Skip to content

Commit dc88062

Browse files
authored
Generalize l2_grouping to support 3+ dimensions (#313)
1 parent 366a3b3 commit dc88062

File tree

9 files changed

+524
-234
lines changed

9 files changed

+524
-234
lines changed

helion/_compiler/program_id.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def codegen(self, state: CodegenState) -> None:
266266
# Note: Persistent kernel setup is handled by ForEachProgramID if needed
267267
assert self.parent_strategy is not None
268268
parent_pids = self.parent_strategy.pid_info
269-
assert len(parent_pids) == 2
269+
assert len(parent_pids) >= 2, "L2 grouping requires at least 2 dimensions"
270270
new_var = state.device_function.new_var
271271

272272
# Use shared_pid_var if we're in a ForEachProgramID context, otherwise use virtual_program_id
@@ -275,26 +275,87 @@ def codegen(self, state: CodegenState) -> None:
275275
else:
276276
pid = self.virtual_program_id
277277

278-
num_pid_m = new_var("num_pid_m")
279-
num_pid_n = new_var("num_pid_n")
280-
num_pid_in_group = new_var("num_pid_in_group")
281-
group_id = new_var("group_id")
282-
first_pid_m = new_var("first_pid_m")
283-
group_size_m = new_var("group_size_m")
284-
285-
assignments = [
286-
(num_pid_m, parent_pids[0].num_pids_expr(is_device=True)),
287-
(num_pid_n, parent_pids[1].num_pids_expr(is_device=True)),
288-
(num_pid_in_group, f"{self.group_size} * {num_pid_n}"),
289-
(group_id, f"{pid} // {num_pid_in_group}"),
290-
(first_pid_m, f"{group_id} * {self.group_size}"),
291-
(group_size_m, f"min({num_pid_m} - {first_pid_m}, {self.group_size})"),
292-
(
293-
parent_pids[0].pid_var,
294-
f"{first_pid_m} + (({pid} % {num_pid_in_group}) % {group_size_m})",
295-
),
296-
(parent_pids[1].pid_var, f"({pid} % {num_pid_in_group}) // {group_size_m}"),
278+
# Apply L2 grouping to the 2 fastest varying dimensions (pid_0, pid_1)
279+
# These are always the first 2 dimensions in the PID decomposition
280+
num_dims = len(parent_pids)
281+
assignments = []
282+
283+
# Generate size variables for all dimensions (except the last which doesn't need one)
284+
num_blocks = []
285+
for i in range(num_dims - 1):
286+
num_block_var = new_var(f"num_blocks_{i}", dce=True)
287+
assignments.append(
288+
(num_block_var, parent_pids[i].num_pids_expr(is_device=True))
289+
)
290+
num_blocks.append(num_block_var)
291+
292+
# Apply L2 grouping to the 2 fastest varying dimensions (pid_0, pid_1)
293+
fastest_m_idx = 0 # pid_0 (fastest varying)
294+
fastest_n_idx = 1 # pid_1 (second fastest varying)
295+
296+
# Extract the 2D portion for the fastest 2 dimensions
297+
inner_2d_size = new_var("inner_2d_size", dce=True)
298+
inner_2d_pid = new_var("inner_2d_pid", dce=True)
299+
300+
num_pid_m = new_var("num_pid_m", dce=True)
301+
num_pid_n = new_var("num_pid_n", dce=True)
302+
num_pid_in_group = new_var("num_pid_in_group", dce=True)
303+
group_id = new_var("group_id", dce=True)
304+
first_pid_m = new_var("first_pid_m", dce=True)
305+
group_size_m = new_var("group_size_m", dce=True)
306+
307+
# Set up L2 grouping for the fastest 2 dimensions
308+
inner_2d_assignments = [
309+
(num_pid_m, parent_pids[fastest_m_idx].num_pids_expr(is_device=True)),
310+
(num_pid_n, parent_pids[fastest_n_idx].num_pids_expr(is_device=True)),
297311
]
312+
313+
# Only add modulo for 3D+ cases where we need to extract the 2D portion
314+
if num_dims > 2:
315+
inner_2d_assignments.extend(
316+
[
317+
(inner_2d_size, f"{num_pid_m} * {num_pid_n}"),
318+
(
319+
inner_2d_pid,
320+
f"{pid} % {inner_2d_size}",
321+
), # Extract fastest 2D portion
322+
]
323+
)
324+
else:
325+
# For 2D case, the entire PID space is the 2D space
326+
inner_2d_assignments.append((inner_2d_pid, pid))
327+
328+
assignments.extend(inner_2d_assignments)
329+
assignments.extend(
330+
[
331+
(num_pid_in_group, f"{self.group_size} * {num_pid_n}"),
332+
(group_id, f"{inner_2d_pid} // {num_pid_in_group}"),
333+
(first_pid_m, f"{group_id} * {self.group_size}"),
334+
(group_size_m, f"min({num_pid_m} - {first_pid_m}, {self.group_size})"),
335+
(
336+
parent_pids[fastest_m_idx].pid_var,
337+
f"{first_pid_m} + (({inner_2d_pid} % {num_pid_in_group}) % {group_size_m})",
338+
),
339+
(
340+
parent_pids[fastest_n_idx].pid_var,
341+
f"({inner_2d_pid} % {num_pid_in_group}) // {group_size_m}",
342+
),
343+
]
344+
)
345+
346+
# Process remaining dimensions (if any) using standard decomposition
347+
for i in range(2, num_dims):
348+
expr = pid
349+
# Add divisor for all faster dimensions
350+
if i > 0:
351+
divisor = " * ".join(num_blocks[:i])
352+
expr = f"({expr}) // ({divisor})"
353+
# Add modulo unless this is the outermost dimension
354+
if i + 1 < num_dims: # Not the outermost dimension
355+
expr = f"({expr}) % {num_blocks[i]}"
356+
357+
assignments.append((parent_pids[i].pid_var, expr))
358+
298359
statements = [
299360
statement_from_string(f"{var} = {expr}") for var, expr in assignments
300361
]

helion/language/loops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,8 @@ def _add_config_choices(
386386
config_spec.grid_block_ids.extend(
387387
[x for x in block_ids if x not in existing_ids]
388388
)
389-
if len(block_ids) == 2:
390-
# TODO(jansel): support L2 grouping with 3+ dims (and maybe non-grids?)
389+
if len(block_ids) >= 2:
390+
# L2 grouping now supports 3D+ grids by applying to innermost 2 dimensions
391391
config_spec.l2_groupings.append(L2GroupingSpec(block_ids))
392392
if not _allow_use_yz_grid(config_spec, block_ids):
393393
config_spec.disallow_pid_type("xyz")

test/test_autotuner.expected

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[8],
1414
helion.Config(block_sizes=[16, 32, 32], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[0, 2], range_warp_specializes=[None, None], range_num_stages=[0, 0], range_multi_buffers=[None, None], range_flattens=[None, True], num_warps=2, num_stages=6, indexing='tensor_descriptor', pid_type='flat')
1515

1616
--- assertExpectedJournal(TestAutotuner.test_config_fragment1)
17-
helion.Config(block_sizes=[8, 16, 16], loop_orders=[[0, 1, 2]], flatten_loops=[False], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat')
18-
helion.Config(block_sizes=[2, 128, 128], loop_orders=[[1, 2, 0]], flatten_loops=[False], range_unroll_factors=[2], range_warp_specializes=[None], range_num_stages=[4], range_multi_buffers=[False], range_flattens=[None], num_warps=8, num_stages=4, indexing='tensor_descriptor', pid_type='persistent_blocked')
19-
helion.Config(block_sizes=[2, 16, 4], loop_orders=[[0, 2, 1]], flatten_loops=[True], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=1, num_stages=2, indexing='tensor_descriptor', pid_type='flat')
20-
helion.Config(block_sizes=[8, 2, 512], loop_orders=[[0, 2, 1]], flatten_loops=[True], range_unroll_factors=[4], range_warp_specializes=[False], range_num_stages=[0], range_multi_buffers=[False], range_flattens=[True], num_warps=8, num_stages=3, indexing='block_ptr', pid_type='persistent_interleaved')
21-
helion.Config(block_sizes=[1, 16, 32], loop_orders=[[0, 2, 1]], flatten_loops=[False], range_unroll_factors=[4], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[None], range_flattens=[False], num_warps=8, num_stages=4, indexing='tensor_descriptor', pid_type='persistent_interleaved')
22-
helion.Config(block_sizes=[1, 32, 512], loop_orders=[[0, 2, 1]], flatten_loops=[False], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=2, num_stages=5, indexing='pointer', pid_type='flat')
23-
helion.Config(block_sizes=[1, 32, 32], loop_orders=[[1, 2, 0]], flatten_loops=[True], range_unroll_factors=[4], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[None], num_warps=16, num_stages=3, indexing='tensor_descriptor', pid_type='persistent_blocked')
24-
helion.Config(block_sizes=[1, 4, 32], loop_orders=[[1, 0, 2]], flatten_loops=[True], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=16, num_stages=6, indexing='block_ptr', pid_type='flat')
25-
helion.Config(block_sizes=[4, 16, 1], loop_orders=[[2, 1, 0]], flatten_loops=[True], range_unroll_factors=[2], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=4, num_stages=8, indexing='block_ptr', pid_type='persistent_interleaved')
26-
helion.Config(block_sizes=[8, 128, 4], loop_orders=[[1, 0, 2]], flatten_loops=[False], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=2, num_stages=4, indexing='tensor_descriptor', pid_type='flat')
17+
helion.Config(block_sizes=[8, 16, 16], loop_orders=[[0, 1, 2]], flatten_loops=[False], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat')
18+
helion.Config(block_sizes=[2, 128, 128], loop_orders=[[1, 2, 0]], flatten_loops=[False], l2_groupings=[4], range_unroll_factors=[1], range_warp_specializes=[True], range_num_stages=[3], range_multi_buffers=[None], range_flattens=[False], num_warps=16, num_stages=4, indexing='tensor_descriptor', pid_type='persistent_blocked')
19+
helion.Config(block_sizes=[4, 32, 8], loop_orders=[[0, 2, 1]], flatten_loops=[True], l2_groupings=[8], range_unroll_factors=[4], range_warp_specializes=[False], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[None], num_warps=1, num_stages=4, indexing='block_ptr', pid_type='persistent_blocked')
20+
helion.Config(block_sizes=[1, 512, 1], loop_orders=[[0, 2, 1]], flatten_loops=[True], l2_groupings=[1], range_unroll_factors=[2], range_warp_specializes=[True], range_num_stages=[3], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='pointer', pid_type='persistent_interleaved')
21+
helion.Config(block_sizes=[1, 8, 512], loop_orders=[[1, 0, 2]], flatten_loops=[True], l2_groupings=[8], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=2, indexing='tensor_descriptor', pid_type='flat')
22+
helion.Config(block_sizes=[4, 2, 128], loop_orders=[[0, 1, 2]], flatten_loops=[True], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[False], num_warps=4, num_stages=5, indexing='block_ptr', pid_type='persistent_blocked')
23+
helion.Config(block_sizes=[2, 16, 2], loop_orders=[[0, 2, 1]], flatten_loops=[True], l2_groupings=[64], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[True], num_warps=16, num_stages=4, indexing='block_ptr', pid_type='persistent_blocked')
24+
helion.Config(block_sizes=[4, 4, 1], loop_orders=[[1, 2, 0]], flatten_loops=[False], l2_groupings=[16], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[False], num_warps=8, num_stages=5, indexing='tensor_descriptor', pid_type='persistent_blocked')
25+
helion.Config(block_sizes=[4, 4, 16], loop_orders=[[1, 2, 0]], flatten_loops=[True], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[True], num_warps=8, num_stages=3, indexing='tensor_descriptor', pid_type='persistent_blocked')
26+
helion.Config(block_sizes=[4, 8, 8], loop_orders=[[2, 0, 1]], flatten_loops=[False], l2_groupings=[4], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=8, num_stages=5, indexing='tensor_descriptor', pid_type='flat')
2727

2828
--- assertExpectedJournal(TestAutotuner.test_save_load_config)
2929
{

0 commit comments

Comments
 (0)