Skip to content

Commit da9cf12

Browse files
authored
Add rms_norm example and test (#252)
- Implement RMS normalization kernel using helion - Add PyTorch reference implementation - Add unit test in test_examples.py with expected output
1 parent 3ff927d commit da9cf12

File tree

3 files changed

+129
-0
lines changed

3 files changed

+129
-0
lines changed

examples/rms_norm.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import run_example
7+
import helion.language as hl
8+
9+
10+
@helion.kernel(static_shapes=True)
11+
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
12+
m, n = x.size()
13+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"
14+
15+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
16+
17+
for tile_m in hl.tile(m):
18+
x_tile = x[tile_m, :].to(torch.float32)
19+
20+
# Compute RMS: sqrt(mean(x^2))
21+
x_squared = x_tile * x_tile
22+
mean_x_squared = torch.mean(x_squared, dim=-1, keepdim=True)
23+
rms = torch.rsqrt(mean_x_squared + eps)
24+
25+
# Apply normalization and weight
26+
normalized = x_tile * rms
27+
out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
28+
29+
return out
30+
31+
32+
def rms_norm_tritonbench(H: int, inp: torch.Tensor) -> torch.Tensor:
33+
"""Wrapper for tritonbench that matches expected interface."""
34+
weight = torch.ones(H, device=inp.device, dtype=inp.dtype)
35+
return rms_norm(inp, weight, eps=1e-6)
36+
37+
38+
def rms_norm_pytorch(
39+
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5
40+
) -> torch.Tensor:
41+
input_dtype = x.dtype
42+
hidden_states = x.to(torch.float32)
43+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
44+
hidden_states = hidden_states * torch.rsqrt(variance + eps)
45+
return weight * hidden_states.to(input_dtype)
46+
47+
48+
def check(m: int, n: int) -> None:
49+
x = torch.randn([m, n], device="cuda", dtype=torch.float16)
50+
weight = torch.randn([n], device="cuda", dtype=torch.float16)
51+
run_example(rms_norm, rms_norm_pytorch, (x, weight, 1e-5))
52+
53+
54+
def main() -> None:
55+
check(32, 64)
56+
check(128, 256)
57+
check(1024, 1024)
58+
59+
60+
if __name__ == "__main__":
61+
main()

test/test_examples.expected

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,54 @@ def _moe_matmul_ogs_make_precompiler(A: torch.Tensor, W: torch.Tensor, expert_to
989989
from helion.runtime.precompile_shim import make_precompiler
990990
return make_precompiler(_moe_matmul_ogs_kernel)(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A.stride(0), A.stride(1), C.stride(0), C.stride(1), W.stride(0), W.stride(1), W.stride(2), expert_token_counts.stride(0), expert_token_offsets.stride(0), sorted_to_orig_token_idx.stride(0), max_T_per_expert, N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
991991

992+
--- assertExpectedJournal(TestExamples.test_rms_norm)
993+
from __future__ import annotations
994+
995+
import torch
996+
import triton
997+
import triton.language as tl
998+
from torch._inductor.runtime.triton_compat import libdevice
999+
1000+
@triton.jit
1001+
def _rms_norm_kernel(x, weight, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
1002+
pid_0 = tl.program_id(0)
1003+
offset_0 = pid_0 * _BLOCK_SIZE_0
1004+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1005+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
1006+
load = tl.load(x + (indices_0[:, None] * 256 + indices_1[None, :] * 1), None)
1007+
v_0 = load.to(tl.float32)
1008+
v_1 = v_0 * v_0
1009+
mean_x_squared_extra = tl.reshape(tl.sum(v_1, 1), [_BLOCK_SIZE_0, 1])
1010+
v_2 = 256
1011+
v_3 = mean_x_squared_extra / v_2.to(tl.float32)
1012+
v_4 = v_3 + eps
1013+
v_5 = libdevice.rsqrt(v_4)
1014+
v_6 = v_0 * v_5
1015+
load_1 = tl.load(weight + indices_1 * 1, None)
1016+
v_7 = load_1.to(tl.float32)
1017+
v_8 = v_7[None, :]
1018+
v_9 = v_6 * v_8
1019+
v_10 = v_9.to(tl.float16)
1020+
tl.store(out + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_10, None)
1021+
1022+
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05):
1023+
m, n = x.size()
1024+
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
1025+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
1026+
_BLOCK_SIZE_0 = 16
1027+
_RDIM_SIZE_1 = 256
1028+
_rms_norm_kernel[triton.cdiv(128, _BLOCK_SIZE_0),](x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
1029+
return out
1030+
1031+
def _rms_norm_make_precompiler(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05):
1032+
m, n = x.size()
1033+
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
1034+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
1035+
_BLOCK_SIZE_0 = 16
1036+
_RDIM_SIZE_1 = 256
1037+
from helion.runtime.precompile_shim import make_precompiler
1038+
return make_precompiler(_rms_norm_kernel)(x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
1039+
9921040
--- assertExpectedJournal(TestExamples.test_softmax)
9931041
from __future__ import annotations
9941042

test/test_examples.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,26 @@ def test_softmax_two_pass_block_ptr(self):
237237
)
238238
)
239239

240+
def test_rms_norm(self):
241+
args = (
242+
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),
243+
torch.randn([256], device=DEVICE, dtype=torch.float16),
244+
1e-5,
245+
)
246+
# Import and use the reference implementation from rms_norm.py
247+
mod = import_path(EXAMPLES_DIR / "rms_norm.py")
248+
expected = mod.rms_norm_pytorch(*args)
249+
250+
self.assertExpectedJournal(
251+
check_example(
252+
"rms_norm",
253+
args,
254+
expected,
255+
block_sizes=[16],
256+
indexing="pointer",
257+
)
258+
)
259+
240260
def test_embedding_pointers(self):
241261
args = (
242262
torch.randint(0, 1024, [8, 128], device=DEVICE, dtype=torch.int32),

0 commit comments

Comments
 (0)