Skip to content

Commit adac8a7

Browse files
authored
Make to_triton_code config arg optional (#291)
1 parent 37e8af3 commit adac8a7

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

helion/runtime/kernel.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def __init__(self, kernel: Kernel[_R], args: tuple[object, ...]) -> None:
274274
super().__init__()
275275
self.kernel = kernel
276276
self._run: Callable[..., _R] | None = None
277+
self._config: Config | None = None
277278
self._compile_cache: dict[Config, CompiledConfig] = {}
278279
self.env = CompileEnvironment(_find_device(args), self.kernel.settings)
279280
with self.env:
@@ -338,7 +339,7 @@ def configs(self) -> list[Config]:
338339
"""
339340
return self.kernel.configs
340341

341-
def to_triton_code(self, config: ConfigLike) -> str:
342+
def to_triton_code(self, config: ConfigLike | None = None) -> str:
342343
"""
343344
Generate Triton code for the kernel based on the given configuration.
344345
@@ -348,6 +349,8 @@ def to_triton_code(self, config: ConfigLike) -> str:
348349
Returns:
349350
str: The generated Triton code as a string.
350351
"""
352+
if config is None:
353+
config = self._require_implicit_config()
351354
with self.env:
352355
if not isinstance(config, Config):
353356
config = Config(**config) # pyright: ignore[reportArgumentType]
@@ -356,7 +359,7 @@ def to_triton_code(self, config: ConfigLike) -> str:
356359
return get_needed_imports(root) + unparse(root)
357360

358361
def compile_config(
359-
self, config: ConfigLike, *, allow_print: bool = True
362+
self, config: ConfigLike | None = None, *, allow_print: bool = True
360363
) -> CompiledConfig:
361364
"""
362365
Compile the kernel for a specific configuration.
@@ -368,6 +371,8 @@ def compile_config(
368371
Returns:
369372
CompiledConfig: A callable object representing the compiled kernel.
370373
"""
374+
if config is None:
375+
config = self._require_implicit_config()
371376
if not isinstance(config, Config):
372377
config = Config(
373378
**config # pyright: ignore[reportArgumentType]
@@ -458,6 +463,7 @@ def set_config(self, config: ConfigLike) -> None:
458463
**config # pyright: ignore[reportArgumentType]
459464
)
460465
self._run = self.compile_config(config)
466+
self._config = config
461467

462468
def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:
463469
"""
@@ -492,6 +498,27 @@ def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
492498
extractors.append(make_extractor(source))
493499
return extractors
494500

501+
def _implicit_config(self) -> Config | None:
502+
"""
503+
Returns a single config that is implicitly used by this kernel, if any.
504+
"""
505+
configs = self.kernel.configs
506+
if self._config is not None:
507+
return self._config
508+
if len(configs) == 1:
509+
return configs[0]
510+
if len(configs) == 0 and self.kernel.settings.use_default_config:
511+
return self.config_spec.default_config()
512+
return None
513+
514+
def _require_implicit_config(self) -> Config:
515+
"""
516+
Returns the implicit config for this kernel, or raises an error if no implicit config is available.
517+
"""
518+
if (config := self._implicit_config()) is None:
519+
raise RuntimeError("no config provided and no implicit config available")
520+
return config
521+
495522
def __call__(self, *args: object) -> _R:
496523
"""
497524
Execute the kernel with the given arguments.
@@ -503,8 +530,8 @@ def __call__(self, *args: object) -> _R:
503530
_R: The result of the kernel execution.
504531
"""
505532
if self._run is None:
506-
if not self.configs and self.settings.use_default_config:
507-
self.set_config(self.config_spec.default_config())
533+
if (config := self._implicit_config()) is not None:
534+
self.set_config(config)
508535
else:
509536
self.autotune(args)
510537
assert self._run is not None

test/test_misc.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,57 @@ def test_tile_block_size_usage(x: torch.Tensor) -> torch.Tensor:
236236
# The result should have 1s at positions that are last in their tile
237237
self.assertTrue(result.sum().item() > 0)
238238

239+
def test_to_triton_code_optional_config(self):
240+
"""Test that to_triton_code() works without explicit config argument."""
241+
242+
# Test 1: Kernel with single config - should use that config
243+
@helion.kernel(config={"block_sizes": [64]})
244+
def kernel_single_config(x: torch.Tensor) -> torch.Tensor:
245+
result = torch.empty_like(x)
246+
for tile in hl.tile(x.shape):
247+
result[tile] = x[tile] * 2
248+
return result
249+
250+
x = torch.randn([32], device=DEVICE)
251+
bound_kernel = kernel_single_config.bind((x,))
252+
253+
# Should work without config argument
254+
code_without_config = bound_kernel.to_triton_code()
255+
code_with_config = bound_kernel.to_triton_code({"block_sizes": [64]})
256+
self.assertEqual(code_without_config, code_with_config)
257+
258+
# Test 2: Kernel with use_default_config - should use default config
259+
@helion.kernel(use_default_config=True)
260+
def kernel_default_config(x: torch.Tensor) -> torch.Tensor:
261+
result = torch.empty_like(x)
262+
for tile in hl.tile(x.shape):
263+
result[tile] = x[tile] * 3
264+
return result
265+
266+
bound_kernel_default = kernel_default_config.bind((x,))
267+
268+
# Should work without config argument using default config
269+
code_default = bound_kernel_default.to_triton_code()
270+
self.assertIsInstance(code_default, str)
271+
self.assertIn("def", code_default) # Basic sanity check
272+
273+
# Test 3: Kernel with no configs and no default - should raise error
274+
@helion.kernel
275+
def kernel_no_config(x: torch.Tensor) -> torch.Tensor:
276+
result = torch.empty_like(x)
277+
for tile in hl.tile(x.shape):
278+
result[tile] = x[tile] * 4
279+
return result
280+
281+
bound_kernel_no_config = kernel_no_config.bind((x,))
282+
283+
# Should raise RuntimeError when no implicit config available
284+
with self.assertRaises(RuntimeError) as cm:
285+
bound_kernel_no_config.to_triton_code()
286+
self.assertIn(
287+
"no config provided and no implicit config available", str(cm.exception)
288+
)
289+
239290

240291
if __name__ == "__main__":
241292
unittest.main()

0 commit comments

Comments
 (0)