Skip to content

Commit a1f6c83

Browse files
committed
Layer Norm fwd issue
ghstack-source-id: ca3cb11 Pull Request resolved: #170
1 parent b64bf00 commit a1f6c83

File tree

4 files changed

+127
-5
lines changed

4 files changed

+127
-5
lines changed

examples/layer_norm.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
import helion.language as hl
7+
8+
"""
9+
NOTE: layer_norm_fwd_ideal does not work! I am keeping this around as a reference
10+
to what I believed should have worked in Helion when I first began without debugging.
11+
12+
The user experience should be pushed this direction
13+
"""
14+
@helion.kernel(static_shapes=True)
15+
def layer_norm_fwd_ideal(
16+
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float = 1e-5
17+
) -> torch.Tensor:
18+
"""
19+
Layer normalization forward pass.
20+
21+
Args:
22+
x: Input tensor of shape [batch_size, hidden_size]
23+
weight: Scale parameter of shape [hidden_size]
24+
bias: Bias parameter of shape [hidden_size]
25+
eps: Epsilon for numerical stability
26+
27+
Returns:
28+
Normalized tensor of shape [batch_size, hidden_size]
29+
"""
30+
m = x.size(0)
31+
out = torch.empty_like(x)
32+
33+
for tile_b in hl.tile(m):
34+
row = x[tile_b]
35+
mean, var = torch.var_mean(row)
36+
37+
layer_norm_out = (row - mean) / torch.sqrt(var + eps)
38+
layer_norm_out = layer_norm_out * weight + bias
39+
out[tile_b, :] = layer_norm_out
40+
41+
return out
42+
43+
@helion.kernel(static_shapes=True, use_default_config=True)
44+
def layer_norm_fwd(
45+
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
46+
) -> torch.Tensor:
47+
m, n = x.size()
48+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
49+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
50+
out = torch.empty(
51+
[m, n], dtype=torch.float16, device=x.device
52+
)
53+
54+
eps = 1e-5
55+
56+
for tile_m in hl.tile(m):
57+
# acc = x[tile_m, :].to(torch.float32) works! We should not have to do this cast
58+
acc = x[tile_m, :]
59+
60+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
61+
62+
normalized = (acc - mean) * torch.rsqrt(var + eps)
63+
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
64+
65+
out[tile_m, :] = acc
66+
return out
67+
68+
69+
def check(batch_size: int, hidden_size: int) -> None:
70+
from triton.testing import do_bench
71+
72+
# Create random input tensors
73+
x = torch.randn([batch_size, hidden_size], device="cuda", dtype=torch.float16)
74+
weight = torch.randn([hidden_size], device="cuda", dtype=torch.float16)
75+
bias = torch.randn([hidden_size], device="cuda", dtype=torch.float16)
76+
77+
# Run Helion kernel
78+
result = layer_norm_fwd(x, weight, bias)
79+
80+
# # Run PyTorch layer norm for comparison
81+
torch_result = torch.nn.functional.layer_norm(
82+
x, [hidden_size], weight, bias, eps=1e-5
83+
)
84+
85+
# # Check correctness
86+
torch.testing.assert_close(result, torch_result, rtol=1e-2, atol=1e-1)
87+
88+
# Benchmark Helion implementation
89+
helion_sec = do_bench(lambda: layer_norm_fwd(x, weight, bias))
90+
91+
# Benchmark PyTorch implementation
92+
torch_sec = do_bench(lambda: torch.nn.functional.layer_norm(
93+
x, [hidden_size], weight, bias, eps=1e-5
94+
))
95+
96+
print(
97+
f"Helion time: {helion_sec:.4f}ms, torch time: {torch_sec:.4f}, speedup: {torch_sec / helion_sec:.2f}x"
98+
)
99+
100+
101+
def main() -> None:
102+
# Test with different sizes
103+
print("Testing batch_size=128, hidden_size=768")
104+
check(128, 768)
105+
106+
print("\nTesting batch_size=32, hidden_size=1024")
107+
check(32, 1024)
108+
109+
110+
if __name__ == "__main__":
111+
main()

