Skip to content

Commit f7ed720

Browse files
authored
Add hl.register_reduction_dim(); add support for matmul+layernorm example (#80)
1 parent de9ab5c commit f7ed720

File tree

12 files changed

+798
-28
lines changed

12 files changed

+798
-28
lines changed

examples/matmul_layernorm.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
import helion.language as hl
7+
8+
9+
# static_shapes=True gives a performance boost for matmuls
10+
@helion.kernel(static_shapes=True)
11+
def matmul_layernorm(
12+
x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
13+
) -> torch.Tensor:
14+
m, k = x.size()
15+
k2 = y.size(0)
16+
n = hl.register_reduction_dim(y.size(1))
17+
assert k == k2, f"size mismatch {k} != {k2}"
18+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"
19+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {n}"
20+
out = torch.empty(
21+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
22+
)
23+
for tile_m in hl.tile(m):
24+
acc = hl.zeros([tile_m, n], dtype=torch.float32)
25+
for tile_k in hl.tile(k):
26+
mm = torch.matmul(x[tile_m, tile_k], y[tile_k, :])
27+
acc = acc + mm
28+
eps = 1e-5
29+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
30+
normalized = (acc - mean) * torch.rsqrt(var + eps)
31+
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
32+
out[tile_m, :] = acc
33+
return out
34+
35+
36+
def matmul_layernorm_pytorch(
37+
x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
38+
) -> torch.Tensor:
39+
import torch.nn.functional as F
40+
41+
matmul_out = torch.matmul(x, y)
42+
43+
ln_out = F.layer_norm(
44+
matmul_out.to(torch.float32),
45+
normalized_shape=(matmul_out.shape[-1],),
46+
weight=weight.to(torch.float32),
47+
bias=bias.to(torch.float32),
48+
)
49+
50+
return ln_out.to(torch.promote_types(x.dtype, y.dtype))
51+
52+
53+
def check(m: int, k: int, n: int) -> None:
54+
from triton.testing import do_bench
55+
56+
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
57+
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
58+
weight = torch.randn([n], device="cuda", dtype=torch.float16)
59+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
60+
result = matmul_layernorm(x, y, weight, bias)
61+
expected = matmul_layernorm_pytorch(x, y, weight, bias)
62+
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-1)
63+
sec = do_bench(lambda: matmul_layernorm(x, y, weight, bias))
64+
baseline_sec = do_bench(lambda: matmul_layernorm_pytorch(x, y, weight, bias))
65+
print(
66+
f"Helion time: {sec:.4f}s, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
67+
)
68+
69+
70+
def main() -> None:
71+
# TODO(yf225): n=64 or 128 throws error, need to investigate
72+
# check(32, 64, 64)
73+
# check(32, 64, 128)
74+
check(32, 64, 200)
75+
check(128, 256, 400)
76+
77+
78+
if __name__ == "__main__":
79+
main()

