Skip to content

Commit 16740ac

Browse files
authored
Fix bug with computations based on hl.register_block_size (#157)
1 parent b9e93c0 commit 16740ac

File tree

11 files changed

+151
-43
lines changed

11 files changed

+151
-43
lines changed

helion/_compiler/device_function.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def __init__(self, name: str, config: Config) -> None:
152152
self._variable_renames: dict[str, list[str]] = {}
153153
self.dce_vars: list[str] = []
154154
self.block_size_var_cache: dict[tuple[int, ...], str] = {}
155+
self.expr_to_var_name: dict[sympy.Expr, str] = {}
155156

156157
from .indexing_strategy import IndexingStrategy
157158
from .tile_dispatch import TileStrategyDispatch
@@ -175,17 +176,26 @@ def set_pid(self, pid: SharedProgramID | ProgramIDs) -> None:
175176
self.pid = pid
176177

177178
def sympy_expr(self, expr: sympy.Expr) -> str:
178-
expr_to_origin = HostFunction.current().expr_to_origin
179179
expr = CompileEnvironment.current().shape_env.simplify(expr)
180180
if not expr.free_symbols:
181181
return texpr(expr)
182+
if expr in self.expr_to_var_name:
183+
return self.expr_to_var_name[expr]
184+
expr_to_origin = HostFunction.current().expr_to_origin
182185
if expr in expr_to_origin:
183186
return self._lift_sympy_arg(expr)
184187
replacements = {}
185188
for sym in sorted(expr.free_symbols, key=lambda x: x.name):
186189
assert isinstance(sym, sympy.Symbol)
187-
assert sym in expr_to_origin, f"no origin found for {sym.name}"
188-
replacements[sym] = sympy.Symbol(self._lift_sympy_arg(sym), integer=True)
190+
if sym in self.expr_to_var_name:
191+
replacements[sym] = sympy.Symbol(
192+
self.expr_to_var_name[sym], integer=True
193+
)
194+
else:
195+
assert sym in expr_to_origin, f"no origin found for {sym.name}"
196+
replacements[sym] = sympy.Symbol(
197+
self._lift_sympy_arg(sym), integer=True
198+
)
189199
return texpr(expr.xreplace(replacements))
190200

191201
def _lift_sympy_arg(self, expr: sympy.Expr) -> str:

helion/_compiler/generate_ast.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,15 @@ def set_statements(self, new_statements: list[ast.AST] | None) -> Iterator[None]
8080
if new_statements is None:
8181
yield
8282
else:
83+
expr_to_var_name = self.device_function.expr_to_var_name
84+
# We don't want to reuse vars assigned in a nested scope, so copy it
85+
self.device_function.expr_to_var_name = expr_to_var_name.copy()
8386
self.statements_stack.append(new_statements)
8487
try:
8588
yield
8689
finally:
8790
self.statements_stack.pop()
91+
self.device_function.expr_to_var_name = expr_to_var_name
8892

8993
@contextlib.contextmanager
9094
def set_on_device(self) -> Iterator[None]:

helion/_compiler/inductor_lowering.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -927,13 +927,26 @@ def run_node(self, n: Node) -> object:
927927
return result
928928
assert isinstance(result, ast.expr)
929929
if len(n.users) > 0:
930-
if isinstance(result, (ast.Name, ast.Constant)):
931-
return result
932-
name = self.cg.device_function.new_var(n.name)
933-
self.cg.add_statement(
934-
statement_from_string(f"{name} = result", result=result)
935-
)
936-
return create(ast.Name, id=name, ctx=ast.Load())
930+
if not isinstance(result, (ast.Name, ast.Constant)):
931+
name = self.cg.device_function.new_var(n.name)
932+
self.cg.add_statement(
933+
statement_from_string(f"{name} = result", result=result)
934+
)
935+
result = create(ast.Name, id=name, ctx=ast.Load())
936+
if (
937+
isinstance(val := n.meta["val"], torch.SymInt)
938+
and len((expr := val._sympy_()).free_symbols) > 0
939+
):
940+
# Keep track of what variable symints are stored in to support DeviceFunction.sympy_expr()
941+
expr = CompileEnvironment.current().shape_env.simplify(expr)
942+
if isinstance(result, ast.Name):
943+
self.cg.device_function.expr_to_var_name[expr] = result.id
944+
else:
945+
assert isinstance(result, ast.Constant)
946+
self.cg.device_function.expr_to_var_name[expr] = repr(
947+
result.value
948+
)
949+
return result
937950
if not isinstance(result, (ast.Name, ast.Constant)):
938951
self.cg.add_statement(create(ast.Expr, value=result))
939952
return None
@@ -997,3 +1010,6 @@ def config(self) -> Config:
9971010

9981011
def add_statement(self, statement: ast.AST | str) -> None:
9991012
return self.codegen.add_statement(statement)
1013+
1014+
def sympy_expr(self, expr: sympy.Expr) -> str:
1015+
return self.codegen.device_function.sympy_expr(expr)

helion/_compiler/program_id.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def host_cdiv(self) -> str:
2828
return f"triton.cdiv({numel_str}, {self.block_size_var})"
2929

3030
def device_cdiv(self, state: CodegenState) -> str:
31-
numel_str = state.device_function.sympy_expr(self.numel)
31+
numel_str = state.sympy_expr(self.numel)
3232
if self.block_size_var == "1":
3333
return numel_str
3434
return f"tl.cdiv({numel_str}, {self.block_size_var})"

helion/_compiler/reduction_strategy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def __init__(
227227
def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
228228
env = CompileEnvironment.current()
229229
block_index = self.block_index
230-
device_function = state.device_function
231230
numel = env.block_sizes[block_index].numel
232231
offset_var = self.offset_var(block_index)
233232
index_var = self.index_var(block_index)
@@ -245,21 +244,21 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
245244
if (mask_var := self._mask_var) is not None:
246245
body.append(
247246
statement_from_string(
248-
f"{mask_var} = {index_var} < {device_function.sympy_expr(numel)}"
247+
f"{mask_var} = {index_var} < {state.sympy_expr(numel)}"
249248
)
250249
)
251250
for_node = create(
252251
ast.For,
253252
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
254253
iter=expr_from_string(
255-
f"range(0, ({device_function.sympy_expr(numel)}), {block_size_var})"
254+
f"range(0, ({state.sympy_expr(numel)}), {block_size_var})"
256255
),
257256
body=body,
258257
orelse=[],
259258
type_comment=None,
260259
)
261260
# Extract end_var_name from the actual numel expression used in the range()
262-
end_var_name = {block_index: device_function.sympy_expr(numel)}
261+
end_var_name = {block_index: state.sympy_expr(numel)}
263262
return DeviceLoopState(
264263
self,
265264
for_node=for_node,

helion/_compiler/tile_strategy.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def _codegen_common(
212212
block_ids = self.block_ids
213213
env = CompileEnvironment.current()
214214
total_numel = sympy.S.One
215-
device_function = state.device_function
216215
offsets_var = self.new_var("offsets", dce=True)
217216
block_size_var = self.block_size_var(-1)
218217
statements = []
@@ -227,17 +226,17 @@ def _codegen_common(
227226
block_index_var = self.index_var(block_idx)
228227
expr = offsets_var
229228
if total_numel != sympy.S.One:
230-
expr = f"({expr}) // ({device_function.sympy_expr(total_numel)})"
229+
expr = f"({expr}) // ({state.sympy_expr(total_numel)})"
231230
if i + 1 < len(block_ids):
232-
expr = f"({expr}) % ({device_function.sympy_expr(numel)})"
231+
expr = f"({expr}) % ({state.sympy_expr(numel)})"
233232
statements.append(statement_from_string(f"{block_index_var} = {expr}"))
234233
total_numel = total_numel * numel
235234

236235
mask_var = self.mask_var(-1)
237236
if mask_var is not None:
238237
statements.append(
239238
statement_from_string(
240-
f"{mask_var} = {offsets_var} < ({device_function.sympy_expr(total_numel)})"
239+
f"{mask_var} = {offsets_var} < ({state.sympy_expr(total_numel)})"
241240
)
242241
)
243242
return block_size_var, offsets_var, total_numel, statements
@@ -264,7 +263,7 @@ def codegen_grid(self) -> ast.AST:
264263
end_var_name = {}
265264
for block_id in self.block_ids:
266265
end_bound = env.block_sizes[block_id].numel
267-
end_var_name[block_id] = state.device_function.sympy_expr(end_bound)
266+
end_var_name[block_id] = state.sympy_expr(end_bound)
268267
return DeviceGridState(self, end_var_name=end_var_name)
269268

270269
def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
@@ -277,7 +276,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
277276
ast.For,
278277
target=create(ast.Name, id=lid, ctx=ast.Store()),
279278
iter=expr_from_string(
280-
f"range(tl.cdiv({state.device_function.sympy_expr(total_numel)}, {block_size_var}))"
279+
f"range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}))"
281280
),
282281
body=(
283282
body := [
@@ -417,7 +416,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
417416
end_var_name = {}
418417
for block_id in self.block_ids:
419418
end_bound = env.block_sizes[block_id].numel
420-
end_var_name[block_id] = state.device_function.sympy_expr(end_bound)
419+
end_var_name[block_id] = state.sympy_expr(end_bound)
421420
return DeviceGridState(self, end_var_name=end_var_name)
422421

423422
def select_pid_strategy(self) -> ProgramIDs:

helion/_compiler/variable_origin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ def host_str(self) -> str:
237237
return "1"
238238
return var
239239

240+
def suggest_var_name(self) -> str:
241+
return f"block_size_{self.block_id}"
242+
240243

241244
@dataclasses.dataclass
242245
class ReductionDimensionOrigin(Origin):

helion/language/_tracing_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _(state: CodegenState) -> ast.AST:
4343
# this should be unused
4444
return expr_from_string("block_size_var_optimized_away")
4545
return state.codegen.lift(
46-
expr_from_string(state.device_function.sympy_expr(val._sympy_())),
46+
expr_from_string(state.sympy_expr(val._sympy_())),
4747
dce=True,
4848
prefix="symnode",
4949
)

helion/language/loops.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import overload
99

1010
import torch
11+
from torch._inductor.codegen.simd import constant_repr
1112
from torch._inductor.runtime.runtime_utils import next_power_of_2
1213
from torch._inductor.runtime.triton_heuristics import get_max_y_grid
1314

@@ -374,7 +375,7 @@ def _(state: CodegenState) -> ast.AST:
374375

375376

376377
@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True)
377-
def register_block_size(min_or_max: int, max_or_none: int | None = None, /) -> Tile:
378+
def register_block_size(min_or_max: int, max_or_none: int | None = None, /) -> int:
378379
"""
379380
Explicitly register a block size that should be autotuned and can be used for
380381
allocations and inside hl.tile(..., block_size=...).
@@ -396,6 +397,8 @@ def register_block_size(min_or_max: int, max_or_none: int | None = None, /) -> T
396397
def _(
397398
min_or_max: TypeInfo, max_or_none: TypeInfo | None = None, /, *, origin: Origin
398399
) -> TypeInfo:
400+
from .._compiler.type_propagation import SymIntType
401+
399402
min_type, max_type = _normalize_begin_end(min_or_max, max_or_none, origin=origin)
400403
min_proxy = _to_proxy(min_type)
401404
max_proxy = _to_proxy(max_type)
@@ -412,13 +415,43 @@ def _(
412415
loop_spec = env.config_spec.block_sizes.block_id_lookup(result.block_id)
413416
loop_spec.min_size = assert_integer_power_of_two(max(1, min_proxy))
414417
loop_spec.max_size = next_power_of_2(env.size_hint(max_proxy))
415-
return result
418+
block_id = result.block_id
419+
return SymIntType(origin, env.block_sizes[block_id].var)
420+
421+
422+
def _block_id_from_state(state: CodegenState) -> int:
423+
"""Extract the block_id from the current state for nodes hl.register_block_size."""
424+
from .._compiler.type_propagation import SymIntType
425+
426+
env = CompileEnvironment.current()
427+
if state.fx_node is not None:
428+
val = state.fx_node.meta["val"]
429+
assert isinstance(val, SymIntType)
430+
block_id = env.get_block_id(val.value)
431+
assert block_id is not None
432+
return block_id
433+
current_node = ExtendedAST.current()[-1]
434+
type_info = current_node._type_info
435+
assert isinstance(type_info, SymIntType)
436+
block_id = env.get_block_id(type_info.value)
437+
assert block_id is not None
438+
return block_id
439+
440+
441+
@_decorators.codegen(register_block_size)
442+
def _(state: CodegenState) -> ast.AST:
443+
env = CompileEnvironment.current()
444+
block_size = env.config_spec.block_sizes.config_get(
445+
state.config.block_sizes, _block_id_from_state(state)
446+
)
447+
assert block_size is not None
448+
return expr_from_string(constant_repr(block_size))
416449

417450

418451
@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True)
419452
def register_reduction_dim(
420453
size: int,
421-
) -> torch.SymInt:
454+
) -> int:
422455
"""
423456
Explicitly register a reduction dimension that should be used for reduction operations.
424457
@@ -432,21 +465,10 @@ def register_reduction_dim(
432465
raise exc.NotInsideKernel
433466

434467

435-
@_decorators.register_fake(register_reduction_dim)
436-
def _(size: int) -> torch.SymInt:
437-
"""Fake implementation that returns the registered reduction dimension size(s)"""
438-
from .._compiler.compile_environment import CompileEnvironment
439-
440-
env = CompileEnvironment.current()
441-
442-
rdim = env.allocate_reduction_dimension(size)
443-
return rdim.var
444-
445-
446468
@_decorators.type_propagation(register_reduction_dim)
447469
def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo:
448470
from .._compiler.compile_environment import CompileEnvironment
449-
from .._compiler.type_propagation import ReductionDimType
471+
from .._compiler.type_propagation import SymIntType
450472

451473
try:
452474
proxy_sizes = sizes.proxy()
@@ -464,16 +486,16 @@ def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo:
464486
env = CompileEnvironment.current()
465487

466488
rdim = env.allocate_reduction_dimension(proxy_sizes)
467-
return ReductionDimType(origin, rdim.block_id)
489+
return SymIntType(origin, rdim.var)
468490

469491

470492
@_decorators.codegen(register_reduction_dim)
471493
def _(state: CodegenState) -> ast.AST:
472494
"""Generate code for register_reduction_dim - return the size expression"""
473-
from .._compiler.type_propagation import ReductionDimType
495+
from .._compiler.type_propagation import SymIntType
474496

475497
current_node = ExtendedAST.current()[-1]
476498
type_info = current_node._type_info
477499

478-
assert isinstance(type_info, ReductionDimType)
500+
assert isinstance(type_info, SymIntType)
479501
return current_node.args[0] # pyre-ignore[16]

helion/language/tiles.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def tile_begin(tile: Tile) -> int:
5858

5959

6060
@_decorators.register_fake(tile_begin)
61-
def _(state: CodegenState) -> torch.SymInt:
61+
def _(tile: torch.SymInt) -> torch.SymInt:
6262
return CompileEnvironment.current().create_unbacked_symint()
6363

6464

@@ -93,7 +93,7 @@ def tile_end(tile: Tile) -> int:
9393

9494

9595
@_decorators.register_fake(tile_end)
96-
def _(state: CodegenState) -> torch.SymInt:
96+
def _(tile: torch.SymInt) -> torch.SymInt:
9797
return CompileEnvironment.current().create_unbacked_symint()
9898

9999

0 commit comments

Comments
 (0)