Skip to content

Commit 8efc252

Browse files
authored
Support hl.tile_{begin,end,block_size} (#150)
1 parent b64bf00 commit 8efc252

17 files changed

+323
-30
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,4 @@ venv
8585
.watchman
8686
.watchmanconfig
8787
*.zip
88+
CLAUDE.md

helion/_compiler/compile_environment.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
169169
self.debug_shape_renames[sym._sympy_()] = sympy.Symbol(debug_name, integer=True)
170170
return sym
171171

172+
def create_unbacked_symint(self, hint: int = 8192) -> torch.SymInt:
173+
with self.shape_env.ignore_fresh_unbacked_symbols():
174+
sym = self.shape_env.create_unbacked_symint()
175+
# TODO(jansel): this is a hack to get us past some == 1 checks
176+
# we should probably have a better way to handle this
177+
self.shape_env.var_to_val[sym._sympy_()] = sympy.sympify(hint)
178+
return sym
179+
172180
def to_fake(self, obj: object, origin: Origin) -> object:
173181
if isinstance(obj, torch.Tensor):
174182
return self._to_fake_tensor(obj, origin.to_source())
@@ -177,12 +185,7 @@ def to_fake(self, obj: object, origin: Origin) -> object:
177185
with self.shape_env.ignore_fresh_unbacked_symbols():
178186
return self.shape_env.create_unbacked_symbool()
179187
if isinstance(obj, int):
180-
with self.shape_env.ignore_fresh_unbacked_symbols():
181-
sym = self.shape_env.create_unbacked_symint()
182-
# TODO(jansel): this is a hack to get us past some == 1 checks
183-
# we should probably have a better way to handle this
184-
self.shape_env.var_to_val[sym._sympy_()] = sympy.sympify(8192)
185-
return sym
188+
return self.create_unbacked_symint()
186189
if isinstance(obj, float):
187190
with self.shape_env.ignore_fresh_unbacked_symbols():
188191
return self.shape_env.create_unbacked_symfloat()

helion/_compiler/device_ir.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
from .type_propagation import CallableType
5050
from .type_propagation import GridIndexType
5151
from .type_propagation import IterType
52+
from .type_propagation import LiteralType
53+
from .type_propagation import NumericType
5254
from .type_propagation import SequenceType
5355
from .type_propagation import TensorType
5456
from .type_propagation import TileIndexType
@@ -739,7 +741,9 @@ def visit_Assign(self, node: ast.Assign) -> None:
739741
rhs_type = node.value._type_info
740742
assert isinstance(target, ExtendedAST)
741743
lhs_type = target._type_info
742-
if not isinstance(lhs_type, TensorType) or not isinstance(rhs_type, TensorType):
744+
if not isinstance(lhs_type, TensorType) or not isinstance(
745+
rhs_type, (TensorType, NumericType, LiteralType)
746+
):
743747
raise exc.NonTensorSubscriptAssign(lhs_type, rhs_type)
744748
assert isinstance(target.value, ExtendedAST)
745749
target_origin = target.value._type_info.origin

helion/_compiler/generate_ast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ def add_statement(self, stmt: ast.AST | str | None) -> None:
6363
stmt = statement_from_string(stmt)
6464
self.statements_stack[-1].append(stmt)
6565

66-
def tmpvar(self, dce: bool = False) -> str:
67-
return self.device_function.unique_name("v", dce=dce)
66+
def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str:
67+
return self.device_function.unique_name(prefix, dce=dce)
6868

69-
def lift(self, expr: ast.AST, dce: bool = False) -> ast.Name:
69+
def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name:
7070
if isinstance(expr, ast.Name):
7171
return expr
7272
assert isinstance(expr, ExtendedAST), expr
7373
with expr:
74-
varname = self.tmpvar(dce=dce)
74+
varname = self.tmpvar(dce=dce, prefix=prefix)
7575
self.add_statement(statement_from_string(f"{varname} = expr", expr=expr))
7676
return create(ast.Name, id=varname, ctx=ast.Load())
7777

helion/_compiler/indexing_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def create(
306306
ast_index = state.ast_args[1]
307307
assert isinstance(ast_index, (list, tuple))
308308
assert len(ast_index) == len(index)
309-
index_var = state.codegen.lift(ast_index[n]).id
309+
index_var = state.codegen.lift(ast_index[n], prefix="index").id
310310
index_values.append(f"({index_var}){expand}")
311311
if (
312312
block_idx := TileStrategy.get_block_index(output_size[output_idx])
@@ -321,7 +321,7 @@ def create(
321321
ast_index = state.ast_args[1]
322322
assert isinstance(ast_index, (list, tuple))
323323
assert len(ast_index) == 1
324-
index_var = state.codegen.lift(ast_index[0]).id
324+
index_var = state.codegen.lift(ast_index[0], prefix="index").id
325325
index_values.append(index_var)
326326
output_idx += k.ndim
327327
for n, s in enumerate(output_size):

helion/_compiler/reduction_strategy.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,13 @@ def codegen_preamble(self, state: CodegenState) -> None:
176176
state.add_statement(
177177
f"{mask_var} = {index_var} < {self.fn.sympy_expr(numel)}"
178178
)
179-
state.codegen.set_active_loops(PersistentReductionState(self))
179+
# Extract end_var_name from the numel expression
180+
env = CompileEnvironment.current()
181+
numel = env.block_sizes[self.block_index].numel
182+
end_var_name = {self.block_index: self.fn.sympy_expr(numel)}
183+
state.codegen.set_active_loops(
184+
PersistentReductionState(self, end_var_name=end_var_name)
185+
)
180186

181187
def codegen_reduction(
182188
self,
@@ -254,11 +260,14 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
254260
orelse=[],
255261
type_comment=None,
256262
)
263+
# Extract end_var_name from the actual numel expression used in the range()
264+
end_var_name = {block_index: device_function.sympy_expr(numel)}
257265
return DeviceLoopState(
258266
self,
259267
for_node=for_node,
260268
inner_statements=body,
261269
end_bounds={block_index: numel},
270+
end_var_name=end_var_name,
262271
)
263272

264273
def codegen_reduction(

helion/_compiler/tile_dispatch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from helion._compiler.reduction_strategy import PersistentReductionStrategy
1414
from helion._compiler.reduction_strategy import ReductionStrategy
1515
from helion._compiler.tile_strategy import CompactedShape
16-
from helion._compiler.tile_strategy import DeviceGridState
1716
from helion._compiler.tile_strategy import DeviceLoopState
1817
from helion._compiler.tile_strategy import FlattenedTileStrategy
1918
from helion._compiler.tile_strategy import NDGridTileStrategy
@@ -111,11 +110,11 @@ def _add_reduction_strategies(self, fn: DeviceFunction, config: Config) -> None:
111110

112111
def codegen_grid(self, state: CodegenState, block_ids: list[int]) -> None:
113112
strategy = self.block_id_to_strategy[tuple(block_ids)]
114-
strategy.codegen_grid(state)
113+
grid_state = strategy.codegen_grid(state)
115114
for other_strategy in self.strategies:
116115
if other_strategy is not strategy:
117116
other_strategy.codegen_preamble(state)
118-
state.codegen.set_active_loops(DeviceGridState(strategy))
117+
state.codegen.set_active_loops(grid_state)
119118

120119
def codegen_device_loop(
121120
self, state: CodegenState, block_ids: list[int]

helion/_compiler/tile_index_proxy.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,33 @@ def index(self) -> torch.Tensor:
9191

9292
return tile_index(self)
9393

94+
@property
95+
def begin(self) -> int:
96+
"""
97+
Alias for hl.tile_begin, which retrieves the start offset of a tile.
98+
"""
99+
from ..language.tiles import tile_begin
100+
101+
return tile_begin(self)
102+
103+
@property
104+
def end(self) -> int:
105+
"""
106+
Alias for hl.tile_end, which retrieves the end offset of a tile.
107+
"""
108+
from ..language.tiles import tile_end
109+
110+
return tile_end(self)
111+
112+
@property
113+
def block_size(self) -> int:
114+
"""
115+
Alias for hl.tile_block_size, which retrieves the block_size of a tile.
116+
"""
117+
from ..language.tiles import tile_block_size
118+
119+
return tile_block_size(self)
120+
94121

95122
class CheckForIndexCalls:
96123
"""

helion/_compiler/tile_strategy.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
@dataclasses.dataclass
4343
class DeviceLoopOrGridState:
4444
strategy: TileStrategy
45+
end_var_name: dict[int, str]
4546

4647
@property
4748
def block_ids(self) -> list[int]:
@@ -106,7 +107,7 @@ def block_size_var(self, block_idx: int) -> str | None:
106107
def user_size(self, block_index: int) -> sympy.Expr:
107108
raise NotImplementedError
108109

109-
def codegen_grid(self, state: CodegenState) -> None:
110+
def codegen_grid(self, state: CodegenState) -> DeviceGridState:
110111
raise NotImplementedError
111112

112113
def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
@@ -255,7 +256,7 @@ def _codegen_common(
255256
)
256257
return block_size_var, offsets_var, total_numel, statements
257258

258-
def codegen_grid(self, state: CodegenState) -> None:
259+
def codegen_grid(self, state: CodegenState) -> DeviceGridState:
259260
block_size_var, offsets_var, total_numel, statements = self._codegen_common(
260261
state
261262
)
@@ -273,6 +274,12 @@ def codegen_grid(self) -> ast.AST:
273274

274275
state.device_function.set_pid(TmpPid())
275276

277+
end_var_name = {}
278+
for block_id in self.block_ids:
279+
end_bound = CompileEnvironment.current().block_sizes[block_id].numel
280+
end_var_name[block_id] = state.device_function.sympy_expr(end_bound)
281+
return DeviceGridState(self, end_var_name=end_var_name)
282+
276283
def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
277284
block_size_var, offsets_var, total_numel, statements = self._codegen_common(
278285
state
@@ -301,6 +308,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
301308
for_node=for_node,
302309
inner_statements=body,
303310
end_bounds=self.get_end_bounds(state),
311+
end_var_name={},
304312
)
305313

306314
@classmethod
@@ -361,7 +369,7 @@ def __init__(
361369
f"_BLOCK_SIZE_{block_idx}"
362370
)
363371

364-
def codegen_grid(self, state: CodegenState) -> None:
372+
def codegen_grid(self, state: CodegenState) -> DeviceGridState:
365373
block_ids = self.block_ids
366374
env = CompileEnvironment.current()
367375
device_function = state.device_function
@@ -417,6 +425,13 @@ def codegen_grid(self, state: CodegenState) -> None:
417425
else:
418426
state.device_function.set_pid(pids)
419427

428+
# Extract end_var_name from end bound expressions
429+
end_var_name = {}
430+
for block_id in self.block_ids:
431+
end_bound = CompileEnvironment.current().block_sizes[block_id].numel
432+
end_var_name[block_id] = state.device_function.sympy_expr(end_bound)
433+
return DeviceGridState(self, end_var_name=end_var_name)
434+
420435
def select_pid_strategy(self) -> ProgramIDs:
421436
if 1 < len(self.block_ids) <= 3 and self.fn.config.use_yz_grid:
422437
return GridProgramIDs()
@@ -447,6 +462,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
447462
_, begins, ends, _ = state.ast_args
448463
assert isinstance(begins, list)
449464
assert isinstance(ends, list)
465+
end_var_name = {}
450466
for block_idx, block_size, begin, end in self._reorder(
451467
[*zip(block_ids, block_sizes, begins, ends, strict=True)]
452468
):
@@ -463,6 +479,9 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
463479
)
464480
else:
465481
block_size_var = "1"
482+
end_var_name[block_idx] = state.codegen.lift(
483+
self._to_ast(end, to_dtype=dtype), dce=True, prefix="end"
484+
).id
466485
for_node = create(
467486
ast.For,
468487
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
@@ -494,6 +513,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
494513
for_node=for_node,
495514
inner_statements=innermost_body,
496515
end_bounds=self.get_end_bounds(state),
516+
end_var_name=end_var_name,
497517
)
498518

499519
def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:

helion/_compiler/type_propagation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,9 @@ def propagate_setitem(
503503
rhs_rank,
504504
f"LHS shape: {tuple(lhs_shape)}, RHS shape: {tuple(value.fake_value.shape)}",
505505
)
506+
elif isinstance(value, (NumericType, LiteralType)):
507+
# Allow scalar assignment to tensor (broadcasts to tensor shape)
508+
pass
506509
elif isinstance(value, UnknownType):
507510
raise exc.TypePropagationError(value)
508511
else:

0 commit comments

Comments
 (0)