Skip to content

Commit e1e59c1

Browse files
committed
add op, gen_instances and test
1 parent b765fe7 commit e1e59c1

File tree

3 files changed

+207
-0
lines changed

3 files changed

+207
-0
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import functools
2+
from .op import CKTileGemmOperation
3+
4+
5+
@functools.cache
6+
def ops():
7+
"""
8+
Generate the supported instance dataclasses
9+
"""
10+
import itertools
11+
12+
compute_v3_instances = [
13+
CKTileGemmOperation(
14+
layout_a=layout_a,
15+
layout_b=layout_b,
16+
layout_c=layout_c,
17+
datatype_a=datatype_a,
18+
datatype_b=datatype_b,
19+
datatype_c=datatype_c,
20+
tile_m=tile_m,
21+
tile_n=tile_n,
22+
tile_k=tile_k,
23+
warp_m=warp_m,
24+
warp_n=warp_n,
25+
warp_k=warp_k,
26+
warp_tile_m=warp_tile_m,
27+
warp_tile_n=warp_tile_n,
28+
warp_tile_k=warp_tile_k,
29+
m_is_padded=m_is_padded,
30+
n_is_padded=n_is_padded,
31+
k_is_padded=k_is_padded,
32+
pipeline="CompV3",
33+
scheduler="Intrawave",
34+
epilogue=epilogue,
35+
)
36+
for (layout_a, layout_b, layout_c) in [
37+
("Row", "Row", "Row"),
38+
("Row", "Col", "Row"),
39+
]
40+
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
41+
for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)]
42+
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
43+
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
44+
for m_is_padded in ["true", "false"]
45+
for n_is_padded in ["true", "false"]
46+
for k_is_padded in ["true", "false"]
47+
for epilogue in ["Default", "CShuffle"]
48+
]
49+
50+
compute_v4_instances = [
51+
CKTileGemmOperation(
52+
layout_a=layout_a,
53+
layout_b=layout_b,
54+
layout_c=layout_c,
55+
datatype_a=datatype_a,
56+
datatype_b=datatype_b,
57+
datatype_c=datatype_c,
58+
tile_m=tile_m,
59+
tile_n=tile_n,
60+
tile_k=tile_k,
61+
warp_m=warp_m,
62+
warp_n=warp_n,
63+
warp_k=warp_k,
64+
warp_tile_m=warp_tile_m,
65+
warp_tile_n=warp_tile_n,
66+
warp_tile_k=warp_tile_k,
67+
m_is_padded=m_is_padded,
68+
n_is_padded=n_is_padded,
69+
k_is_padded=k_is_padded,
70+
pipeline="CompV4",
71+
scheduler="Intrawave",
72+
epilogue=epilogue,
73+
)
74+
for (layout_a, layout_b, layout_c) in [
75+
("Row", "Row", "Row"),
76+
("Row", "Col", "Row"),
77+
]
78+
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
79+
for (tile_m, tile_n, tile_k) in [
80+
(256, 256, 32)
81+
] # half the tile size since it has double buffering
82+
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
83+
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
84+
for m_is_padded in ["true", "false"]
85+
for n_is_padded in ["true", "false"]
86+
for k_is_padded in ["true", "false"]
87+
for epilogue in ["Default", "CShuffle"]
88+
]
89+
90+
mem_instances = [
91+
CKTileGemmOperation(
92+
layout_a=layout_a,
93+
layout_b=layout_b,
94+
layout_c=layout_c,
95+
datatype_a=datatype_a,
96+
datatype_b=datatype_b,
97+
datatype_c=datatype_c,
98+
tile_m=tile_m,
99+
tile_n=tile_n,
100+
tile_k=tile_k,
101+
warp_m=warp_m,
102+
warp_n=warp_n,
103+
warp_k=warp_k,
104+
warp_tile_m=warp_tile_m,
105+
warp_tile_n=warp_tile_n,
106+
warp_tile_k=warp_tile_k,
107+
m_is_padded=m_is_padded,
108+
n_is_padded=n_is_padded,
109+
k_is_padded=k_is_padded,
110+
pipeline="Mem",
111+
scheduler=scheduler,
112+
epilogue=epilogue,
113+
)
114+
for (layout_a, layout_b, layout_c) in [
115+
("Row", "Row", "Row"),
116+
("Row", "Col", "Row"),
117+
]
118+
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
119+
for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)]
120+
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
121+
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
122+
for m_is_padded in ["true", "false"]
123+
for n_is_padded in ["true", "false"]
124+
for k_is_padded in ["true", "false"]
125+
for scheduler in ["Intrawave", "Interwave"]
126+
for epilogue in ["Default", "CShuffle"]
127+
]
128+
129+
return list(
130+
itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances)
131+
)
132+
133+
134+
if __name__ == "__main__":
135+
for op in ops():
136+
print(op.name())
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from dataclasses import asdict, dataclass
2+
3+
4+
@dataclass
5+
class CKTileGemmOperation:
6+
layout_a: str
7+
layout_b: str
8+
layout_c: str
9+
10+
datatype_a: str
11+
datatype_b: str
12+
datatype_c: str
13+
14+
tile_m: int
15+
tile_n: int
16+
tile_k: int
17+
18+
warp_m: int
19+
warp_n: int
20+
warp_k: int
21+
22+
warp_tile_m: int
23+
warp_tile_n: int
24+
warp_tile_k: int
25+
26+
m_is_padded: str
27+
n_is_padded: str
28+
k_is_padded: str
29+
30+
pipeline: str
31+
scheduler: str
32+
epilogue: str
33+
34+
def layout_repr(self):
35+
return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}"
36+
37+
def dtype_repr(self):
38+
return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}"
39+
40+
def tile_sizes(self):
41+
return "_".join(
42+
[
43+
f"{self.tile_m}{self.tile_n}{self.tile_k}",
44+
f"{self.warp_m}{self.warp_n}{self.warp_k}",
45+
f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}",
46+
]
47+
)
48+
49+
def name(self):
50+
return "ck_tile_gemm_universal_" + "_".join(
51+
[
52+
f"{self.layout_repr()}",
53+
f"{self.dtype_repr()}",
54+
f"{self.tile_sizes()}",
55+
f"{self.pipeline}",
56+
f"{self.scheduler}",
57+
f"{self.epilogue}",
58+
]
59+
)
60+
61+
def dict_items(self):
62+
return asdict(self).items()

python/test/test_gen_instances.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from ck4inductor.batched_universal_gemm.gen_instances import (
1717
gen_ops_library as gen_batched_gemm_ops_library,
1818
)
19+
from ck4inductor.ck_tile_universal_gemm.gen_instances import (
20+
ops as gen_ck_tile_gemm_ops_library
21+
)
1922

2023
log = logging.getLogger(__name__)
2124

@@ -44,3 +47,9 @@ def test_gen_batched_gemm_instances(self):
4447

4548
log.debug("%d gemm instances from library" % len(instances))
4649
self.assertTrue(instances)
50+
51+
def test_gen_ck_tile_universal_gemm_instances(self):
52+
instances = gen_ck_tile_gemm_ops_library()
53+
54+
log.debug("%d ck-tile gemm instances from library" % len(instances))
55+
self.assertTrue(instances)

0 commit comments

Comments
 (0)