diff --git a/examples/layer_norm.py b/examples/layer_norm.py new file mode 100644 index 00000000..7d067670 --- /dev/null +++ b/examples/layer_norm.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import torch + +import helion +import helion.language as hl + +""" + NOTE: layer_norm_fwd_ideal does not work! I am keeping this around as a reference + to what I believed should have worked in Helion when I first began without debugging. + + The user experience should be pushed this direction +""" +@helion.kernel(static_shapes=True) +def layer_norm_fwd_ideal( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float = 1e-5 +) -> torch.Tensor: + """ + Layer normalization forward pass. + + Args: + x: Input tensor of shape [batch_size, hidden_size] + weight: Scale parameter of shape [hidden_size] + bias: Bias parameter of shape [hidden_size] + eps: Epsilon for numerical stability + + Returns: + Normalized tensor of shape [batch_size, hidden_size] + """ + m = x.size(0) + out = torch.empty_like(x) + + for tile_b in hl.tile(m): + row = x[tile_b] + mean, var = torch.var_mean(row) + + layer_norm_out = (row - mean) / torch.sqrt(var + eps) + layer_norm_out = layer_norm_out * weight + bias + out[tile_b, :] = layer_norm_out + + return out + +@helion.kernel(static_shapes=True, use_default_config=True) +def layer_norm_fwd( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + m, n = x.size() + assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}" + assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}" + out = torch.empty( + [m, n], dtype=torch.float16, device=x.device + ) + + eps = 1e-5 + + for tile_m in hl.tile(m): + # acc = x[tile_m, :].to(torch.float32) works! We should not have to do this cast + acc = x[tile_m, :] + + var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) + + normalized = (acc - mean) * torch.rsqrt(var + eps) + acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) + + out[tile_m, :] = acc + return out + + +def check(batch_size: int, hidden_size: int) -> None: + from triton.testing import do_bench + + # Create random input tensors + x = torch.randn([batch_size, hidden_size], device="cuda", dtype=torch.float16) + weight = torch.randn([hidden_size], device="cuda", dtype=torch.float16) + bias = torch.randn([hidden_size], device="cuda", dtype=torch.float16) + + # Run Helion kernel + result = layer_norm_fwd(x, weight, bias) + + # # Run PyTorch layer norm for comparison + torch_result = torch.nn.functional.layer_norm( + x, [hidden_size], weight, bias, eps=1e-5 + ) + + # # Check correctness + torch.testing.assert_close(result, torch_result, rtol=1e-2, atol=1e-1) + + # Benchmark Helion implementation + helion_sec = do_bench(lambda: layer_norm_fwd(x, weight, bias)) + + # Benchmark PyTorch implementation + torch_sec = do_bench(lambda: torch.nn.functional.layer_norm( + x, [hidden_size], weight, bias, eps=1e-5 + )) + + print( + f"Helion time: {helion_sec:.4f}ms, torch time: {torch_sec:.4f}, speedup: {torch_sec / helion_sec:.2f}x" + ) + + +def main() -> None: + # Test with different sizes + print("Testing batch_size=128, hidden_size=768") + check(128, 768) + + print("\nTesting batch_size=32, hidden_size=1024") + check(32, 1024) + + +if __name__ == "__main__": + main() diff --git a/helion/_compiler/ast_extension.py b/helion/_compiler/ast_extension.py index 41b73086..2f0ce9c3 100644 --- a/helion/_compiler/ast_extension.py +++ b/helion/_compiler/ast_extension.py @@ -2,6 +2,7 @@ import ast import enum +import logging import threading import typing from typing import TYPE_CHECKING @@ -10,6 +11,8 @@ from .. import exc from .source_location import SourceLocation from .source_location import current_location +import sys +import traceback if TYPE_CHECKING: from collections.abc import Sequence @@ -222,7 +225,9 @@ def visit(self, node: ast.AST) -> ast.AST: except exc.Base: raise except Exception as e: - raise exc.InternalError(e) from e + logging.error(f"Original Error: {str(e)}") + traceback.print_tb(e.__traceback__) + raise # Determine whether vanilla ast.unparse keeps parentheses in "(a, b) = c". diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 0698a89e..4ea9745e 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -149,6 +149,8 @@ def convert_arg(arg: Node) -> TensorBox: # pyre-ignore[6] *map_arg((node.args, node.kwargs), convert_arg), ) + # Previously it was (buf3, buf1), associating buf3 with variance and buf1 with mean + # Now it is (buf3, buf4), so buf4 (the .to_dtype) is the output node, variance gets lost in buf3 if not isinstance(result, tuple): result = (result,) buffer_name_to_output_index = {} @@ -165,17 +167,18 @@ def convert_arg(arg: Node) -> TensorBox: buffer_name_to_output_index[buffer.get_name()] = i new_buffers = graph_lowering.buffers[prior_buffers:] + assert buffer in new_buffers # pyre-ignore[61] nodes = [] extra_input_names = [] new_node: torch.fx.Node - + # Explicitly track the mapping from node to Inductor buffer name. # First, map the original input nodes to their names. node_to_buf_name_mapping: dict[torch.fx.Node, str] = dict( zip(node._input_nodes, input_names, strict=True) ) - + for i, buffer in enumerate(new_buffers): if not isinstance(buffer, ComputedBuffer) or not isinstance( buffer.data, (Pointwise, Reduction) @@ -207,18 +210,27 @@ def convert_arg(arg: Node) -> TensorBox: current_input_names = [] for inp_node in current_input_nodes: current_input_names.append(node_to_buf_name_mapping[inp_node]) + - used_input_names = strip_unused_inputs( - new_node, - buffer.get_read_names(), - dict(zip(current_input_nodes, current_input_names, strict=True)), - ) + if i != len(new_buffers) - 1: + used_input_names = strip_unused_inputs( + new_node, + buffer.get_read_names(), + dict(zip(current_input_nodes, current_input_names, strict=True)), + ) + else: + used_input_names = strip_unused_inputs( + new_node, + set(current_input_names[1:]), + dict(zip(current_input_nodes, current_input_names, strict=True)), + ) + new_node.meta["lowering"] = lowering = lowering_cls(buffer, used_input_names) new_node.meta["orig_node"] = node if isinstance(lowering, ReductionLowering): lowering.add_input_mask(new_node) + nodes.append(new_node) - extra_input_names.append(buffer.get_name()) # Add this node to our mapping for future nodes to reference node_to_buf_name_mapping[new_node] = buffer.get_name() @@ -230,6 +242,7 @@ def convert_arg(arg: Node) -> TensorBox: for n in nodes: if "output_index" in n.meta: output_nodes[n.meta["output_index"]] = n.name + import pdb; pdb.set_trace() last_node.meta["output_nodes"] = output_nodes @@ -489,6 +502,7 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: assert len(inputs) == 2 # `inputs[0]` is the original input tensor to var_mean repr_input = inputs[0] + import pdb; pdb.set_trace() else: # TODO(jansel): combine multiple inputs into a single fake value raise NotImplementedError("reductions with >1 input") @@ -878,11 +892,11 @@ def _collect_multi_outputs( Collect outputs for multi-output operations using metadata. """ # Check if this operation has multiple outputs using the new metadata - assert "output_nodes" in node.meta + assert "output_nodes" in node.meta, "Output nodes not in node.meta" output_nodes = node.meta["output_nodes"] outputs = [None] * len(output_nodes) all_nodes = {n.name: n for n in self.module.graph.nodes} # pyre-ignore[16] - + import pdb; pdb.set_trace() for idx, node_name in output_nodes.items(): if node_name == node.name: # This is the last node @@ -896,6 +910,7 @@ def _collect_multi_outputs( # Ensure all outputs are found and are ast.Name nodes final_outputs = [] + import pdb; pdb.set_trace() for i, result in enumerate(outputs): assert result is not None if not isinstance(result, ast.Name): diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index b148ec73..802fa5fc 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -1780,6 +1780,7 @@ def visit_BinOp(self, node: ast.BinOp) -> TypeInfo: except exc.Base: raise except Exception as e: + import pdb; pdb.set_trace() raise exc.TorchOpTracingError(e) from e if isinstance(left, UnknownType):