Skip to content

Commit 6da8099

Browse files
authored
Add additional tl.range choices to persistent loop (#287)
1 parent a352384 commit 6da8099

20 files changed

+662
-228
lines changed

helion/_compiler/program_id.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class PIDInfo(NamedTuple):
2424
pid_var: str
2525
block_size_var: str
2626
numel: sympy.Expr
27+
block_id: int
2728

2829
def num_pids_expr(self, *, is_device: bool) -> str:
2930
"""Get the number of PIDs expression for device or host."""
@@ -347,11 +348,16 @@ def __init__(self, is_blocked: bool = False) -> None:
347348
self.block_size_var: str = device_function.new_var("block_size")
348349
self.start_pid_var: str = device_function.new_var("start_pid")
349350
self.end_pid_var: str = device_function.new_var("end_pid")
350-
self.range_expr: str = f"tl.range({self.start_pid_var}, {self.end_pid_var})"
351+
self.range_kwargs: dict[str, str] = {
352+
"begin": self.start_pid_var,
353+
"end": self.end_pid_var,
354+
}
351355
else:
352-
self.range_expr: str = (
353-
f"tl.range(tl.program_id(0), {self.total_pids_var}, {NUM_SM_VAR})"
354-
)
356+
self.range_kwargs: dict[str, str] = {
357+
"begin": "tl.program_id(0)",
358+
"end": self.total_pids_var,
359+
"step": NUM_SM_VAR,
360+
}
355361
if device_function.constexpr_arg(NUM_SM_VAR):
356362
device = CompileEnvironment.current().device
357363
device_function.codegen.host_statements.append(
@@ -402,8 +408,18 @@ def setup_persistent_kernel(
402408
)
403409

404410
device_function.preamble.extend(setup_statements)
411+
# Collect all block IDs from PID info for range configuration
412+
pid_block_ids = []
413+
for pid_info in self.pid_info:
414+
pid_block_ids.append(pid_info.block_id)
415+
416+
from .tile_strategy import TileStrategy
417+
418+
range_expr = TileStrategy.get_range_call_str(
419+
device_function.config, pid_block_ids, **self.range_kwargs
420+
)
405421
return self._setup_persistent_kernel_and_wrap_body(
406-
device_function, self.virtual_pid_var, self.range_expr, total_pids_expr
422+
device_function, self.virtual_pid_var, range_expr, total_pids_expr
407423
)
408424

409425
def _is_persistent(self) -> bool:

helion/_compiler/reduction_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
258258
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
259259
iter=expr_from_string(
260260
self.get_range_call_str(
261-
state,
261+
state.config,
262262
[self.block_index],
263263
begin="0",
264264
end=state.sympy_expr(numel),

helion/_compiler/tile_strategy.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
if TYPE_CHECKING:
3434
from collections.abc import Sequence
3535

36+
from ..runtime.config import Config
3637
from .device_function import DeviceFunction
3738
from .inductor_lowering import CodegenState
3839

@@ -125,45 +126,46 @@ def mask_var(self, block_idx: int) -> str | None:
125126
def block_size_var(self, block_idx: int) -> str | None:
126127
return self.fn.block_size_var_cache.get((block_idx,))
127128

128-
def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> list[str]:
129+
@staticmethod
130+
def get_tl_range_kwargs(config: Config, block_idx: int) -> list[str]:
129131
"""Get the range_extra string for loop unroll factor and num_stages based on config."""
130132
env = CompileEnvironment.current()
131133
kwargs = []
132134

133135
range_unroll_factor = env.config_spec.range_unroll_factors.config_get(
134-
state.config.range_unroll_factors, block_idx, 0
136+
config.range_unroll_factors, block_idx, 0
135137
)
136138
if range_unroll_factor > 0:
137139
kwargs.append(f"loop_unroll_factor={range_unroll_factor}")
138140

139141
range_warp_specialize = env.config_spec.range_warp_specialize.config_get(
140-
state.config.range_warp_specializes, block_idx, None
142+
config.range_warp_specializes, block_idx, None
141143
)
142144
if range_warp_specialize is not None:
143145
kwargs.append(f"warp_specialize={range_warp_specialize}")
144146

145147
range_num_stages = env.config_spec.range_num_stages.config_get(
146-
state.config.range_num_stages, block_idx, 0
148+
config.range_num_stages, block_idx, 0
147149
)
148150
if range_num_stages > 0:
149151
kwargs.append(f"num_stages={range_num_stages}")
150152

151153
range_multi_buffer = env.config_spec.range_multi_buffers.config_get(
152-
state.config.range_multi_buffers, block_idx, None
154+
config.range_multi_buffers, block_idx, None
153155
)
154156
if range_multi_buffer is not None:
155157
kwargs.append(f"disallow_acc_multi_buffer={not range_multi_buffer}")
156158

157159
range_flatten = env.config_spec.range_flattens.config_get(
158-
state.config.range_flattens, block_idx, None
160+
config.range_flattens, block_idx, None
159161
)
160162
if range_flatten is not None:
161163
kwargs.append(f"flatten={range_flatten}")
162164
return kwargs
163165

166+
@staticmethod
164167
def get_range_call_str(
165-
self,
166-
state: CodegenState,
168+
config: Config,
167169
block_ids: list[int],
168170
*,
169171
begin: str | None = None,
@@ -173,7 +175,7 @@ def get_range_call_str(
173175
env = CompileEnvironment.current()
174176
use_static_range = all(
175177
env.config_spec.static_ranges.config_get(
176-
state.config.static_ranges, block_idx, None
178+
config.static_ranges, block_idx, None
177179
)
178180
is True
179181
for block_idx in block_ids
@@ -183,12 +185,12 @@ def get_range_call_str(
183185
if begin is not None:
184186
range_args.append(begin)
185187
range_args.append(end)
186-
if step is not None:
187-
range_args.append(f"step={step}")
188+
if step is not None and step != "1":
189+
range_args.append(step)
188190

189191
if use_static_range:
190192
return f"tl.static_range({', '.join(range_args)})"
191-
range_kwargs = self.get_tl_range_kwargs(state, block_ids[0])
193+
range_kwargs = TileStrategy.get_tl_range_kwargs(config, block_ids[0])
192194
return f"tl.range({', '.join(range_args + range_kwargs)})"
193195

194196
def user_size(self, block_index: int) -> sympy.Expr:
@@ -439,7 +441,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
439441
ast.For,
440442
target=create(ast.Name, id=lid, ctx=ast.Store()),
441443
iter=expr_from_string(
442-
self.get_range_call_str(state, self.block_ids, end=end_var)
444+
self.get_range_call_str(state.config, self.block_ids, end=end_var)
443445
),
444446
body=(
445447
body := [
@@ -577,7 +579,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
577579
)
578580
if mask_statement is not None:
579581
state.add_statement(mask_statement)
580-
pid = PIDInfo(pid_var, block_size_var, numel)
582+
pid = PIDInfo(pid_var, block_size_var, numel, block_idx)
581583
pids.append(pid)
582584
pids.codegen(state)
583585
if isinstance(state.device_function.pid, ForEachProgramID):
@@ -656,7 +658,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
656658
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
657659
iter=expr_from_string(
658660
self.get_range_call_str(
659-
state,
661+
state.config,
660662
[block_idx],
661663
begin="begin",
662664
end="end",

helion/autotuner/config_fragment.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ def random(self) -> object:
126126
def differential_mutation(self, a: object, b: object, c: object) -> object:
127127
if b == c:
128128
return a
129-
for candidate in random.sample(self.choices, 2):
130-
if candidate != a:
131-
return candidate
132-
return self.random() # only reachable with duplicate choices
129+
choices = [b, c]
130+
if a in choices:
131+
choices.remove(a)
132+
return random.choice(choices)
133133

134134

135135
class BooleanFragment(ConfigSpecFragment):

helion/autotuner/config_spec.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
import operator
66
from typing import TYPE_CHECKING
7+
from typing import cast
78

89
from torch._inductor.runtime.runtime_utils import next_power_of_2
910

@@ -95,6 +96,7 @@ class ConfigSpec:
9596
allowed_pid_types: tuple[PidTypeLiteral, ...] = dataclasses.field(
9697
default_factory=functools.partial(tuple, VALID_PID_TYPES)
9798
)
99+
grid_block_ids: list[int] = dataclasses.field(default_factory=list)
98100

99101
@staticmethod
100102
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
@@ -165,27 +167,33 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
165167
name, config.get(name, ()), flatten=flatten
166168
)
167169

168-
static_range_block_ids = []
169-
for block_id in self.static_ranges.valid_block_ids():
170-
use_static_range = self.static_ranges.config_get(
171-
config.get( # pyright: ignore[reportArgumentType]
172-
"static_ranges", ()
173-
),
170+
# Disable range_* configs for static ranges
171+
static_range_block_ids = [
172+
block_id
173+
for block_id in self.static_ranges.valid_block_ids()
174+
if self.static_ranges.config_get(
175+
cast("list[bool]", config.get("static_ranges", [])),
174176
block_id,
175177
)
176-
if use_static_range:
177-
static_range_block_ids.append(block_id)
178-
179-
for name, mapping in (
180-
("range_unroll_factors", self.range_unroll_factors),
181-
("range_warp_specializes", self.range_warp_specialize),
182-
("range_num_stages", self.range_num_stages),
183-
("range_multi_buffers", self.range_multi_buffers),
184-
("range_flattens", self.range_flattens),
185-
):
186-
config[name] = mapping._reset_config_to_default(
187-
name, config.get(name, ()), block_ids=static_range_block_ids
188-
)
178+
]
179+
if static_range_block_ids:
180+
for name, mapping in (
181+
("range_unroll_factors", self.range_unroll_factors),
182+
("range_warp_specializes", self.range_warp_specialize),
183+
("range_num_stages", self.range_num_stages),
184+
("range_multi_buffers", self.range_multi_buffers),
185+
("range_flattens", self.range_flattens),
186+
):
187+
config[name] = mapping._reset_config_to_default(
188+
name, config.get(name, ()), block_ids=static_range_block_ids
189+
)
190+
191+
# Only one range_warp_specializes is allowed, take the last one
192+
range_warp_specializes = cast(
193+
"list[bool | None]", config.get("range_warp_specializes", [])
194+
)
195+
for i in [j for j, val in enumerate(range_warp_specializes) if val][:-1]:
196+
range_warp_specializes[i] = None
189197

190198
for name in (
191199
"loop_orders",
@@ -218,6 +226,20 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
218226
else:
219227
config[name] = values[0]
220228

229+
# Set default values for grid indices when pid_type is not persistent
230+
pid_type = config["pid_type"]
231+
if pid_type in ("flat", "xyz") and self.grid_block_ids:
232+
for name, mapping in (
233+
("range_unroll_factors", self.range_unroll_factors),
234+
("range_warp_specializes", self.range_warp_specialize),
235+
("range_num_stages", self.range_num_stages),
236+
("range_multi_buffers", self.range_multi_buffers),
237+
("range_flattens", self.range_flattens),
238+
):
239+
config[name] = mapping._reset_config_to_default(
240+
name, config.get(name, ()), block_ids=self.grid_block_ids
241+
)
242+
221243
# Allow tunable parameter keys in addition to VALID_KEYS
222244
allowed_keys = VALID_KEYS | {*self.user_defined_tunables.keys()}
223245
if invalid_keys := ({*config} - allowed_keys):
@@ -263,6 +285,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
263285
):
264286
if not config[name]:
265287
config.pop(name)
288+
self.normalize(config)
266289
return helion.Config(**config)
267290

268291

0 commit comments

Comments
 (0)