Skip to content

Commit b9e93c0

Browse files
authored
Add hl.register_tunable (#154)
1 parent 59bf929 commit b9e93c0

File tree

11 files changed

+459
-9
lines changed

11 files changed

+459
-9
lines changed

examples/matmul_split_k.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion.autotuner import PowerOfTwoFragment
7+
import helion.language as hl
8+
9+
10+
# static_shapes=True gives a performance boost for matmuls
11+
@helion.kernel(static_shapes=True)
12+
def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
13+
m, k = x.size()
14+
k2, n = y.size()
15+
assert k == k2, f"size mismatch {k} != {k2}"
16+
out = torch.zeros(
17+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
18+
)
19+
split_k = hl.register_tunable("split_k", PowerOfTwoFragment(1, 256))
20+
k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
21+
for tile_m, tile_n, outer_k in hl.tile([m, n, k], block_size=[None, None, k_block]):
22+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
23+
for inner_k in hl.tile(outer_k.begin, outer_k.end):
24+
acc = torch.addmm(acc, x[tile_m, inner_k], y[inner_k, tile_n])
25+
hl.atomic_add(out, [tile_m, tile_n], acc)
26+
return out
27+
28+
29+
def check(m: int, k: int, n: int) -> None:
30+
from triton.testing import do_bench
31+
32+
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
33+
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
34+
result = matmul_split_k(x, y)
35+
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1)
36+
sec = do_bench(lambda: matmul_split_k(x, y))
37+
baseline_sec = do_bench(lambda: torch.matmul(x, y))
38+
print(
39+
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
40+
)
41+
42+
43+
def main() -> None:
44+
check(64, 32768, 64)
45+
46+
47+
if __name__ == "__main__":
48+
main()

helion/_compiler/compile_environment.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,13 @@ def to_fake(self, obj: object, origin: Origin) -> object:
189189
return self.shape_env.create_unbacked_symfloat()
190190
if isinstance(
191191
obj,
192-
(torch.dtype, torch.device, types.BuiltinFunctionType, types.ModuleType),
192+
(
193+
torch.dtype,
194+
torch.device,
195+
types.BuiltinFunctionType,
196+
types.ModuleType,
197+
type,
198+
),
193199
):
194200
return obj
195201
if isinstance(obj, types.FunctionType):

helion/_compiler/type_propagation.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import NoReturn
1313
from typing import Protocol
1414
from typing import TypeVar
15+
from typing import cast
1516
from unittest.mock import patch
1617

1718
import sympy
@@ -21,6 +22,7 @@
2122
from torch.utils._pytree import tree_map_only
2223

