Skip to content

Commit 53f50e6

Browse files
committed
Layer Norm fwd issue
ghstack-source-id: a9859b5 Pull Request resolved: #170
1 parent b64bf00 commit 53f50e6

File tree

4 files changed

+143
-11
lines changed

4 files changed

+143
-11
lines changed

examples/layer_norm.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
import helion.language as hl
7+
8+
"""
9+
NOTE: layer_norm_fwd_ideal does not work! I am keeping this around as a reference
10+
to what I believed should have worked in Helion when I first began without debugging.
11+
12+
The user experience should be pushed this direction
13+
"""
14+
@helion.kernel(static_shapes=True)
15+
def layer_norm_fwd_ideal(
16+
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float = 1e-5
17+
) -> torch.Tensor:
18+
"""
19+
Layer normalization forward pass.
20+
21+
Args:
22+
x: Input tensor of shape [batch_size, hidden_size]
23+
weight: Scale parameter of shape [hidden_size]
24+
bias: Bias parameter of shape [hidden_size]
25+
eps: Epsilon for numerical stability
26+
27+
Returns:
28+
Normalized tensor of shape [batch_size, hidden_size]
29+
"""
30+
m = x.size(0)
31+
out = torch.empty_like(x)
32+
33+
for tile_b in hl.tile(m):
34+
row = x[tile_b]
35+
mean, var = torch.var_mean(row)
36+
37+
layer_norm_out = (row - mean) / torch.sqrt(var + eps)
38+
layer_norm_out = layer_norm_out * weight + bias
39+
out[tile_b, :] = layer_norm_out
40+
41+
return out
42+
43+
@helion.kernel(static_shapes=True, use_default_config=True)
44+
def layer_norm_fwd(
45+
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
46+
) -> torch.Tensor:
47+
m, n = x.size()
48+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
49+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
50+
out = torch.empty(
51+
[m, n], dtype=torch.float16, device=x.device
52+
)
53+
54+
eps = 1e-5
55+
56+
for tile_m in hl.tile(m):
57+
# acc = x[tile_m, :].to(torch.float32) works! We should not have to do this cast
58+
acc = x[tile_m, :]
59+
60+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
61+
62+
normalized = (acc - mean) * torch.rsqrt(var + eps)
63+
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
64+
65+
out[tile_m, :] = acc
66+
return out
67+
68+
69+
def check(batch_size: int, hidden_size: int) -> None:
70+
from triton.testing import do_bench
71+
72+
# Create random input tensors
73+
x = torch.randn([batch_size, hidden_size], device="cuda", dtype=torch.float16)
74+
weight = torch.randn([hidden_size], device="cuda", dtype=torch.float16)
75+
bias = torch.randn([hidden_size], device="cuda", dtype=torch.float16)
76+
77+
# Run Helion kernel
78+
result = layer_norm_fwd(x, weight, bias)
79+
80+
# # Run PyTorch layer norm for comparison
81+
torch_result = torch.nn.functional.layer_norm(
82+
x, [hidden_size], weight, bias, eps=1e-5
83+
)
84+
85+
# # Check correctness
86+
torch.testing.assert_close(result, torch_result, rtol=1e-2, atol=1e-1)
87+
88+
# Benchmark Helion implementation
89+
helion_sec = do_bench(lambda: layer_norm_fwd(x, weight, bias))
90+
91+
# Benchmark PyTorch implementation
92+
torch_sec = do_bench(lambda: torch.nn.functional.layer_norm(
93+
x, [hidden_size], weight, bias, eps=1e-5
94+
))
95+
96+
print(
97+
f"Helion time: {helion_sec:.4f}ms, torch time: {torch_sec:.4f}, speedup: {torch_sec / helion_sec:.2f}x"
98+
)
99+
100+
101+
def main() -> None:
102+
# Test with different sizes
103+
print("Testing batch_size=128, hidden_size=768")
104+
check(128, 768)
105+
106+
print("\nTesting batch_size=32, hidden_size=1024")
107+
check(32, 1024)
108+
109+
110+
if __name__ == "__main__":
111+
main()

helion/_compiler/ast_extension.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import enum
5+
import logging
56
import threading
67
import typing
78
from typing import TYPE_CHECKING
@@ -10,6 +11,8 @@
1011
from .. import exc
1112
from .source_location import SourceLocation
1213
from .source_location import current_location
14+
import sys
15+
import traceback
1316

1417
if TYPE_CHECKING:
1518
from collections.abc import Sequence
@@ -222,7 +225,9 @@ def visit(self, node: ast.AST) -> ast.AST:
222225
except exc.Base:
223226
raise
224227
except Exception as e:
225-
raise exc.InternalError(e) from e
228+
logging.error(f"Original Error: {str(e)}")
229+
traceback.print_tb(e.__traceback__)
230+
raise
226231

227232

228233
# Determine whether vanilla ast.unparse keeps parentheses in "(a, b) = c".

