@@ -274,6 +274,7 @@ def __init__(self, kernel: Kernel[_R], args: tuple[object, ...]) -> None:
274
274
super ().__init__ ()
275
275
self .kernel = kernel
276
276
self ._run : Callable [..., _R ] | None = None
277
+ self ._config : Config | None = None
277
278
self ._compile_cache : dict [Config , CompiledConfig ] = {}
278
279
self .env = CompileEnvironment (_find_device (args ), self .kernel .settings )
279
280
with self .env :
@@ -338,7 +339,7 @@ def configs(self) -> list[Config]:
338
339
"""
339
340
return self .kernel .configs
340
341
341
- def to_triton_code (self , config : ConfigLike ) -> str :
342
+ def to_triton_code (self , config : ConfigLike | None = None ) -> str :
342
343
"""
343
344
Generate Triton code for the kernel based on the given configuration.
344
345
@@ -348,6 +349,8 @@ def to_triton_code(self, config: ConfigLike) -> str:
348
349
Returns:
349
350
str: The generated Triton code as a string.
350
351
"""
352
+ if config is None :
353
+ config = self ._require_implicit_config ()
351
354
with self .env :
352
355
if not isinstance (config , Config ):
353
356
config = Config (** config ) # pyright: ignore[reportArgumentType]
@@ -356,7 +359,7 @@ def to_triton_code(self, config: ConfigLike) -> str:
356
359
return get_needed_imports (root ) + unparse (root )
357
360
358
361
def compile_config (
359
- self , config : ConfigLike , * , allow_print : bool = True
362
+ self , config : ConfigLike | None = None , * , allow_print : bool = True
360
363
) -> CompiledConfig :
361
364
"""
362
365
Compile the kernel for a specific configuration.
@@ -368,6 +371,8 @@ def compile_config(
368
371
Returns:
369
372
CompiledConfig: A callable object representing the compiled kernel.
370
373
"""
374
+ if config is None :
375
+ config = self ._require_implicit_config ()
371
376
if not isinstance (config , Config ):
372
377
config = Config (
373
378
** config # pyright: ignore[reportArgumentType]
@@ -458,6 +463,7 @@ def set_config(self, config: ConfigLike) -> None:
458
463
** config # pyright: ignore[reportArgumentType]
459
464
)
460
465
self ._run = self .compile_config (config )
466
+ self ._config = config
461
467
462
468
def _specialize_extra (self ) -> list [Callable [[Sequence [object ]], Hashable ]]:
463
469
"""
@@ -492,6 +498,27 @@ def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
492
498
extractors .append (make_extractor (source ))
493
499
return extractors
494
500
501
+ def _implicit_config (self ) -> Config | None :
502
+ """
503
+ Returns a single config that is implicitly used by this kernel, if any.
504
+ """
505
+ configs = self .kernel .configs
506
+ if self ._config is not None :
507
+ return self ._config
508
+ if len (configs ) == 1 :
509
+ return configs [0 ]
510
+ if len (configs ) == 0 and self .kernel .settings .use_default_config :
511
+ return self .config_spec .default_config ()
512
+ return None
513
+
514
+ def _require_implicit_config (self ) -> Config :
515
+ """
516
+ Returns the implicit config for this kernel, or raises an error if no implicit config is available.
517
+ """
518
+ if (config := self ._implicit_config ()) is None :
519
+ raise RuntimeError ("no config provided and no implicit config available" )
520
+ return config
521
+
495
522
def __call__ (self , * args : object ) -> _R :
496
523
"""
497
524
Execute the kernel with the given arguments.
@@ -503,8 +530,8 @@ def __call__(self, *args: object) -> _R:
503
530
_R: The result of the kernel execution.
504
531
"""
505
532
if self ._run is None :
506
- if not self .configs and self . settings . use_default_config :
507
- self .set_config (self . config_spec . default_config () )
533
+ if ( config := self ._implicit_config ()) is not None :
534
+ self .set_config (config )
508
535
else :
509
536
self .autotune (args )
510
537
assert self ._run is not None
0 commit comments