@@ -989,6 +989,54 @@ def _moe_matmul_ogs_make_precompiler(A: torch.Tensor, W: torch.Tensor, expert_to
989
989
from helion.runtime.precompile_shim import make_precompiler
990
990
return make_precompiler(_moe_matmul_ogs_kernel)(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A.stride(0), A.stride(1), C.stride(0), C.stride(1), W.stride(0), W.stride(1), W.stride(2), expert_token_counts.stride(0), expert_token_offsets.stride(0), sorted_to_orig_token_idx.stride(0), max_T_per_expert, N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
991
991
992
+ --- assertExpectedJournal(TestExamples.test_rms_norm)
993
+ from __future__ import annotations
994
+
995
+ import torch
996
+ import triton
997
+ import triton.language as tl
998
+ from torch._inductor.runtime.triton_compat import libdevice
999
+
1000
+ @triton.jit
1001
+ def _rms_norm_kernel(x, weight, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
1002
+ pid_0 = tl.program_id(0)
1003
+ offset_0 = pid_0 * _BLOCK_SIZE_0
1004
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1005
+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
1006
+ load = tl.load(x + (indices_0[:, None] * 256 + indices_1[None, :] * 1), None)
1007
+ v_0 = load.to(tl.float32)
1008
+ v_1 = v_0 * v_0
1009
+ mean_x_squared_extra = tl.reshape(tl.sum(v_1, 1), [_BLOCK_SIZE_0, 1])
1010
+ v_2 = 256
1011
+ v_3 = mean_x_squared_extra / v_2.to(tl.float32)
1012
+ v_4 = v_3 + eps
1013
+ v_5 = libdevice.rsqrt(v_4)
1014
+ v_6 = v_0 * v_5
1015
+ load_1 = tl.load(weight + indices_1 * 1, None)
1016
+ v_7 = load_1.to(tl.float32)
1017
+ v_8 = v_7[None, :]
1018
+ v_9 = v_6 * v_8
1019
+ v_10 = v_9.to(tl.float16)
1020
+ tl.store(out + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_10, None)
1021
+
1022
+ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05):
1023
+ m, n = x.size()
1024
+ assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
1025
+ out = torch.empty([m, n], dtype=x.dtype, device=x.device)
1026
+ _BLOCK_SIZE_0 = 16
1027
+ _RDIM_SIZE_1 = 256
1028
+ _rms_norm_kernel[triton.cdiv(128, _BLOCK_SIZE_0),](x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
1029
+ return out
1030
+
1031
+ def _rms_norm_make_precompiler(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05):
1032
+ m, n = x.size()
1033
+ assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
1034
+ out = torch.empty([m, n], dtype=x.dtype, device=x.device)
1035
+ _BLOCK_SIZE_0 = 16
1036
+ _RDIM_SIZE_1 = 256
1037
+ from helion.runtime.precompile_shim import make_precompiler
1038
+ return make_precompiler(_rms_norm_kernel)(x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
1039
+
992
1040
--- assertExpectedJournal(TestExamples.test_softmax)
993
1041
from __future__ import annotations
994
1042
0 commit comments