2324
from .. import exc
25+
from ..autotuner.config_fragment import ConfigSpecFragment
2426
from ..autotuner.config_spec import BlockSizeSpec
2527
from ..language._decorators import get_device_func_replacement
2628
from ..language._decorators import is_api_func
@@ -249,6 +251,8 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo:
249251
)
250252
),
251253
)
254+
if isinstance(value, ConfigSpecFragment):
255+
return ConfigFragmentType(origin, value)
252256
if dataclasses.is_dataclass(value):
253257
keys = value.__dataclass_fields__.keys() # pyre-ignore[16]
254258
return ClassType(
@@ -695,6 +699,16 @@ def as_literal(self) -> object:
695699
return self.value
696700

697701

702+
class ConfigFragmentType(LiteralType):
703+
"""TypeInfo for config fragments are treated as constant literals during compilation."""
704+
705+
value: ConfigSpecFragment
706+
707+
def __init__(self, origin: Origin, fragment: ConfigSpecFragment) -> None:
708+
assert isinstance(fragment, ConfigSpecFragment)
709+
super().__init__(origin, fragment)
710+
711+
698712
class CallableType(LiteralType):
699713
value: Callable[..., object]
700714

@@ -745,6 +759,19 @@ def to_proxy(arg: TypeInfo) -> object:
745759
env: CompileEnvironment = CompileEnvironment.current()
746760
proxy_args = [x.tree_map(to_proxy) for x in args]
747761
proxy_kwargs = {k: v.tree_map(to_proxy) for k, v in kwargs.items()}
762+
763+
# special handling for symint arguments
764+
if any(
765+
(isinstance(x, torch.SymInt) and not isinstance(x._sympy_(), sympy.Integer))
766+
for x in proxy_args
767+
):
768+
if self.value in self._new_symint_on_host_fns() and origin.is_host():
769+
return SymIntType.new_unbacked(origin)
770+
if isinstance(self.value, type) and issubclass(
771+
self.value, ConfigFragmentType
772+
):
773+
raise exc.ConfigSpecFragmentWithSymInt(args)
774+
748775
try:
749776
with patch.object(torch.SymInt, "__index__", _raise_shape_specializing):
750777
output_type = TypeInfo.from_example(
@@ -782,6 +809,15 @@ def to_proxy(arg: TypeInfo) -> object:
782809
# TODO(jansel): point to other tracing modes
783810
raise exc.TorchOpTracingError(e) from e
784811

812+
@staticmethod
813+
@functools.cache
814+
def _new_symint_on_host_fns() -> dict[object, None]:
815+
"""Funtions that should return a new unbacked symint when called on host with a symint argument."""
816+
from triton import cdiv
817+
from triton import next_power_of_2
818+
819+
return cast("dict[object, None]", dict.fromkeys([cdiv, next_power_of_2]))
820+
785821

786822
def _raise_shape_specializing(*args: object) -> None:
787823
raise exc.ShapeSpecializingCall
@@ -890,12 +926,10 @@ class SymIntType(NumericType):
890926

891927
@classmethod
892928
def new_unbacked(cls, origin: Origin) -> Self:
893-
shape_env = CompileEnvironment.current().shape_env
894-
with shape_env.ignore_fresh_unbacked_symbols():
895-
return cls(
896-
origin,
897-
shape_env.create_unbacked_symint(),
898-
)
929+
return cls(
930+
origin,
931+
CompileEnvironment.current().create_unbacked_symint(),
932+
)
899933

900934
@property
901935
def python_type(self) -> type[int]:
@@ -953,7 +987,13 @@ def _get_hint(numel: int | torch.SymInt | AutoSize | None) -> int:
953987
if numel is None or isinstance(numel, AutoSize):
954988
# For data-dependent sizes, use arbitrary hint of 8192
955989
return 8192
956-
return CompileEnvironment.current().size_hint(numel)
990+
991+
hint = CompileEnvironment.current().size_hint(numel)
992+
# If the hint is invalid (like 0), use a reasonable default
993+
# This can happen when other hints cancel out in expressions
994+
if hint <= 1:
995+
return 8192
996+
return hint
957997

958998

959999
class TileIndexType(TypeInfo):

helion/autotuner/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from __future__ import annotations
22

3+
from .config_fragment import BooleanFragment as BooleanFragment
4+
from .config_fragment import EnumFragment as EnumFragment
5+
from .config_fragment import IntegerFragment as IntegerFragment
6+
from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment
37
from .config_spec import ConfigSpec as ConfigSpec
48
from .differential_evolution import (
59
DifferentialEvolutionSearch as DifferentialEvolutionSearch,

helion/autotuner/config_fragment.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ class BaseIntegerFragment(ConfigSpecFragment):
6363
high: int # maximum value (inclusive)
6464
default_val: int
6565

66+
def __init__(self, low: int, high: int, default_val: int | None = None) -> None:
67+
self.low = low
68+
self.high = high
69+
if default_val is None:
70+
default_val = low
71+
self.default_val = default_val
72+
6673
def default(self) -> int:
6774
return self.clamp(self.default_val)
6875

helion/autotuner/config_spec.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class ConfigSpec:
6161
reduction_loops: BlockIdSequence[ReductionLoopSpec] = dataclasses.field(
6262
default_factory=BlockIdSequence
6363
)
64+
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
65+
default_factory=dict
66+
)
6467
allow_use_yz_grid: bool | None = None
6568

6669
def _remove_duplicates(self) -> None:
@@ -110,7 +113,10 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
110113
config.setdefault("use_yz_grid", False)
111114

112115
config.setdefault("indexing", "pointer")
113-
if invalid_keys := ({*config} - VALID_KEYS):
116+
117+
# Allow tunable parameter keys in addition to VALID_KEYS
118+
allowed_keys = VALID_KEYS | {*self.user_defined_tunables.keys()}
119+
if invalid_keys := ({*config} - allowed_keys):
114120
raise InvalidConfig(f"Invalid config keys {sorted(invalid_keys)!r}")
115121

116122
def default_config(self) -> helion.Config:
@@ -134,6 +140,10 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
134140
)
135141
),
136142
}
143+
# Add tunable parameters
144+
for key, fragment in self.user_defined_tunables.items():
145+
config[key] = fn(fragment)
146+
137147
if self.allow_use_yz_grid:
138148
use_yz_grid = fn(BooleanFragment())
139149
# pyre-ignore[16]
@@ -191,6 +201,7 @@ def __init__(
191201
self.max_size: int = (
192202
next_power_of_2(size_hint) if max_size is None else max_size
193203
)
204+
assert self.min_size <= self.max_size
194205

195206
def __repr__(self) -> str:
196207
fields = []
@@ -207,6 +218,8 @@ def __repr__(self) -> str:
207218

208219
def update_min(self, value: int) -> None:
209220
self.min_size = assert_integer_power_of_two(max(value, self.min_size))
221+
if self.max_size < self.min_size:
222+
self.max_size = self.min_size
210223

211224
def update_max(self, value: int) -> None:
212225
self.max_size = assert_integer_power_of_two(min(value, self.max_size))

helion/exc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,24 @@ class FailedToUnpackTupleAssign(BaseError):
140140
message = "Failed to unpack values in tuple assignment. Expected a sequence of size {0}, got type: {1!s}."
141141

142142

143+
class RegisterTunableArgTypes(BaseError):
144+
message = "expected string literal and ConfigSpecFragment literal, got {0} and {1}"
145+
146+
147+
class TunableTypeNotSupported(BaseError):
148+
message = "hl.register_tunable() only supports integer, float, and boolean types, got {0!s}."
149+
150+
151+
class TunableNameConflict(BaseError):
152+
message = (
153+
"Tunable parameter with name {0!s} already exists. Please use a different name."
154+
)
155+
156+
157+
class ConfigSpecFragmentWithSymInt(BaseError):
158+
message = "ConfigSpecFragment with SymInt arg is not supported. hl.constexpr or hl.specialize may be used to specialize the SymInt value."
159+
160+
143161
class FailedToUnpackTile(BaseError):
144162
message = (
145163
"Failed to unpack a tile into a tuple assignment. "

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from .tiles import tile_block_size as tile_block_size
1818
from .tiles import tile_end as tile_end
1919
from .tiles import tile_index as tile_index
20+
from .tunable_ops import register_tunable as register_tunable
2021
from .view_ops import subscript as subscript

helion/language/tunable_ops.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from torch._inductor.codegen.simd import constant_repr
6+
7+
from .. import exc
8+
from .._compiler.ast_extension import expr_from_string
9+
from ..autotuner.config_fragment import ConfigSpecFragment
10+
from ..autotuner.config_spec import VALID_KEYS
11+
from ..exc import NotInsideKernel
12+
from . import _decorators
13+
14+
if TYPE_CHECKING:
15+
import ast
16+
17+
from .._compiler.inductor_lowering import CodegenState
18+
from .._compiler.type_propagation import TypeInfo
19+
from .._compiler.variable_origin import Origin
20+
21+
__all__ = ["register_tunable"]
22+
23+
24+
@_decorators.api(is_device_only=False)
25+
def register_tunable(name: str, fragment: ConfigSpecFragment) -> int:
26+
"""
27+
Register a tunable parameter for autotuning.
28+
29+
This function allows you to define parameters that can be automatically tuned
30+
during the autotuning process. The fragment defines the search space and default value.
31+
32+
:param name: The key for the tunable parameter in the Config().
33+
:param fragment: A ConfigSpecFragment that defines the search space (e.g., PowerOfTwoFragment)
34+
:return: The value assigned to this tunable parameter in the current configuration.
35+
"""
36+
raise NotInsideKernel
37+
38+
39+
@_decorators.type_propagation(register_tunable)
40+
def _register_tunable_type(
41+
name: TypeInfo, fragment: TypeInfo, *, origin: Origin
42+
) -> TypeInfo:
43+
# During type propagation, register the tunable parameter and return unbacked symint
44+
from .._compiler.compile_environment import CompileEnvironment
45+
from .._compiler.type_propagation import NumericType
46+
47+
env = CompileEnvironment.current()
48+
49+
try:
50+
fragment_val = fragment.as_literal()
51+
name_val = name.as_literal()
52+
except NotImplementedError:
53+
fragment_val = None
54+
name_val = None
55+
if not (isinstance(name_val, str) and isinstance(fragment_val, ConfigSpecFragment)):
56+
raise exc.RegisterTunableArgTypes(name, fragment)
57+
del name, fragment
58+
59+
if name_val in VALID_KEYS or f"{name_val}s" in VALID_KEYS:
60+
raise exc.TunableNameConflict(name_val)
61+
if (
62+
name_val in env.config_spec.user_defined_tunables
63+
and env.config_spec.user_defined_tunables[name_val] != fragment_val
64+
):
65+
raise exc.TunableNameConflict(name_val)
66+
67+
# register the value for tuning
68+
env.config_spec.user_defined_tunables[name_val] = fragment_val
69+
70+
python_type = type(fragment_val.default())
71+
if not issubclass(python_type, (int, float, bool)):
72+
raise exc.TunableTypeNotSupported(python_type)
73+
return NumericType.subtype(python_type).new_unbacked(origin)
74+
75+
76+
@_decorators.codegen(register_tunable)
77+
def _register_tunable_codegen(state: CodegenState) -> ast.AST:
78+
name = state.proxy_arg(0)
79+
assert isinstance(name, str)
80+
config_value = state.config[name]
81+
assert isinstance(config_value, (int, float, bool))
82+
return expr_from_string(constant_repr(config_value))

0 commit comments

Comments
 (0)