helion/_compiler/compile_environment.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,24 @@ def allocate_block_size(
121121
return idx
122122

123123
def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInfo:
124+
# Check if this size is already a registered block size
125+
if isinstance(size, torch.SymInt):
126+
from .host_function import HostFunction
127+
128+
expr = size._sympy_()
129+
origin_info = HostFunction.current().expr_to_origin.get(expr)
130+
if origin_info and isinstance(origin_info.origin, BlockSizeOrigin):
131+
block_idx = origin_info.origin.block_id
132+
# Return the existing block size if it's a reduction dimension
133+
if self.block_sizes[block_idx].reduction:
134+
return self.block_sizes[block_idx]
135+
136+
# Check for existing reduction dimensions with the same size
124137
for rdim in self.block_sizes:
125138
if rdim.reduction and rdim.size == size:
126139
return rdim
140+
141+
# Allocate a new reduction dimension
127142
rdim_idx = self.allocate_block_size(
128143
size,
129144
reduction=True,

helion/_compiler/device_ir.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,21 @@ def build_rolled_reductions(self) -> None:
313313
for rdim in rdims:
314314
graph_to_info = {}
315315
allow_loop = False
316+
317+
# First, check if any graph contains matmul with rdim
318+
# If so, we can't roll any graphs in this reduction dimension
319+
can_roll_graphs = True
320+
for graph_info in self.graphs:
321+
roller = ReductionRoller(self, rdim, {})
322+
if roller.has_matmul_with_rdim(graph_info.graph):
323+
can_roll_graphs = False
324+
break
325+
326+
if not can_roll_graphs:
327+
first = False
328+
continue
329+
330+
# Process graphs normally
316331
for graph_id, graph_info in enumerate([*self.graphs]):
317332
assert graph_id == graph_info.graph_id
318333
roller = ReductionRoller(self, rdim, graph_to_info)
@@ -705,6 +720,19 @@ def visit_Assign(self, node: ast.Assign) -> None:
705720
# TODO(jansel): should assert that name is only used on device
706721
self._assign(target, self.visit(node.value))
707722
return None
723+
if isinstance(target, ast.Tuple):
724+
# Handle tuple unpacking
725+
value = self.visit(node.value)
726+
if not isinstance(value, tuple):
727+
raise exc.InvalidAssignment
728+
if len(target.elts) != len(value):
729+
raise exc.InvalidAssignment
730+
for t, v in zip(target.elts, value, strict=True):
731+
if isinstance(t, ast.Name):
732+
self._assign(t, v)
733+
else:
734+
raise exc.InvalidAssignment
735+
return None
708736
if not isinstance(target, ast.Subscript):
709737
raise exc.InvalidAssignment
710738
assert isinstance(node.value, ExtendedAST)

helion/_compiler/inductor_lowering.py

Lines changed: 118 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import functools
77
from operator import getitem
88
from typing import TYPE_CHECKING
9+
from typing import Callable
910
from typing import ContextManager
1011
from typing import NamedTuple
1112

@@ -148,21 +149,33 @@ def convert_arg(arg: Node) -> TensorBox:
148149
# pyre-ignore[6]
149150
*map_arg((node.args, node.kwargs), convert_arg),
150151
)
151-
result.realize()
152-
if not isinstance(result, TensorBox) or not isinstance(result.data, StorageBox):
153-
raise InductorLoweringError(
154-
f"Lowering {node.target} returned type(result), expected TensorBox(StorageBox(...)): {result}"
155-
)
156-
if not isinstance(buffer := result.data.data, ComputedBuffer):
157-
raise InductorLoweringError(
158-
f"Lowering {node.target} returned buffer type {type(buffer)}, expected ComputedBuffer: {buffer}"
159-
)
152+
if not isinstance(result, tuple):
153+
result = (result,)
154+
buffer_name_to_output_index = {}
155+
for i, r in enumerate(result):
156+
r.realize()
157+
if not isinstance(r, TensorBox) or not isinstance(r.data, StorageBox):
158+
raise InductorLoweringError(
159+
f"Lowering {node.target} returned {type(r)}, expected TensorBox(StorageBox(...)): {r}"
160+
)
161+
if not isinstance(buffer := r.data.data, ComputedBuffer):
162+
raise InductorLoweringError(
163+
f"Lowering {node.target} returned buffer type {type(buffer)}, expected ComputedBuffer: {buffer}"
164+
)
165+
buffer_name_to_output_index[buffer.get_name()] = i
160166

161167
new_buffers = graph_lowering.buffers[prior_buffers:]
162-
assert new_buffers[-1] is buffer
168+
assert buffer in new_buffers # pyre-ignore[61]
163169
nodes = []
164170
extra_input_names = []
165171
new_node: torch.fx.Node
172+
173+
# Explicitly track the mapping from node to Inductor buffer name.
174+
# First, map the original input nodes to their names.
175+
node_to_buf_name_mapping: dict[torch.fx.Node, str] = dict(
176+
zip(node._input_nodes, input_names, strict=True)
177+
)
178+
166179
for i, buffer in enumerate(new_buffers):
167180
if not isinstance(buffer, ComputedBuffer) or not isinstance(
168181
buffer.data, (Pointwise, Reduction)
@@ -176,29 +189,49 @@ def convert_arg(arg: Node) -> TensorBox:
176189
new_node.kwargs = {**new_node.kwargs, "_extra_args": [*nodes]}
177190
else:
178191
new_node = create_extra_node(node, buffer, [*node._input_nodes, *nodes])
192+
193+
# Store output index if this buffer corresponds to an output
194+
if buffer.get_name() in buffer_name_to_output_index:
195+
new_node.meta["output_index"] = buffer_name_to_output_index[
196+
buffer.get_name()
197+
]
198+
179199
lowering_cls = (
180200
PointwiseLowering
181201
if isinstance(buffer.data, Pointwise)
182202
else ReductionLowering
183203
)
184204
buffer.freeze_layout()
205+
206+
current_input_nodes = new_node._input_nodes
207+
current_input_names = []
208+
for inp_node in current_input_nodes:
209+
current_input_names.append(node_to_buf_name_mapping[inp_node])
210+
185211
used_input_names = strip_unused_inputs(
186212
new_node,
187213
buffer.get_read_names(),
188-
dict(
189-
zip(
190-
node.all_input_nodes,
191-
[*input_names, *extra_input_names],
192-
strict=True,
193-
)
194-
),
214+
dict(zip(current_input_nodes, current_input_names, strict=True)),
195215
)
196216
new_node.meta["lowering"] = lowering = lowering_cls(buffer, used_input_names)
217+
new_node.meta["orig_node"] = node
197218
if isinstance(lowering, ReductionLowering):
198219
lowering.add_input_mask(new_node)
199220
nodes.append(new_node)
200221
extra_input_names.append(buffer.get_name())
201222

223+
# Add this node to our mapping for future nodes to reference
224+
node_to_buf_name_mapping[new_node] = buffer.get_name()
225+
226+
# After all nodes are created, build the output_nodes mapping for multi-output operations
227+
if len(result) > 1 and nodes:
228+
last_node = nodes[-1] # The last node is the main node
229+
output_nodes = {}
230+
for n in nodes:
231+
if "output_index" in n.meta:
232+
output_nodes[n.meta["output_index"]] = n.name
233+
last_node.meta["output_nodes"] = output_nodes
234+
202235

203236
def strip_unused_inputs(
204237
node: torch.fx.Node,
@@ -447,14 +480,23 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
447480
strategy = BlockReductionStrategy(state, self.block_index)
448481

449482
inputs = self.input_fake_tensors(node)
450-
if len(inputs) != 1:
451-
# TODO(jansel): combine multiple inputs into a single fake value
452-
raise NotImplementedError("reductions with >1 input")
483+
484+
repr_input = None
485+
if len(inputs) == 1:
486+
repr_input = inputs[0]
487+
else:
488+
if node.meta["orig_node"].target == torch.ops.aten.var_mean.correction:
489+
assert len(inputs) == 2
490+
# `inputs[0]` is the original input tensor to var_mean
491+
repr_input = inputs[0]
492+
else:
493+
# TODO(jansel): combine multiple inputs into a single fake value
494+
raise NotImplementedError("reductions with >1 input")
453495

454496
# TODO(jansel): find a better way to get dim
455497
(dim,) = [
456498
i
457-
for i, v in enumerate(inputs[0].shape)
499+
for i, v in enumerate(repr_input.shape)
458500
if TileStrategy.get_block_index(v) == self.block_index
459501
]
460502

@@ -463,7 +505,7 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
463505
output_name,
464506
reduction.reduction_type,
465507
dim,
466-
inputs[0],
508+
repr_input,
467509
node.meta["val"],
468510
)
469511

@@ -806,6 +848,14 @@ def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> str:
806848
name = self.cg.lift(
807849
expr_from_string(self.cg.device_function.user_sympy_expr(expr))
808850
).id
851+
852+
# If the lifted symbol refers to a `tl.constexpr` kernel
853+
# argument (for example a tile/block size constant such as
854+
# `_BLOCK_SIZE_1`) the resulting Triton value is not a tensor
855+
# and therefore does not expose a `.to` method.
856+
if name in self.cg.device_function._constexpr_args:
857+
return name
858+
809859
return f"{name}.to({triton_type(dtype)})"
810860

811861

@@ -821,11 +871,57 @@ def __init__(self, graph: torch.fx.Graph, cg: GenerateAST) -> None:
821871
super().__init__(_LazyGraphModule({}, graph), garbage_collect_values=False)
822872
self.cg = cg
823873

874+
def _collect_multi_outputs(
875+
self, node: Node, last_node_result: object
876+
) -> tuple[object, ...]:
877+
"""
878+
Collect outputs for multi-output operations using metadata.
879+
"""
880+
# Check if this operation has multiple outputs using the new metadata
881+
assert "output_nodes" in node.meta
882+
output_nodes = node.meta["output_nodes"]
883+
outputs = [None] * len(output_nodes)
884+
all_nodes = {n.name: n for n in self.module.graph.nodes} # pyre-ignore[16]
885+
886+
for idx, node_name in output_nodes.items():
887+
if node_name == node.name:
888+
# This is the last node
889+
outputs[idx] = last_node_result # pyre-ignore[6]
890+
else:
891+
# This is an extra node - get its result from env
892+
if node_name in all_nodes:
893+
extra_node = all_nodes[node_name]
894+
if extra_node in self.env:
895+
outputs[idx] = self.env[extra_node]
896+
897+
# Ensure all outputs are found and are ast.Name nodes
898+
final_outputs = []
899+
for i, result in enumerate(outputs):
900+
assert result is not None
901+
if not isinstance(result, ast.Name):
902+
var_name = self.cg.device_function.new_var(f"{node.name}_output{i}")
903+
self.cg.add_statement(
904+
statement_from_string(f"{var_name} = result", result=result)
905+
)
906+
result = create(ast.Name, id=var_name, ctx=ast.Load())
907+
final_outputs.append(result)
908+
909+
return tuple(final_outputs)
910+
824911
def run_node(self, n: Node) -> object:
825912
if n.op == "call_function":
826913
with self._set_current_node(n), n.meta["location"]:
827914
lowering: Lowering = n.meta["lowering"]
828915
result = lowering.codegen(self, n)
916+
n.meta["codegen"] = result
917+
918+
# Generic handling for operations with multiple outputs
919+
if n.kwargs.get("_extra_args"):
920+
# Check if this node has getitem users, indicating multiple outputs
921+
getitem_users = [user for user in n.users if user.target == getitem]
922+
if len(getitem_users) > 0:
923+
return self._collect_multi_outputs(n, result)
924+
829925
if result is None:
830926
return None
831927
if not isinstance(result, ast.AST):

0 commit comments

Comments
 (0)