Skip to content

Commit bcc4ad5

Browse files
authored
[tuner]: use property function from iree lowering config python binding (#662)
After landing iree-org/iree#19376, all helper functions related to lowering configuration can be removed. Instead, we can directly utilize property functions from the LoweringConfig Python bindings. This PR is still relevant to the task in #453: use IREE bindings for compilation info (incl., lowering_config and translation_info). --------- Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
1 parent 217690e commit bcc4ad5

File tree

6 files changed

+68
-156
lines changed

6 files changed

+68
-156
lines changed

tuner/tuner/candidate_gen.py

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,19 @@
3838

3939
tune_logger = logging.getLogger("tune")
4040

41-
# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'.
41+
4242
def apply_configuration(
4343
template: list[str],
4444
configuration: Configuration,
45-
workgroup_sizes: list[int],
46-
reduction_sizes: list[int],
4745
) -> str:
48-
intrinsic = get_intrinsic(configuration)
49-
subgroup_m_count = get_subgroup_m_count(configuration)
50-
subgroup_n_count = get_subgroup_n_count(configuration)
46+
lowering_config = configuration.lowering_config
47+
intrinsic = lowering_config.mma_kind
48+
(
49+
subgroup_m_count,
50+
subgroup_n_count,
51+
) = lowering_config.subgroup_count_mn
52+
workgroup_sizes = lowering_config.workgroup_tile_sizes
53+
reduction_sizes = lowering_config.reduction_tile_sizes
5154
tune_logger.info(f"Applying: {configuration}")
5255
expr0 = re.compile(
5356
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
@@ -125,9 +128,12 @@ class MmtTuner(DispatchTuner, MmtParser):
125128
def get_transform_function_mmt(
126129
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
127130
) -> str:
128-
intrinsic = get_intrinsic(configuration)
129-
subgroup_m_count = get_subgroup_m_count(configuration)
130-
subgroup_n_count = get_subgroup_n_count(configuration)
131+
lowering_config = configuration.lowering_config
132+
intrinsic = lowering_config.mma_kind
133+
(
134+
subgroup_m_count,
135+
subgroup_n_count,
136+
) = lowering_config.subgroup_count_mn
131137

132138
wg_x, wg_y, wg_z = configuration.workgroup_size
133139
extra_config = get_pipeline_config(configuration)
@@ -167,8 +173,6 @@ def apply_params(
167173
modified += apply_configuration(
168174
template,
169175
configuration,
170-
get_mmt_workgroup_sizes(configuration),
171-
get_mmt_reduction_sizes(configuration),
172176
)
173177
embeddable = indent(
174178
self.get_transform_function_mmt(problem_size, f"match_op", configuration),
@@ -193,15 +197,12 @@ def get_transform_function_conv(
193197
filter = f"tensor<{problem_size.rhs_type}>"
194198
output = f"tensor<{dynamic_batch_output_ty}>"
195199

196-
workgroup_sizes = ", ".join(
197-
map(str, self.get_conv_workgroup_sizes(configuration))
198-
)
199-
reduction_sizes = ", ".join(
200-
map(str, self.get_conv_reduction_sizes(configuration))
201-
)
202-
intrinsic = get_intrinsic(configuration)
203-
subgroup_m_count = get_subgroup_m_count(configuration)
204-
subgroup_n_count = get_subgroup_n_count(configuration)
200+
lowering_config = configuration.lowering_config
201+
intrinsic = lowering_config.mma_kind
202+
(
203+
subgroup_m_count,
204+
subgroup_n_count,
205+
) = lowering_config.subgroup_count_mn
205206

206207
wg_x, wg_y, wg_z = configuration.workgroup_size
207208
extra_config = get_pipeline_config(configuration)
@@ -246,8 +247,6 @@ def apply_params(
246247
modified += apply_configuration(
247248
template,
248249
configuration,
249-
self.get_conv_workgroup_sizes(configuration),
250-
self.get_conv_reduction_sizes(configuration),
251250
)
252251
embeddable = indent(
253252
self.get_transform_function_conv(problem_size, f"match_op", configuration),
@@ -263,15 +262,12 @@ def get_transform_function_broadcast_rhs_mmt(
263262
functionName: str,
264263
configuration: Configuration,
265264
) -> str:
266-
workgroup_sizes = ", ".join(
267-
map(str, get_batch_mmt_workgroup_sizes(configuration))
268-
)
269-
reduction_sizes = ", ".join(
270-
map(str, get_batch_mmt_reduction_sizes(configuration))
271-
)
272-
intrinsic = get_intrinsic(configuration)
273-
subgroup_m_count = get_subgroup_m_count(configuration)
274-
subgroup_n_count = get_subgroup_n_count(configuration)
265+
lowering_config = configuration.lowering_config
266+
intrinsic = lowering_config.mma_kind
267+
(
268+
subgroup_m_count,
269+
subgroup_n_count,
270+
) = lowering_config.subgroup_count_mn
275271

276272
wg_x, wg_y, wg_z = configuration.workgroup_size
277273
extra_config = get_pipeline_config(configuration)
@@ -316,8 +312,6 @@ def apply_params_broadcast_rhs_mmt(
316312
modified += apply_configuration(
317313
template,
318314
configuration,
319-
get_batch_mmt_workgroup_sizes(configuration),
320-
get_batch_mmt_reduction_sizes(configuration),
321315
)
322316

323317
embeddable = indent(
@@ -345,8 +339,6 @@ def apply_params(
345339
apply_configuration(
346340
template,
347341
configuration,
348-
get_contract_workgroup_sizes(configuration, self.tile_dims),
349-
get_contract_reduction_sizes(configuration, self.tile_dims),
350342
),
351343
"",
352344
)
@@ -359,9 +351,12 @@ def get_transform_function_batch_mmt(
359351
functionName: str,
360352
configuration: Configuration,
361353
) -> str:
362-
intrinsic = get_intrinsic(configuration)
363-
subgroup_m_count = get_subgroup_m_count(configuration)
364-
subgroup_n_count = get_subgroup_n_count(configuration)
354+
lowering_config = configuration.lowering_config
355+
intrinsic = lowering_config.mma_kind
356+
(
357+
subgroup_m_count,
358+
subgroup_n_count,
359+
) = lowering_config.subgroup_count_mn
365360

366361
wg_x, wg_y, wg_z = configuration.workgroup_size
367362
extra_config = get_pipeline_config(configuration)
@@ -403,8 +398,6 @@ def apply_params(
403398
modified += apply_configuration(
404399
template,
405400
configuration,
406-
get_batch_mmt_workgroup_sizes(configuration),
407-
get_batch_mmt_reduction_sizes(configuration),
408401
)
409402

410403
embeddable = indent(
@@ -428,9 +421,12 @@ def get_transform_function_batch_matmul(
428421
input1 = f"tensor<{problem_size.rhs_type}>"
429422
output = f"tensor<{problem_size.res_type}>"
430423

431-
intrinsic = get_intrinsic(configuration)
432-
subgroup_m_count = get_subgroup_m_count(configuration)
433-
subgroup_n_count = get_subgroup_n_count(configuration)
424+
lowering_config = configuration.lowering_config
425+
intrinsic = lowering_config.mma_kind
426+
(
427+
subgroup_m_count,
428+
subgroup_n_count,
429+
) = lowering_config.subgroup_count_mn
434430

435431
wg_x, wg_y, wg_z = configuration.workgroup_size
436432
extra_config = get_pipeline_config(configuration)
@@ -476,8 +472,6 @@ def apply_params(
476472
modified += apply_configuration(
477473
template,
478474
configuration,
479-
get_contract_workgroup_sizes(configuration, self.tile_dims),
480-
get_contract_reduction_sizes(configuration, self.tile_dims),
481475
)
482476

483477
embeddable = indent(

tuner/tuner/candidate_gen_test.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,15 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
106106
'gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}',
107107
]
108108

109-
n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640
109+
n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 16
110110

111111
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
112112
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
113113
lowering_config = common.get_lowering_config(
114114
tuner_ctx=tuner_ctx,
115115
mma_kind=mma_attr,
116-
workgroup=[464, 320, 0],
117-
reduction=[0, 0, 16],
116+
workgroup=[n, oh, ow, oc, fh, fw, 0],
117+
reduction=[0, 0, 0, 0, 0, 0, ic],
118118
subgroup_m_count=1,
119119
subgroup_n_count=4,
120120
)
@@ -155,7 +155,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
155155
"LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64"
156156
in modified
157157
)
158-
assert "workgroup = [1, 1, 464, 320, 1, 1, 0]" in modified
158+
assert "workgroup = [2, 64, 64, 640, 3, 3, 0]" in modified
159159
assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified
160160
assert (
161161
"gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>>"
@@ -186,8 +186,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
186186
lowering_config = common.get_lowering_config(
187187
tuner_ctx=tuner_ctx,
188188
mma_kind=mma_attr,
189-
workgroup=[480, 384, 0],
190-
reduction=[0, 0, 32],
189+
workgroup=[1, 480, 384, 0],
190+
reduction=[0, 0, 0, 32],
191191
subgroup_m_count=1,
192192
subgroup_n_count=4,
193193
)
@@ -241,8 +241,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
241241
lowering_config = common.get_lowering_config(
242242
tuner_ctx=tuner_ctx,
243243
mma_kind=mma_attr,
244-
workgroup=[416, 320, 0],
245-
reduction=[0, 0, 128],
244+
workgroup=[1, 416, 320, 0],
245+
reduction=[0, 0, 0, 128],
246246
subgroup_m_count=2,
247247
subgroup_n_count=2,
248248
)
@@ -299,8 +299,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
299299
lowering_config = common.get_lowering_config(
300300
tuner_ctx=tuner_ctx,
301301
mma_kind=mma_attr,
302-
workgroup=[128, 64, 0],
303-
reduction=[0, 0, 128],
302+
workgroup=[1, 128, 64, 0],
303+
reduction=[0, 0, 0, 128],
304304
subgroup_m_count=2,
305305
subgroup_n_count=2,
306306
)
@@ -355,8 +355,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
355355
lowering_config = common.get_lowering_config(
356356
tuner_ctx=tuner_ctx,
357357
mma_kind=mma_attr,
358-
workgroup=[128, 64, 0],
359-
reduction=[0, 0, 128],
358+
workgroup=[1, 128, 64, 0],
359+
reduction=[0, 0, 0, 128],
360360
subgroup_m_count=2,
361361
subgroup_n_count=2,
362362
)
@@ -408,8 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
408408
"%config = transform.param.constant #iree_codegen.compilation_info<"
409409
in embeddable
410410
)
411-
assert "workgroup = [128, 64, 0]" in embeddable
412-
assert "reduction = [0, 0, 128]" in embeddable
411+
assert "workgroup = [1, 128, 64, 0]" in embeddable
412+
assert "reduction = [0, 0, 0, 128]" in embeddable
413413
assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable
414414
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable
415415

@@ -435,8 +435,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
435435
lowering_config = common.get_lowering_config(
436436
tuner_ctx=tuner_ctx,
437437
mma_kind=mma_attr,
438-
workgroup=[128, 64, 0],
439-
reduction=[0, 0, 128],
438+
workgroup=[1, 128, 64, 0],
439+
reduction=[0, 0, 0, 128],
440440
subgroup_m_count=2,
441441
subgroup_n_count=2,
442442
)
@@ -492,8 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
492492
"%config = transform.param.constant #iree_codegen.compilation_info<"
493493
in embeddable
494494
)
495-
assert "workgroup = [128, 64, 0]" in embeddable
496-
assert "reduction = [0, 0, 128]" in embeddable
495+
assert "workgroup = [1, 128, 64, 0]" in embeddable
496+
assert "reduction = [0, 0, 0, 128]" in embeddable
497497
assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable
498498
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable
499499

tuner/tuner/common.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -119,40 +119,6 @@ class Configuration:
119119
waves_per_eu: int
120120

121121

122-
def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]:
123-
if "mma_kind" in config.lowering_config.attributes:
124-
return config.lowering_config.attributes["mma_kind"]
125-
return None
126-
127-
128-
def get_workgroup_tile_sizes(config: Configuration) -> list[int]:
129-
if "workgroup" in config.lowering_config.attributes:
130-
workgroup_attrs = config.lowering_config.attributes["workgroup"]
131-
return [attr.value for attr in workgroup_attrs]
132-
return []
133-
134-
135-
def get_reduction_tile_sizes(config: Configuration) -> list[int]:
136-
if "reduction" in config.lowering_config.attributes:
137-
reduction_attrs = config.lowering_config.attributes["reduction"]
138-
return [attr.value for attr in reduction_attrs]
139-
return []
140-
141-
142-
def get_subgroup_m_count(config: Configuration) -> Optional[int]:
143-
if "subgroup_m_count" in config.lowering_config.attributes:
144-
attr = config.lowering_config.attributes["subgroup_m_count"]
145-
return attr.value
146-
return None
147-
148-
149-
def get_subgroup_n_count(config: Configuration) -> Optional[int]:
150-
if "subgroup_n_count" in config.lowering_config.attributes:
151-
attr = config.lowering_config.attributes["subgroup_n_count"]
152-
return attr.value
153-
return None
154-
155-
156122
def get_lowering_config(
157123
tuner_ctx: TunerContext,
158124
**kwargs: Any,

tuner/tuner/common_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,5 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None:
215215
waves_per_eu=2,
216216
)
217217

218-
assert common.get_intrinsic(config) is None
219-
assert common.get_subgroup_m_count(config) == 1
220-
assert common.get_subgroup_n_count(config) == 1
218+
assert config.lowering_config.mma_kind is None
219+
assert config.lowering_config.subgroup_count_mn == (1, 1)

tuner/tuner/dispatch_parser.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,10 @@ def parse_tensor_type(tensor_type: str) -> ShapedType:
2020
return ShapedType(shaped_ty.shape, shaped_ty.element_type)
2121

2222

23-
def get_mmt_workgroup_sizes(configuration: Configuration):
24-
return get_workgroup_tile_sizes(configuration)
25-
26-
27-
def get_mmt_reduction_sizes(configuration: Configuration):
28-
return get_reduction_tile_sizes(configuration)
29-
30-
3123
def get_contract_workgroup_sizes(
3224
configuration: Configuration, tile_dims: str
3325
) -> list[int]:
34-
m, n, _k = get_workgroup_tile_sizes(configuration)
26+
m, n, _k = configuration.lowering_config.workgroup_tile_sizes
3527

3628
workgroup_size = [1] * len(tile_dims)
3729
for idx, dim in enumerate(tile_dims):
@@ -48,7 +40,7 @@ def get_contract_workgroup_sizes(
4840
def get_contract_reduction_sizes(
4941
configuration: Configuration, tile_dims: str
5042
) -> list[int]:
51-
_m, _n, k = get_reduction_tile_sizes(configuration)
43+
_m, _n, k = configuration.lowering_config.reduction_tile_sizes
5244
reduction_size = [0] * len(tile_dims)
5345
for idx, dim in enumerate(tile_dims):
5446
if dim == "k":
@@ -57,14 +49,6 @@ def get_contract_reduction_sizes(
5749
return reduction_size
5850

5951

60-
def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]:
61-
return [1] + get_workgroup_tile_sizes(configuration)
62-
63-
64-
def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]:
65-
return [0] + get_reduction_tile_sizes(configuration)
66-
67-
6852
class MlirRegex(Enum):
6953
ssa_value = r"%[a-zA-Z0-9-_]+"
7054
tensor_type = r"tensor<([^>]+)>"
@@ -164,22 +148,6 @@ class ConvParser(DispatchParser):
164148
def supports(self, op_name: str) -> bool:
165149
return "conv_2d_nhwc_hwcf" in op_name
166150

167-
def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]:
168-
batch = 1
169-
fh = 1
170-
fw = 1
171-
172-
oh = 1
173-
174-
ow, oc, _ic = get_workgroup_tile_sizes(configuration)
175-
176-
return [batch, oh, ow, oc, fh, fw, 0]
177-
178-
def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]:
179-
_ow, _oc, ic = get_reduction_tile_sizes(configuration)
180-
181-
return [0, 0, 0, 0, 0, 0, ic]
182-
183151
def get_shapes(self, template: list[str]) -> ProblemSize:
184152
for line in template:
185153
if "linalg.conv_2d_nhwc_hwcf" not in line:

0 commit comments

Comments
 (0)