From cb1f0072a5a265c089e7aa46e0a5b9f839c61957 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Thu, 12 Jun 2025 18:37:57 -0700 Subject: [PATCH 1/2] Layer Norm fwd issue [ghstack-poisoned] --- examples/layer_norm.py | 111 ++++++++++++++++++++++++++ helion/_compiler/ast_extension.py | 1 + helion/_compiler/inductor_lowering.py | 19 +++-- helion/_compiler/type_propagation.py | 1 + 4 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 examples/layer_norm.py diff --git a/examples/layer_norm.py b/examples/layer_norm.py new file mode 100644 index 00000000..7d067670 --- /dev/null +++ b/examples/layer_norm.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import torch + +import helion +import helion.language as hl + +""" + NOTE: layer_norm_fwd_ideal does not work! I am keeping this around as a reference + to what I believed should have worked in Helion when I first began without debugging. + + The user experience should be pushed this direction +""" +@helion.kernel(static_shapes=True) +def layer_norm_fwd_ideal( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float = 1e-5 +) -> torch.Tensor: + """ + Layer normalization forward pass. + + Args: + x: Input tensor of shape [batch_size, hidden_size] + weight: Scale parameter of shape [hidden_size] + bias: Bias parameter of shape [hidden_size] + eps: Epsilon for numerical stability + + Returns: + Normalized tensor of shape [batch_size, hidden_size] + """ + m = x.size(0) + out = torch.empty_like(x) + + for tile_b in hl.tile(m): + row = x[tile_b] + mean, var = torch.var_mean(row) + + layer_norm_out = (row - mean) / torch.sqrt(var + eps) + layer_norm_out = layer_norm_out * weight + bias + out[tile_b, :] = layer_norm_out + + return out + +@helion.kernel(static_shapes=True, use_default_config=True) +def layer_norm_fwd( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + m, n = x.size() + assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}" + assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}" + out = torch.empty( + [m, n], dtype=torch.float16, device=x.device + ) + + eps = 1e-5 + + for tile_m in hl.tile(m): + # acc = x[tile_m, :].to(torch.float32) works! We should not have to do this cast + acc = x[tile_m, :] + + var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) + + normalized = (acc - mean) * torch.rsqrt(var + eps) + acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) + + out[tile_m, :] = acc + return out + + +def check(batch_size: int, hidden_size: int) -> None: + from triton.testing import do_bench + + # Create random input tensors + x = torch.randn([batch_size, hidden_size], device="cuda", dtype=torch.float16) + weight = torch.randn([hidden_size], device="cuda", dtype=torch.float16) + bias = torch.randn([hidden_size], device="cuda", dtype=torch.float16) + + # Run Helion kernel + result = layer_norm_fwd(x, weight, bias) + + # # Run PyTorch layer norm for comparison + torch_result = torch.nn.functional.layer_norm( + x, [hidden_size], weight, bias, eps=1e-5 + ) + + # # Check correctness + torch.testing.assert_close(result, torch_result, rtol=1e-2, atol=1e-1) + + # Benchmark Helion implementation + helion_sec = do_bench(lambda: layer_norm_fwd(x, weight, bias)) + + # Benchmark PyTorch implementation + torch_sec = do_bench(lambda: torch.nn.functional.layer_norm( + x, [hidden_size], weight, bias, eps=1e-5 + )) + + print( + f"Helion time: {helion_sec:.4f}ms, torch time: {torch_sec:.4f}, speedup: {torch_sec / helion_sec:.2f}x" + ) + + +def main() -> None: + # Test with different sizes + print("Testing batch_size=128, hidden_size=768") + check(128, 768) + + print("\nTesting batch_size=32, hidden_size=1024") + check(32, 1024) + + +if __name__ == "__main__": + main() diff --git a/helion/_compiler/ast_extension.py b/helion/_compiler/ast_extension.py index 41b73086..c5e89984 100644 --- a/helion/_compiler/ast_extension.py +++ b/helion/_compiler/ast_extension.py @@ -222,6 +222,7 @@ def visit(self, node: ast.AST) -> ast.AST: except exc.Base: raise except Exception as e: + import pdb; pdb.set_trace() raise exc.InternalError(e) from e diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 0698a89e..e4e16a1d 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -169,13 +169,14 @@ def convert_arg(arg: Node) -> TensorBox: nodes = [] extra_input_names = [] new_node: torch.fx.Node - + + read_buffer_names = set() # Explicitly track the mapping from node to Inductor buffer name. # First, map the original input nodes to their names. node_to_buf_name_mapping: dict[torch.fx.Node, str] = dict( zip(node._input_nodes, input_names, strict=True) ) - + for i, buffer in enumerate(new_buffers): if not isinstance(buffer, ComputedBuffer) or not isinstance( buffer.data, (Pointwise, Reduction) @@ -183,6 +184,10 @@ def convert_arg(arg: Node) -> TensorBox: raise InductorLoweringError( f"Lowering {node.target} returned buffer type {type(buffer)}, expected ComputedBuffer(Pointwise|Reduction): {buffer}" ) + + for name in buffer.get_read_names(): + read_buffer_names.add(name) + if i == len(new_buffers) - 1: new_node = node if nodes: @@ -191,6 +196,7 @@ def convert_arg(arg: Node) -> TensorBox: new_node = create_extra_node(node, buffer, [*node._input_nodes, *nodes]) # Store output index if this buffer corresponds to an output + import pdb; pdb.set_trace() if buffer.get_name() in buffer_name_to_output_index: new_node.meta["output_index"] = buffer_name_to_output_index[ buffer.get_name() @@ -207,7 +213,7 @@ def convert_arg(arg: Node) -> TensorBox: current_input_names = [] for inp_node in current_input_nodes: current_input_names.append(node_to_buf_name_mapping[inp_node]) - + used_input_names = strip_unused_inputs( new_node, buffer.get_read_names(), @@ -230,6 +236,7 @@ def convert_arg(arg: Node) -> TensorBox: for n in nodes: if "output_index" in n.meta: output_nodes[n.meta["output_index"]] = n.name + import pdb; pdb.set_trace() last_node.meta["output_nodes"] = output_nodes @@ -254,6 +261,8 @@ def mask_unused_inputs(n: torch.fx.Node) -> torch.fx.Node | None: return n return None + if node.name == "var_mean": + import pdb; pdb.set_trace() assert len(input_names) == len(node._input_nodes) seen_names: dict[str, None] = {} node.args = map_arg(node.args, mask_unused_inputs) @@ -878,11 +887,11 @@ def _collect_multi_outputs( Collect outputs for multi-output operations using metadata. """ # Check if this operation has multiple outputs using the new metadata - assert "output_nodes" in node.meta + assert "output_nodes" in node.meta, "Output nodes not in node.meta" output_nodes = node.meta["output_nodes"] outputs = [None] * len(output_nodes) all_nodes = {n.name: n for n in self.module.graph.nodes} # pyre-ignore[16] - + import pdb; pdb.set_trace() for idx, node_name in output_nodes.items(): if node_name == node.name: # This is the last node diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index b148ec73..802fa5fc 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -1780,6 +1780,7 @@ def visit_BinOp(self, node: ast.BinOp) -> TypeInfo: except exc.Base: raise except Exception as e: + import pdb; pdb.set_trace() raise exc.TorchOpTracingError(e) from e if isinstance(left, UnknownType): From 955b4e0b4c1c2dae1ecc0a2b8c95b830175b6752 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Fri, 13 Jun 2025 12:35:29 -0700 Subject: [PATCH 2/2] Update on "Layer Norm fwd issue" [ghstack-poisoned] --- helion/_compiler/ast_extension.py | 8 ++++-- helion/_compiler/inductor_lowering.py | 36 ++++++++++++++++----------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/helion/_compiler/ast_extension.py b/helion/_compiler/ast_extension.py index c5e89984..2f0ce9c3 100644 --- a/helion/_compiler/ast_extension.py +++ b/helion/_compiler/ast_extension.py @@ -2,6 +2,7 @@ import ast import enum +import logging import threading import typing from typing import TYPE_CHECKING @@ -10,6 +11,8 @@ from .. import exc from .source_location import SourceLocation from .source_location import current_location +import sys +import traceback if TYPE_CHECKING: from collections.abc import Sequence @@ -222,8 +225,9 @@ def visit(self, node: ast.AST) -> ast.AST: except exc.Base: raise except Exception as e: - import pdb; pdb.set_trace() - raise exc.InternalError(e) from e + logging.error(f"Original Error: {str(e)}") + traceback.print_tb(e.__traceback__) + raise # Determine whether vanilla ast.unparse keeps parentheses in "(a, b) = c". diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index e4e16a1d..4ea9745e 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -149,6 +149,8 @@ def convert_arg(arg: Node) -> TensorBox: # pyre-ignore[6] *map_arg((node.args, node.kwargs), convert_arg), ) + # Previously it was (buf3, buf1), associating buf3 with variance and buf1 with mean + # Now it is (buf3, buf4), so buf4 (the .to_dtype) is the output node, variance gets lost in buf3 if not isinstance(result, tuple): result = (result,) buffer_name_to_output_index = {} @@ -165,12 +167,12 @@ def convert_arg(arg: Node) -> TensorBox: buffer_name_to_output_index[buffer.get_name()] = i new_buffers = graph_lowering.buffers[prior_buffers:] + assert buffer in new_buffers # pyre-ignore[61] nodes = [] extra_input_names = [] new_node: torch.fx.Node - read_buffer_names = set() # Explicitly track the mapping from node to Inductor buffer name. # First, map the original input nodes to their names. node_to_buf_name_mapping: dict[torch.fx.Node, str] = dict( @@ -184,10 +186,6 @@ def convert_arg(arg: Node) -> TensorBox: raise InductorLoweringError( f"Lowering {node.target} returned buffer type {type(buffer)}, expected ComputedBuffer(Pointwise|Reduction): {buffer}" ) - - for name in buffer.get_read_names(): - read_buffer_names.add(name) - if i == len(new_buffers) - 1: new_node = node if nodes: @@ -196,7 +194,6 @@ def convert_arg(arg: Node) -> TensorBox: new_node = create_extra_node(node, buffer, [*node._input_nodes, *nodes]) # Store output index if this buffer corresponds to an output - import pdb; pdb.set_trace() if buffer.get_name() in buffer_name_to_output_index: new_node.meta["output_index"] = buffer_name_to_output_index[ buffer.get_name() @@ -213,18 +210,27 @@ def convert_arg(arg: Node) -> TensorBox: current_input_names = [] for inp_node in current_input_nodes: current_input_names.append(node_to_buf_name_mapping[inp_node]) - - used_input_names = strip_unused_inputs( - new_node, - buffer.get_read_names(), - dict(zip(current_input_nodes, current_input_names, strict=True)), - ) + + + if i != len(new_buffers) - 1: + used_input_names = strip_unused_inputs( + new_node, + buffer.get_read_names(), + dict(zip(current_input_nodes, current_input_names, strict=True)), + ) + else: + used_input_names = strip_unused_inputs( + new_node, + set(current_input_names[1:]), + dict(zip(current_input_nodes, current_input_names, strict=True)), + ) + new_node.meta["lowering"] = lowering = lowering_cls(buffer, used_input_names) new_node.meta["orig_node"] = node if isinstance(lowering, ReductionLowering): lowering.add_input_mask(new_node) + nodes.append(new_node) - extra_input_names.append(buffer.get_name()) # Add this node to our mapping for future nodes to reference node_to_buf_name_mapping[new_node] = buffer.get_name() @@ -261,8 +267,6 @@ def mask_unused_inputs(n: torch.fx.Node) -> torch.fx.Node | None: return n return None - if node.name == "var_mean": - import pdb; pdb.set_trace() assert len(input_names) == len(node._input_nodes) seen_names: dict[str, None] = {} node.args = map_arg(node.args, mask_unused_inputs) @@ -498,6 +502,7 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: assert len(inputs) == 2 # `inputs[0]` is the original input tensor to var_mean repr_input = inputs[0] + import pdb; pdb.set_trace() else: # TODO(jansel): combine multiple inputs into a single fake value raise NotImplementedError("reductions with >1 input") @@ -905,6 +910,7 @@ def _collect_multi_outputs( # Ensure all outputs are found and are ast.Name nodes final_outputs = [] + import pdb; pdb.set_trace() for i, result in enumerate(outputs): assert result is not None if not isinstance(result, ast.Name):