26
26
from typing import Callable
27
27
28
28
# Maps tritonbench op names to Helion kernel examples
29
- KERNEL_MAPPINGS : dict [str , tuple [str , str , str ]] = {
29
+ # Can map to a single kernel or a list of kernels
30
+ KERNEL_MAPPINGS : dict [str , tuple [str , str , str ] | list [tuple [str , str , str ]]] = {
30
31
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
31
32
# "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
32
33
# "embedding": (
75
76
"examples.fp8_gemm" ,
76
77
"fp8_gemm_tritonbench" ,
77
78
),
79
+ "gemm" : [
80
+ # List of gemm variants
81
+ (
82
+ "tritonbench.operators.gemm.operator" ,
83
+ "examples.matmul" ,
84
+ "matmul" ,
85
+ ),
86
+ (
87
+ "tritonbench.operators.gemm.operator" ,
88
+ "examples.matmul_split_k" ,
89
+ "matmul_split_k" ,
90
+ ),
91
+ ],
78
92
}
79
93
80
94
@@ -196,16 +210,36 @@ def check_and_setup_tritonbench() -> None:
196
210
197
211
198
212
def run_kernel (kernel_name : str , tritonbench_args : list [str ]) -> None :
199
- """Run a single kernel benchmark."""
213
+ """Run a kernel benchmark, handling both single and multiple variants ."""
200
214
# Check if kernel is in the mapping table
201
215
if kernel_name not in KERNEL_MAPPINGS :
202
216
print (f"Error: Unknown kernel '{ kernel_name } '" , file = sys .stderr )
203
217
print (
204
218
f"Available kernels: { ', ' .join (KERNEL_MAPPINGS .keys ())} " , file = sys .stderr
205
219
)
206
220
sys .exit (1 )
221
+
222
+ mapping = KERNEL_MAPPINGS [kernel_name ]
223
+
224
+ # Check if it's a list of variants or a single kernel
225
+ if isinstance (mapping , list ):
226
+ # Run each variant
227
+ for i , (tritonbench_module , module_path , func_name ) in enumerate (mapping ):
228
+ # Extract variant name from func_name for display
229
+ variant_name = func_name
230
+ if i > 0 :
231
+ print (f"\n { '=' * 60 } " , file = sys .stderr )
232
+ print (f"Kernel: { kernel_name } (variant: { variant_name } )" , file = sys .stderr )
233
+ print (f"{ '=' * 60 } \n " , file = sys .stderr )
234
+ run_single_kernel_variant (kernel_name , tritonbench_module , module_path , func_name , tritonbench_args .copy (), variant_name )
235
+ else :
236
+ # Single kernel
237
+ tritonbench_module , module_path , func_name = mapping
238
+ run_single_kernel_variant (kernel_name , tritonbench_module , module_path , func_name , tritonbench_args )
239
+
207
240
208
- tritonbench_module , module_path , func_name = KERNEL_MAPPINGS [kernel_name ]
241
+ def run_single_kernel_variant (kernel_name : str , tritonbench_module : str , module_path : str , func_name : str , tritonbench_args : list [str ], variant_name : str | None = None ) -> None :
242
+ """Run a single kernel variant."""
209
243
210
244
# Import from the mapped module
211
245
try :
@@ -305,7 +339,10 @@ def _inner() -> Callable[..., Any] | object:
305
339
return _inner
306
340
307
341
# Method name for the benchmark
308
- helion_method_name = f"helion_{ kernel_name } "
342
+ if variant_name :
343
+ helion_method_name = f"helion_{ kernel_name } _{ variant_name } "
344
+ else :
345
+ helion_method_name = f"helion_{ kernel_name } "
309
346
310
347
# Import register_benchmark API
311
348
from tritonbench .utils .triton_op import ( # pyright: ignore[reportMissingImports]
@@ -325,10 +362,16 @@ def _inner() -> Callable[..., Any] | object:
325
362
# Set the decorated method on the Operator class
326
363
setattr (Operator , helion_method_name , decorated_method )
327
364
328
- print (
329
- f"Running { operator_name } benchmark with Helion implementation...\n " ,
330
- file = sys .stderr ,
331
- )
365
+ if variant_name :
366
+ print (
367
+ f"Running { operator_name } benchmark with Helion implementation (variant: { variant_name } )...\n " ,
368
+ file = sys .stderr ,
369
+ )
370
+ else :
371
+ print (
372
+ f"Running { operator_name } benchmark with Helion implementation...\n " ,
373
+ file = sys .stderr ,
374
+ )
332
375
333
376
# Create and run the operator with unknown args
334
377
op = Operator (tb_args = tb_args , extra_args = unknown_args )
0 commit comments