helion/_compiler/inductor_lowering.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def convert_arg(arg: Node) -> TensorBox:
149149
# pyre-ignore[6]
150150
*map_arg((node.args, node.kwargs), convert_arg),
151151
)
152+
# Previously it was (buf3, buf1), associating buf3 with variance and buf1 with mean
153+
# Now it is (buf3, buf4), so buf4 (the .to_dtype) is the output node, variance gets lost in buf3
152154
if not isinstance(result, tuple):
153155
result = (result,)
154156
buffer_name_to_output_index = {}
@@ -165,17 +167,18 @@ def convert_arg(arg: Node) -> TensorBox:
165167
buffer_name_to_output_index[buffer.get_name()] = i
166168

167169
new_buffers = graph_lowering.buffers[prior_buffers:]
170+
168171
assert buffer in new_buffers # pyre-ignore[61]
169172
nodes = []
170173
extra_input_names = []
171174
new_node: torch.fx.Node
172-
175+
173176
# Explicitly track the mapping from node to Inductor buffer name.
174177
# First, map the original input nodes to their names.
175178
node_to_buf_name_mapping: dict[torch.fx.Node, str] = dict(
176179
zip(node._input_nodes, input_names, strict=True)
177180
)
178-
181+
179182
for i, buffer in enumerate(new_buffers):
180183
if not isinstance(buffer, ComputedBuffer) or not isinstance(
181184
buffer.data, (Pointwise, Reduction)
@@ -207,18 +210,27 @@ def convert_arg(arg: Node) -> TensorBox:
207210
current_input_names = []
208211
for inp_node in current_input_nodes:
209212
current_input_names.append(node_to_buf_name_mapping[inp_node])
213+
210214

211-
used_input_names = strip_unused_inputs(
212-
new_node,
213-
buffer.get_read_names(),
214-
dict(zip(current_input_nodes, current_input_names, strict=True)),
215-
)
215+
if i != len(new_buffers) - 1:
216+
used_input_names = strip_unused_inputs(
217+
new_node,
218+
buffer.get_read_names(),
219+
dict(zip(current_input_nodes, current_input_names, strict=True)),
220+
)
221+
else:
222+
used_input_names = strip_unused_inputs(
223+
new_node,
224+
set(current_input_names[1:]),
225+
dict(zip(current_input_nodes, current_input_names, strict=True)),
226+
)
227+
216228
new_node.meta["lowering"] = lowering = lowering_cls(buffer, used_input_names)
217229
new_node.meta["orig_node"] = node
218230
if isinstance(lowering, ReductionLowering):
219231
lowering.add_input_mask(new_node)
232+
220233
nodes.append(new_node)
221-
extra_input_names.append(buffer.get_name())
222234

223235
# Add this node to our mapping for future nodes to reference
224236
node_to_buf_name_mapping[new_node] = buffer.get_name()
@@ -230,6 +242,7 @@ def convert_arg(arg: Node) -> TensorBox:
230242
for n in nodes:
231243
if "output_index" in n.meta:
232244
output_nodes[n.meta["output_index"]] = n.name
245+
import pdb; pdb.set_trace()
233246
last_node.meta["output_nodes"] = output_nodes
234247

235248

@@ -489,6 +502,7 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
489502
assert len(inputs) == 2
490503
# `inputs[0]` is the original input tensor to var_mean
491504
repr_input = inputs[0]
505+
import pdb; pdb.set_trace()
492506
else:
493507
# TODO(jansel): combine multiple inputs into a single fake value
494508
raise NotImplementedError("reductions with >1 input")
@@ -878,11 +892,11 @@ def _collect_multi_outputs(
878892
Collect outputs for multi-output operations using metadata.
879893
"""
880894
# Check if this operation has multiple outputs using the new metadata
881-
assert "output_nodes" in node.meta
895+
assert "output_nodes" in node.meta, "Output nodes not in node.meta"
882896
output_nodes = node.meta["output_nodes"]
883897
outputs = [None] * len(output_nodes)
884898
all_nodes = {n.name: n for n in self.module.graph.nodes} # pyre-ignore[16]
885-
899+
import pdb; pdb.set_trace()
886900
for idx, node_name in output_nodes.items():
887901
if node_name == node.name:
888902
# This is the last node
@@ -896,6 +910,7 @@ def _collect_multi_outputs(
896910

897911
# Ensure all outputs are found and are ast.Name nodes
898912
final_outputs = []
913+
import pdb; pdb.set_trace()
899914
for i, result in enumerate(outputs):
900915
assert result is not None
901916
if not isinstance(result, ast.Name):

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,7 @@ def visit_BinOp(self, node: ast.BinOp) -> TypeInfo:
17801780
except exc.Base:
17811781
raise
17821782
except Exception as e:
1783+
import pdb; pdb.set_trace()
17831784
raise exc.TorchOpTracingError(e) from e
17841785

17851786
if isinstance(left, UnknownType):

0 commit comments

Comments
 (0)