Skip to content

Layer Norm fwd issue #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/PaulZhang12/1/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions examples/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import annotations

import torch

import helion
import helion.language as hl

"""
NOTE: layer_norm_fwd_ideal does not work! I am keeping this around as a reference
to what I believed should have worked in Helion when I first began without debugging.

The user experience should be pushed this direction
"""
@helion.kernel(static_shapes=True)
def layer_norm_fwd_ideal(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float = 1e-5
) -> torch.Tensor:
"""
Layer normalization forward pass.

Args:
x: Input tensor of shape [batch_size, hidden_size]
weight: Scale parameter of shape [hidden_size]
bias: Bias parameter of shape [hidden_size]
eps: Epsilon for numerical stability

Returns:
Normalized tensor of shape [batch_size, hidden_size]
"""
m = x.size(0)
out = torch.empty_like(x)

for tile_b in hl.tile(m):
row = x[tile_b]
mean, var = torch.var_mean(row)

layer_norm_out = (row - mean) / torch.sqrt(var + eps)
layer_norm_out = layer_norm_out * weight + bias
out[tile_b, :] = layer_norm_out

return out

@helion.kernel(static_shapes=True, use_default_config=True)
def layer_norm_fwd(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
m, n = x.size()
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
out = torch.empty(
[m, n], dtype=torch.float16, device=x.device
)

eps = 1e-5

for tile_m in hl.tile(m):
# acc = x[tile_m, :].to(torch.float32) works! We should not have to do this cast
acc = x[tile_m, :]

var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)

normalized = (acc - mean) * torch.rsqrt(var + eps)
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))

out[tile_m, :] = acc
return out


def check(batch_size: int, hidden_size: int) -> None:
from triton.testing import do_bench

# Create random input tensors
x = torch.randn([batch_size, hidden_size], device="cuda", dtype=torch.float16)
weight = torch.randn([hidden_size], device="cuda", dtype=torch.float16)
bias = torch.randn([hidden_size], device="cuda", dtype=torch.float16)

# Run Helion kernel
result = layer_norm_fwd(x, weight, bias)

# # Run PyTorch layer norm for comparison
torch_result = torch.nn.functional.layer_norm(
x, [hidden_size], weight, bias, eps=1e-5
)

# # Check correctness
torch.testing.assert_close(result, torch_result, rtol=1e-2, atol=1e-1)

# Benchmark Helion implementation
helion_sec = do_bench(lambda: layer_norm_fwd(x, weight, bias))

# Benchmark PyTorch implementation
torch_sec = do_bench(lambda: torch.nn.functional.layer_norm(
x, [hidden_size], weight, bias, eps=1e-5
))

print(
f"Helion time: {helion_sec:.4f}ms, torch time: {torch_sec:.4f}, speedup: {torch_sec / helion_sec:.2f}x"
)


def main() -> None:
# Test with different sizes
print("Testing batch_size=128, hidden_size=768")
check(128, 768)

print("\nTesting batch_size=32, hidden_size=1024")
check(32, 1024)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions helion/_compiler/ast_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def visit(self, node: ast.AST) -> ast.AST:
except exc.Base:
raise
except Exception as e:
import pdb; pdb.set_trace()
raise exc.InternalError(e) from e


Expand Down
19 changes: 14 additions & 5 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,25 @@ def convert_arg(arg: Node) -> TensorBox:
nodes = []
extra_input_names = []
new_node: torch.fx.Node


read_buffer_names = set()
# Explicitly track the mapping from node to Inductor buffer name.
# First, map the original input nodes to their names.
node_to_buf_name_mapping: dict[torch.fx.Node, str] = dict(
zip(node._input_nodes, input_names, strict=True)
)

for i, buffer in enumerate(new_buffers):
if not isinstance(buffer, ComputedBuffer) or not isinstance(
buffer.data, (Pointwise, Reduction)
):
raise InductorLoweringError(
f"Lowering {node.target} returned buffer type {type(buffer)}, expected ComputedBuffer(Pointwise|Reduction): {buffer}"
)

for name in buffer.get_read_names():
read_buffer_names.add(name)

if i == len(new_buffers) - 1:
new_node = node
if nodes:
Expand All @@ -191,6 +196,7 @@ def convert_arg(arg: Node) -> TensorBox:
new_node = create_extra_node(node, buffer, [*node._input_nodes, *nodes])

# Store output index if this buffer corresponds to an output
import pdb; pdb.set_trace()
if buffer.get_name() in buffer_name_to_output_index:
new_node.meta["output_index"] = buffer_name_to_output_index[
buffer.get_name()
Expand All @@ -207,7 +213,7 @@ def convert_arg(arg: Node) -> TensorBox:
current_input_names = []
for inp_node in current_input_nodes:
current_input_names.append(node_to_buf_name_mapping[inp_node])

used_input_names = strip_unused_inputs(
new_node,
buffer.get_read_names(),
Expand All @@ -230,6 +236,7 @@ def convert_arg(arg: Node) -> TensorBox:
for n in nodes:
if "output_index" in n.meta:
output_nodes[n.meta["output_index"]] = n.name
import pdb; pdb.set_trace()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import pdb; pdb.set_trace()
breakpoint()

same thing but shorter.

last_node.meta["output_nodes"] = output_nodes


Expand All @@ -254,6 +261,8 @@ def mask_unused_inputs(n: torch.fx.Node) -> torch.fx.Node | None:
return n
return None

if node.name == "var_mean":
import pdb; pdb.set_trace()
assert len(input_names) == len(node._input_nodes)
seen_names: dict[str, None] = {}
node.args = map_arg(node.args, mask_unused_inputs)
Expand Down Expand Up @@ -878,11 +887,11 @@ def _collect_multi_outputs(
Collect outputs for multi-output operations using metadata.
"""
# Check if this operation has multiple outputs using the new metadata
assert "output_nodes" in node.meta
assert "output_nodes" in node.meta, "Output nodes not in node.meta"
output_nodes = node.meta["output_nodes"]
outputs = [None] * len(output_nodes)
all_nodes = {n.name: n for n in self.module.graph.nodes} # pyre-ignore[16]

import pdb; pdb.set_trace()
for idx, node_name in output_nodes.items():
if node_name == node.name:
# This is the last node
Expand Down
1 change: 1 addition & 0 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,7 @@ def visit_BinOp(self, node: ast.BinOp) -> TypeInfo:
except exc.Base:
raise
except Exception as e:
import pdb; pdb.set_trace()
raise exc.TorchOpTracingError(e) from e

if isinstance(left, UnknownType):
Expand Down
Loading