helion/_compiler/ast_extension.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def visit(self, node: ast.AST) -> ast.AST:
222222
except exc.Base:
223223
raise
224224
except Exception as e:
225+
import pdb; pdb.set_trace()
225226
raise exc.InternalError(e) from e
226227

227228

helion/_compiler/inductor_lowering.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,20 +169,25 @@ def convert_arg(arg: Node) -> TensorBox:
169169
nodes = []
170170
extra_input_names = []
171171
new_node: torch.fx.Node
172-
172+
173+
read_buffer_names = set()
173174
# Explicitly track the mapping from node to Inductor buffer name.
174175
# First, map the original input nodes to their names.
175176
node_to_buf_name_mapping: dict[torch.fx.Node, str] = dict(
176177
zip(node._input_nodes, input_names, strict=True)
177178
)
178-
179+
179180
for i, buffer in enumerate(new_buffers):
180181
if not isinstance(buffer, ComputedBuffer) or not isinstance(
181182
buffer.data, (Pointwise, Reduction)
182183
):
183184
raise InductorLoweringError(
184185
f"Lowering {node.target} returned buffer type {type(buffer)}, expected ComputedBuffer(Pointwise|Reduction): {buffer}"
185186
)
187+
188+
for name in buffer.get_read_names():
189+
read_buffer_names.add(name)
190+
186191
if i == len(new_buffers) - 1:
187192
new_node = node
188193
if nodes:
@@ -191,6 +196,7 @@ def convert_arg(arg: Node) -> TensorBox:
191196
new_node = create_extra_node(node, buffer, [*node._input_nodes, *nodes])
192197

193198
# Store output index if this buffer corresponds to an output
199+
import pdb; pdb.set_trace()
194200
if buffer.get_name() in buffer_name_to_output_index:
195201
new_node.meta["output_index"] = buffer_name_to_output_index[
196202
buffer.get_name()
@@ -207,7 +213,7 @@ def convert_arg(arg: Node) -> TensorBox:
207213
current_input_names = []
208214
for inp_node in current_input_nodes:
209215
current_input_names.append(node_to_buf_name_mapping[inp_node])
210-
216+
211217
used_input_names = strip_unused_inputs(
212218
new_node,
213219
buffer.get_read_names(),
@@ -230,6 +236,7 @@ def convert_arg(arg: Node) -> TensorBox:
230236
for n in nodes:
231237
if "output_index" in n.meta:
232238
output_nodes[n.meta["output_index"]] = n.name
239+
import pdb; pdb.set_trace()
233240
last_node.meta["output_nodes"] = output_nodes
234241

235242

@@ -254,6 +261,8 @@ def mask_unused_inputs(n: torch.fx.Node) -> torch.fx.Node | None:
254261
return n
255262
return None
256263

264+
if node.name == "var_mean":
265+
import pdb; pdb.set_trace()
257266
assert len(input_names) == len(node._input_nodes)
258267
seen_names: dict[str, None] = {}
259268
node.args = map_arg(node.args, mask_unused_inputs)
@@ -878,11 +887,11 @@ def _collect_multi_outputs(
878887
Collect outputs for multi-output operations using metadata.
879888
"""
880889
# Check if this operation has multiple outputs using the new metadata
881-
assert "output_nodes" in node.meta
890+
assert "output_nodes" in node.meta, "Output nodes not in node.meta"
882891
output_nodes = node.meta["output_nodes"]
883892
outputs = [None] * len(output_nodes)
884893
all_nodes = {n.name: n for n in self.module.graph.nodes} # pyre-ignore[16]
885-
894+
import pdb; pdb.set_trace()
886895
for idx, node_name in output_nodes.items():
887896
if node_name == node.name:
888897
# This is the last node

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,7 @@ def visit_BinOp(self, node: ast.BinOp) -> TypeInfo:
17801780
except exc.Base:
17811781
raise
17821782
except Exception as e:
1783+
import pdb; pdb.set_trace()
17831784
raise exc.TorchOpTracingError(e) from e
17841785

17851786
if isinstance(left, UnknownType):

0 commit comments

Comments
 (0)