38
38
39
39
tune_logger = logging .getLogger ("tune" )
40
40
41
- # TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'.
41
+
42
42
def apply_configuration (
43
43
template : list [str ],
44
44
configuration : Configuration ,
45
- workgroup_sizes : list [int ],
46
- reduction_sizes : list [int ],
47
45
) -> str :
48
- intrinsic = get_intrinsic (configuration )
49
- subgroup_m_count = get_subgroup_m_count (configuration )
50
- subgroup_n_count = get_subgroup_n_count (configuration )
46
+ lowering_config = configuration .lowering_config
47
+ intrinsic = lowering_config .mma_kind
48
+ (
49
+ subgroup_m_count ,
50
+ subgroup_n_count ,
51
+ ) = lowering_config .subgroup_count_mn
52
+ workgroup_sizes = lowering_config .workgroup_tile_sizes
53
+ reduction_sizes = lowering_config .reduction_tile_sizes
51
54
tune_logger .info (f"Applying: { configuration } " )
52
55
expr0 = re .compile (
53
56
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
@@ -125,9 +128,12 @@ class MmtTuner(DispatchTuner, MmtParser):
125
128
def get_transform_function_mmt (
126
129
self , problem_size : ProblemSize , functionName : str , configuration : Configuration
127
130
) -> str :
128
- intrinsic = get_intrinsic (configuration )
129
- subgroup_m_count = get_subgroup_m_count (configuration )
130
- subgroup_n_count = get_subgroup_n_count (configuration )
131
+ lowering_config = configuration .lowering_config
132
+ intrinsic = lowering_config .mma_kind
133
+ (
134
+ subgroup_m_count ,
135
+ subgroup_n_count ,
136
+ ) = lowering_config .subgroup_count_mn
131
137
132
138
wg_x , wg_y , wg_z = configuration .workgroup_size
133
139
extra_config = get_pipeline_config (configuration )
@@ -167,8 +173,6 @@ def apply_params(
167
173
modified += apply_configuration (
168
174
template ,
169
175
configuration ,
170
- get_mmt_workgroup_sizes (configuration ),
171
- get_mmt_reduction_sizes (configuration ),
172
176
)
173
177
embeddable = indent (
174
178
self .get_transform_function_mmt (problem_size , f"match_op" , configuration ),
@@ -193,15 +197,12 @@ def get_transform_function_conv(
193
197
filter = f"tensor<{ problem_size .rhs_type } >"
194
198
output = f"tensor<{ dynamic_batch_output_ty } >"
195
199
196
- workgroup_sizes = ", " .join (
197
- map (str , self .get_conv_workgroup_sizes (configuration ))
198
- )
199
- reduction_sizes = ", " .join (
200
- map (str , self .get_conv_reduction_sizes (configuration ))
201
- )
202
- intrinsic = get_intrinsic (configuration )
203
- subgroup_m_count = get_subgroup_m_count (configuration )
204
- subgroup_n_count = get_subgroup_n_count (configuration )
200
+ lowering_config = configuration .lowering_config
201
+ intrinsic = lowering_config .mma_kind
202
+ (
203
+ subgroup_m_count ,
204
+ subgroup_n_count ,
205
+ ) = lowering_config .subgroup_count_mn
205
206
206
207
wg_x , wg_y , wg_z = configuration .workgroup_size
207
208
extra_config = get_pipeline_config (configuration )
@@ -246,8 +247,6 @@ def apply_params(
246
247
modified += apply_configuration (
247
248
template ,
248
249
configuration ,
249
- self .get_conv_workgroup_sizes (configuration ),
250
- self .get_conv_reduction_sizes (configuration ),
251
250
)
252
251
embeddable = indent (
253
252
self .get_transform_function_conv (problem_size , f"match_op" , configuration ),
@@ -263,15 +262,12 @@ def get_transform_function_broadcast_rhs_mmt(
263
262
functionName : str ,
264
263
configuration : Configuration ,
265
264
) -> str :
266
- workgroup_sizes = ", " .join (
267
- map (str , get_batch_mmt_workgroup_sizes (configuration ))
268
- )
269
- reduction_sizes = ", " .join (
270
- map (str , get_batch_mmt_reduction_sizes (configuration ))
271
- )
272
- intrinsic = get_intrinsic (configuration )
273
- subgroup_m_count = get_subgroup_m_count (configuration )
274
- subgroup_n_count = get_subgroup_n_count (configuration )
265
+ lowering_config = configuration .lowering_config
266
+ intrinsic = lowering_config .mma_kind
267
+ (
268
+ subgroup_m_count ,
269
+ subgroup_n_count ,
270
+ ) = lowering_config .subgroup_count_mn
275
271
276
272
wg_x , wg_y , wg_z = configuration .workgroup_size
277
273
extra_config = get_pipeline_config (configuration )
@@ -316,8 +312,6 @@ def apply_params_broadcast_rhs_mmt(
316
312
modified += apply_configuration (
317
313
template ,
318
314
configuration ,
319
- get_batch_mmt_workgroup_sizes (configuration ),
320
- get_batch_mmt_reduction_sizes (configuration ),
321
315
)
322
316
323
317
embeddable = indent (
@@ -345,8 +339,6 @@ def apply_params(
345
339
apply_configuration (
346
340
template ,
347
341
configuration ,
348
- get_contract_workgroup_sizes (configuration , self .tile_dims ),
349
- get_contract_reduction_sizes (configuration , self .tile_dims ),
350
342
),
351
343
"" ,
352
344
)
@@ -359,9 +351,12 @@ def get_transform_function_batch_mmt(
359
351
functionName : str ,
360
352
configuration : Configuration ,
361
353
) -> str :
362
- intrinsic = get_intrinsic (configuration )
363
- subgroup_m_count = get_subgroup_m_count (configuration )
364
- subgroup_n_count = get_subgroup_n_count (configuration )
354
+ lowering_config = configuration .lowering_config
355
+ intrinsic = lowering_config .mma_kind
356
+ (
357
+ subgroup_m_count ,
358
+ subgroup_n_count ,
359
+ ) = lowering_config .subgroup_count_mn
365
360
366
361
wg_x , wg_y , wg_z = configuration .workgroup_size
367
362
extra_config = get_pipeline_config (configuration )
@@ -403,8 +398,6 @@ def apply_params(
403
398
modified += apply_configuration (
404
399
template ,
405
400
configuration ,
406
- get_batch_mmt_workgroup_sizes (configuration ),
407
- get_batch_mmt_reduction_sizes (configuration ),
408
401
)
409
402
410
403
embeddable = indent (
@@ -428,9 +421,12 @@ def get_transform_function_batch_matmul(
428
421
input1 = f"tensor<{ problem_size .rhs_type } >"
429
422
output = f"tensor<{ problem_size .res_type } >"
430
423
431
- intrinsic = get_intrinsic (configuration )
432
- subgroup_m_count = get_subgroup_m_count (configuration )
433
- subgroup_n_count = get_subgroup_n_count (configuration )
424
+ lowering_config = configuration .lowering_config
425
+ intrinsic = lowering_config .mma_kind
426
+ (
427
+ subgroup_m_count ,
428
+ subgroup_n_count ,
429
+ ) = lowering_config .subgroup_count_mn
434
430
435
431
wg_x , wg_y , wg_z = configuration .workgroup_size
436
432
extra_config = get_pipeline_config (configuration )
@@ -476,8 +472,6 @@ def apply_params(
476
472
modified += apply_configuration (
477
473
template ,
478
474
configuration ,
479
- get_contract_workgroup_sizes (configuration , self .tile_dims ),
480
- get_contract_reduction_sizes (configuration , self .tile_dims ),
481
475
)
482
476
483
477
embeddable = indent (
0 commit comments