Skip to content

Commit c432607

Browse files
authored
Rename TileStrategy.get_block_index to CompileEnvironment.get_block_id (#151)
1 parent 8efc252 commit c432607

11 files changed

+57
-58
lines changed

helion/_compiler/compile_environment.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,9 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
7474
self.loop_dependency_checker = LoopDependencyChecker()
7575

7676
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
77-
from .tile_strategy import TileStrategy
78-
7977
for size in sizes:
8078
if isinstance(size, torch.SymInt):
81-
block_idx = TileStrategy.get_block_index(size)
79+
block_idx = self.get_block_id(size)
8280
if block_idx is None:
8381
value = self.shape_env.replace(size._sympy_())
8482
if value.free_symbols:
@@ -315,6 +313,33 @@ def has_current() -> bool:
315313
except NoCurrentEnvironment:
316314
return False
317315

316+
def get_block_id(self, size: int | torch.SymInt | sympy.Expr) -> int | None:
317+
"""
318+
Get the block ID associated with a given size expression.
319+
320+
This method determines if a size expression corresponds to a registered block size
321+
in the current compilation environment. It looks up the origin information of
322+
symbolic expressions to find their associated block IDs.
323+
324+
Args:
325+
size: The size expression to check. Can be an integer, torch.SymInt, or sympy.Expr.
326+
327+
Returns:
328+
The block ID if the size corresponds to a registered block size, None otherwise.
329+
"""
330+
if isinstance(size, torch.SymInt):
331+
return self.get_block_id(size._sympy_())
332+
if isinstance(size, sympy.Symbol):
333+
from .host_function import HostFunction
334+
335+
origin_info = HostFunction.current().expr_to_origin.get(size)
336+
if origin_info is not None and isinstance(
337+
origin_info.origin,
338+
BlockSizeOrigin,
339+
):
340+
return origin_info.origin.block_id
341+
return None
342+
318343

319344
class NoCurrentEnvironment(RuntimeError):
320345
pass

helion/_compiler/device_function.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from .host_function import HostFunction
3030
from .host_function import NoCurrentFunction
3131
from .output_header import reserved_names
32-
from .tile_strategy import TileStrategy
3332
from .variable_origin import BlockSizeOrigin
3433
from .variable_origin import Origin
3534
from .variable_origin import TensorSizeOrigin
@@ -209,7 +208,7 @@ def user_sympy_expr(self, expr: sympy.Expr) -> str:
209208
replacements = {}
210209
for sym in sorted(expr.free_symbols, key=lambda s: s.name):
211210
assert isinstance(sym, sympy.Symbol)
212-
block_idx = TileStrategy.get_block_index(sym)
211+
block_idx = CompileEnvironment.current().get_block_id(sym)
213212
if block_idx is not None:
214213
replacements[sym] = self.tile_strategy.user_size(block_idx)
215214
if replacements:

helion/_compiler/indexing_strategy.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from .compile_environment import CompileEnvironment
1515
from .host_function import HostFunction
1616
from .tile_strategy import DeviceLoopState
17-
from .tile_strategy import TileStrategy
1817
from .variable_origin import BlockSizeOrigin
1918

2019
if TYPE_CHECKING:
@@ -207,6 +206,7 @@ def compute_shape(
207206
assert isinstance(index, (list, tuple)), index
208207
input_size = collections.deque(tensor.size())
209208
output_size = []
209+
env = CompileEnvironment.current()
210210
for k in index:
211211
if k is None:
212212
output_size.append(1)
@@ -218,11 +218,7 @@ def compute_shape(
218218
if isinstance(symbol, sympy.Symbol):
219219
origin = HostFunction.current().expr_to_origin.get(symbol)
220220
if origin and isinstance(origin.origin, BlockSizeOrigin):
221-
if (
222-
CompileEnvironment.current()
223-
.block_sizes[origin.origin.block_id]
224-
.is_grid()
225-
):
221+
if env.block_sizes[origin.origin.block_id].is_grid():
226222
pass
227223
elif tensor.size(tensor.ndim - len(input_size) - 1) != 1:
228224
output_size.append(k)
@@ -231,9 +227,7 @@ def compute_shape(
231227
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
232228
size = input_size.popleft()
233229
if size != 1:
234-
rdim = CompileEnvironment.current().allocate_reduction_dimension(
235-
size
236-
)
230+
rdim = env.allocate_reduction_dimension(size)
237231
output_size.append(rdim.var)
238232
else:
239233
output_size.append(1)
@@ -308,9 +302,7 @@ def create(
308302
assert len(ast_index) == len(index)
309303
index_var = state.codegen.lift(ast_index[n], prefix="index").id
310304
index_values.append(f"({index_var}){expand}")
311-
if (
312-
block_idx := TileStrategy.get_block_index(output_size[output_idx])
313-
) is not None:
305+
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
314306
if mask := state.codegen.mask_var(block_idx):
315307
mask_values.setdefault(f"({mask}){expand}")
316308
output_idx += 1
@@ -325,7 +317,7 @@ def create(
325317
index_values.append(index_var)
326318
output_idx += k.ndim
327319
for n, s in enumerate(output_size):
328-
if (block_idx := TileStrategy.get_block_index(s)) is not None and (
320+
if (block_idx := env.get_block_id(s)) is not None and (
329321
mask := state.codegen.mask_var(block_idx)
330322
):
331323
mask_values.setdefault(

helion/_compiler/inductor_lowering.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from .node_masking import getitem_masked_value
5252
from .node_masking import inductor_masked_value
5353
from .node_masking import mask_node_inputs
54-
from .tile_strategy import TileStrategy
5554

5655
if TYPE_CHECKING:
5756
from collections.abc import Callable
@@ -429,7 +428,7 @@ def __init__(
429428
reduction_var = reduction_ranges[0]
430429
assert isinstance(reduction_var, sympy.Symbol)
431430

432-
block_index = TileStrategy.get_block_index(reduction_var)
431+
block_index = CompileEnvironment.current().get_block_id(reduction_var)
433432
assert block_index is not None
434433
self.block_index: int = block_index
435434

@@ -497,7 +496,7 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
497496
(dim,) = [
498497
i
499498
for i, v in enumerate(repr_input.shape)
500-
if TileStrategy.get_block_index(v) == self.block_index
499+
if CompileEnvironment.current().get_block_id(v) == self.block_index
501500
]
502501

503502
return strategy.codegen_reduction(
@@ -764,7 +763,7 @@ def apply_dot_requirements(
764763
a, b, c = min_dot_size(lproxy.device, lproxy.dtype, rproxy.dtype)
765764
env = CompileEnvironment.current()
766765
for shape, min_size in [(n, a), (k, b), (m, c)]:
767-
block_idx = TileStrategy.get_block_index(shape)
766+
block_idx = CompileEnvironment.current().get_block_id(shape)
768767
if block_idx is not None:
769768
env.block_sizes[block_idx].update_min_block(min_size, allow_flattened=True)
770769
# inputs to the dot operation must be zero-masked

helion/_compiler/reduction_strategy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,6 @@ def codegen_preamble(self, state: CodegenState) -> None:
177177
f"{mask_var} = {index_var} < {self.fn.sympy_expr(numel)}"
178178
)
179179
# Extract end_var_name from the numel expression
180-
env = CompileEnvironment.current()
181-
numel = env.block_sizes[self.block_index].numel
182180
end_var_name = {self.block_index: self.fn.sympy_expr(numel)}
183181
state.codegen.set_active_loops(
184182
PersistentReductionState(self, end_var_name=end_var_name)

helion/_compiler/roll_reduction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import torch
77
from torch.fx import map_arg
88

9+
from helion._compiler.compile_environment import CompileEnvironment
910
from helion._compiler.inductor_lowering import APIFuncLowering
1011
from helion._compiler.inductor_lowering import ReductionLowering
1112
from helion._compiler.inductor_lowering import aten_lowering_dispatch
12-
from helion._compiler.tile_strategy import TileStrategy
1313
from helion.language._tracing_ops import _for_loop
1414
from helion.language._tracing_ops import _get_symnode
1515
from helion.language._tracing_ops import _host_tensor
@@ -102,7 +102,7 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:
102102
num_rdims = 0
103103
if isinstance(val, torch.Tensor):
104104
for size in val.size():
105-
block_idx = TileStrategy.get_block_index(size)
105+
block_idx = CompileEnvironment.current().get_block_id(size)
106106
num_rdims += block_idx == self.rdim.block_id
107107
if num_rdims > 1:
108108
raise NotImplementedError(
@@ -113,7 +113,7 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:
113113
for item in val:
114114
if isinstance(item, torch.Tensor):
115115
for size in item.size():
116-
block_idx = TileStrategy.get_block_index(size)
116+
block_idx = CompileEnvironment.current().get_block_id(size)
117117
num_rdims += block_idx == self.rdim.block_id
118118
if num_rdims > 1:
119119
raise NotImplementedError(
@@ -263,7 +263,7 @@ def is_matmul_with_rdim(node: torch.fx.Node) -> bool:
263263
val = input_node.meta.get("val", None)
264264
if isinstance(val, torch.Tensor):
265265
for size in val.size():
266-
block_idx = TileStrategy.get_block_index(size)
266+
block_idx = CompileEnvironment.current().get_block_id(size)
267267
if block_idx == self.rdim.block_id:
268268
return True
269269
return False

helion/_compiler/tile_dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def codegen_device_loop(
125125
def _compact_shape(self, shapes: ShapeLike) -> list[CompactedShape]:
126126
compacted_shapes = []
127127
for idx, shape in enumerate(shapes):
128-
block_idx = TileStrategy.get_block_index(shape)
128+
block_idx = CompileEnvironment.current().get_block_id(shape)
129129
if block_idx is None:
130130
compacted_shapes.append(
131131
CompactedShape(self.strategies[0].fn.literal_expr(shape), [idx], [])

helion/_compiler/tile_strategy.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from .program_id import ProgramIDs
2727
from .program_id import SharedProgramID
2828
from .program_id import VirtualProgramIDs
29-
from .variable_origin import BlockSizeOrigin
3029

3130
if TYPE_CHECKING:
3231
from collections.abc import Sequence
@@ -119,19 +118,6 @@ def codegen_preamble(self, state: CodegenState) -> None:
119118
def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
120119
raise NotImplementedError
121120

122-
@classmethod
123-
def get_block_index(cls, size: int | torch.SymInt | sympy.Expr) -> int | None:
124-
if isinstance(size, torch.SymInt):
125-
return cls.get_block_index(size._sympy_())
126-
if isinstance(size, sympy.Symbol):
127-
origin_info = HostFunction.current().expr_to_origin.get(size)
128-
if origin_info is not None and isinstance(
129-
origin_info.origin,
130-
BlockSizeOrigin,
131-
):
132-
return origin_info.origin.block_id
133-
return None
134-
135121

136122
class BlockSizeTileStrategy(TileStrategy):
137123
def __init__(
@@ -260,7 +246,8 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
260246
block_size_var, offsets_var, total_numel, statements = self._codegen_common(
261247
state
262248
)
263-
dtype = CompileEnvironment.current().triton_index_type()
249+
env = CompileEnvironment.current()
250+
dtype = env.triton_index_type()
264251
state.add_statement(
265252
f"{offsets_var} = tl.program_id(0) * ({block_size_var}) + tl.arange(0, {block_size_var}).to({dtype})"
266253
)
@@ -276,7 +263,7 @@ def codegen_grid(self) -> ast.AST:
276263

277264
end_var_name = {}
278265
for block_id in self.block_ids:
279-
end_bound = CompileEnvironment.current().block_sizes[block_id].numel
266+
end_bound = env.block_sizes[block_id].numel
280267
end_var_name[block_id] = state.device_function.sympy_expr(end_bound)
281268
return DeviceGridState(self, end_var_name=end_var_name)
282269

@@ -313,12 +300,13 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
313300

314301
@classmethod
315302
def update_allow_flattened(cls, shape: Sequence[sympy.Expr]) -> None:
303+
env = CompileEnvironment.current()
316304
used_indices = {}
317305
for i, x in enumerate(shape):
318-
block_idx = cls.get_block_index(x)
306+
block_idx = env.get_block_id(x)
319307
if block_idx is not None:
320308
used_indices[block_idx] = i
321-
flatten_loops = CompileEnvironment.current().config_spec.flatten_loops
309+
flatten_loops = env.config_spec.flatten_loops
322310
for spec in [*flatten_loops]:
323311
block_ids = spec.block_ids
324312
if not (
@@ -405,7 +393,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
405393
)
406394
else:
407395
block_size_var = "1"
408-
dtype = CompileEnvironment.current().triton_index_type()
396+
dtype = env.triton_index_type()
409397
state.add_statement(f"{offset_var} = {pid_var}")
410398
state.add_statement(
411399
f"{index_var} = {offset_var} + tl.zeros([1], {dtype})"
@@ -428,7 +416,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
428416
# Extract end_var_name from end bound expressions
429417
end_var_name = {}
430418
for block_id in self.block_ids:
431-
end_bound = CompileEnvironment.current().block_sizes[block_id].numel
419+
end_bound = env.block_sizes[block_id].numel
432420
end_var_name[block_id] = state.device_function.sympy_expr(end_bound)
433421
return DeviceGridState(self, end_var_name=end_var_name)
434422

helion/language/_tracing_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .._compiler.compile_environment import CompileEnvironment
1414
from .._compiler.host_function import HostFunction
1515
from .._compiler.tile_index_proxy import TileIndexProxy
16-
from .._compiler.tile_strategy import TileStrategy
1716
from ..exc import NotInsideKernel
1817
from . import _decorators
1918

@@ -39,7 +38,7 @@ def _get_symnode(debug_name: str) -> int:
3938
def _(state: CodegenState) -> ast.AST:
4039
val = state.fx_node.meta["val"]
4140
assert isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)), val
42-
if (block_idx := TileStrategy.get_block_index(val)) is not None:
41+
if (block_idx := CompileEnvironment.current().get_block_id(val)) is not None:
4342
if state.device_function.block_size_var(block_idx) is None:
4443
# this should be unused
4544
return expr_from_string("block_size_var_optimized_away")
@@ -261,7 +260,7 @@ def _(state: CodegenState) -> ast.AST:
261260
mask_exprs = []
262261
input_sizes = [*tensor.size()]
263262
for dim, size in enumerate(input_sizes):
264-
if (index := TileStrategy.get_block_index(size)) is not None and (
263+
if (index := CompileEnvironment.current().get_block_id(size)) is not None and (
265264
mask_var := state.codegen.mask_var(index)
266265
) is not None:
267266
expand = state.tile_strategy.expand_str(input_sizes, dim)

helion/language/loops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ def _(
211211
elif isinstance(bs, int):
212212
results.append(TileIndexType.allocate_fixed(size, bs, origin))
213213
elif isinstance(bs, torch.SymInt):
214-
from helion._compiler.tile_strategy import TileStrategy
214+
from helion._compiler.compile_environment import CompileEnvironment
215215

216-
index = TileStrategy.get_block_index(bs)
216+
index = CompileEnvironment.current().get_block_id(bs)
217217
if index is None:
218218
results.append(TileIndexType.allocate_fixed(size, bs, origin))
219219
else:

0 commit comments

Comments
 (0)