Skip to content

Commit 57f1f52

Browse files
committed
Replace RMSNorm Gated with fused triton kernel
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
1 parent 0d21b26 commit 57f1f52

File tree

2 files changed

+176
-12
lines changed

2 files changed

+176
-12
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
update_metadata)
2323
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
2424
causal_conv1d_fn, causal_conv1d_update)
25+
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
2526
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
2627
selective_state_update)
2728
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
@@ -132,18 +133,12 @@ def forward_cuda(
132133
if self.tp_size > 1 or self.n_groups != 1:
133134
return self.forward_native(x, gate)
134135

135-
from vllm import _custom_ops as ops
136-
137-
# cast x and gate to float32 before silu
138-
out = torch.empty_like(x)
139-
y = x * nn.functional.silu(gate.to(torch.float32))
140-
ops.rms_norm(
141-
out,
142-
y.to(x.dtype),
143-
self.weight.data,
144-
self.variance_epsilon,
145-
)
146-
return out
136+
return rms_norm_gated(x,
137+
self.weight.data,
138+
bias=None,
139+
z=gate,
140+
eps=self.variance_epsilon,
141+
norm_before_gate=False)
147142

148143

149144
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Copyright (c) 2024, Tri Dao.
4+
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/layernorm_gated.py
5+
6+
import torch
7+
8+
from vllm.triton_utils import tl, triton
9+
10+
11+
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
12+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
13+
@triton.jit
14+
def _layer_norm_fwd_1pass_kernel(
15+
X, # pointer to the input
16+
Y, # pointer to the output
17+
W, # pointer to the weights
18+
B, # pointer to the biases
19+
Z, # pointer to the other branch
20+
Mean, # pointer to the mean
21+
Rstd, # pointer to the 1/std
22+
stride_x_row, # how much to increase the pointer when moving by 1 row
23+
stride_y_row,
24+
stride_z_row,
25+
M, # number of rows in X
26+
N, # number of columns in X
27+
eps, # epsilon to avoid division by zero
28+
BLOCK_N: tl.constexpr,
29+
HAS_BIAS: tl.constexpr,
30+
HAS_Z: tl.constexpr,
31+
NORM_BEFORE_GATE: tl.constexpr,
32+
IS_RMS_NORM: tl.constexpr,
33+
):
34+
# Map the program id to the row of X and Y it should compute.
35+
row = tl.program_id(0)
36+
group = tl.program_id(1)
37+
X += row * stride_x_row + group * N
38+
Y += row * stride_y_row + group * N
39+
if HAS_Z:
40+
Z += row * stride_z_row + group * N
41+
if not IS_RMS_NORM:
42+
Mean += group * M
43+
Rstd += group * M
44+
W += group * N
45+
if HAS_BIAS:
46+
B += group * N
47+
# Compute mean and variance
48+
cols = tl.arange(0, BLOCK_N)
49+
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
50+
if HAS_Z and not NORM_BEFORE_GATE:
51+
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
52+
x *= z * tl.sigmoid(z)
53+
if not IS_RMS_NORM:
54+
mean = tl.sum(x, axis=0) / N
55+
tl.store(Mean + row, mean)
56+
xbar = tl.where(cols < N, x - mean, 0.)
57+
var = tl.sum(xbar * xbar, axis=0) / N
58+
else:
59+
xbar = tl.where(cols < N, x, 0.)
60+
var = tl.sum(xbar * xbar, axis=0) / N
61+
rstd = 1 / tl.sqrt(var + eps)
62+
tl.store(Rstd + row, rstd)
63+
# Normalize and apply linear transformation
64+
mask = cols < N
65+
w = tl.load(W + cols, mask=mask).to(tl.float32)
66+
if HAS_BIAS:
67+
b = tl.load(B + cols, mask=mask).to(tl.float32)
68+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
69+
y = x_hat * w + b if HAS_BIAS else x_hat * w
70+
if HAS_Z and NORM_BEFORE_GATE:
71+
z = tl.load(Z + cols, mask=mask).to(tl.float32)
72+
y *= z * tl.sigmoid(z)
73+
# Write output
74+
tl.store(Y + cols, y, mask=mask)
75+
76+
77+
def _layer_norm_fwd(x,
78+
weight,
79+
bias,
80+
eps,
81+
z=None,
82+
out=None,
83+
group_size=None,
84+
norm_before_gate=True,
85+
is_rms_norm=False):
86+
M, N = x.shape
87+
if group_size is None:
88+
group_size = N
89+
assert N % group_size == 0
90+
ngroups = N // group_size
91+
assert x.stride(-1) == 1
92+
if z is not None:
93+
assert z.stride(-1) == 1
94+
assert z.shape == (M, N)
95+
assert weight.shape == (N, )
96+
assert weight.stride(-1) == 1
97+
if bias is not None:
98+
assert bias.stride(-1) == 1
99+
assert bias.shape == (N, )
100+
# allocate output
101+
if out is not None:
102+
assert out.shape == x.shape
103+
else:
104+
out = torch.empty_like(x)
105+
assert out.stride(-1) == 1
106+
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
107+
device=x.device) if not is_rms_norm else None
108+
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
109+
# Less than 64KB per feature: enqueue fused kernel
110+
MAX_FUSED_SIZE = 65536 // x.element_size()
111+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
112+
if group_size > BLOCK_N:
113+
raise RuntimeError(
114+
"This layer norm doesn't support feature dim >= 64KB.")
115+
# heuristics for number of warps
116+
num_warps = min(max(BLOCK_N // 256, 1), 8)
117+
grid = (M, ngroups)
118+
with torch.cuda.device(x.device.index):
119+
_layer_norm_fwd_1pass_kernel[grid](x,
120+
out,
121+
weight,
122+
bias,
123+
z,
124+
mean,
125+
rstd,
126+
x.stride(0),
127+
out.stride(0),
128+
z.stride(0) if z is not None else 0,
129+
M,
130+
group_size,
131+
eps,
132+
BLOCK_N=BLOCK_N,
133+
NORM_BEFORE_GATE=norm_before_gate,
134+
IS_RMS_NORM=is_rms_norm,
135+
num_warps=num_warps)
136+
return out, mean, rstd
137+
138+
139+
def rms_norm_gated(x,
140+
weight,
141+
bias,
142+
z=None,
143+
eps=1e-6,
144+
group_size=None,
145+
norm_before_gate=True,
146+
upcast=True):
147+
x_shape_og = x.shape
148+
# reshape input data into 2D tensor
149+
x = x.reshape(-1, x.shape[-1])
150+
if x.stride(-1) != 1:
151+
x = x.contiguous()
152+
if z is not None:
153+
assert z.shape == x_shape_og
154+
z = z.reshape(-1, z.shape[-1])
155+
if z.stride(-1) != 1:
156+
z = z.contiguous()
157+
weight = weight.contiguous()
158+
if bias is not None:
159+
bias = bias.contiguous()
160+
y, _, _ = _layer_norm_fwd(x,
161+
weight,
162+
bias,
163+
eps,
164+
z=z,
165+
group_size=group_size,
166+
norm_before_gate=norm_before_gate,
167+
is_rms_norm=True)
168+
169+
return y.reshape(x_shape_og)

0 commit comments

Comments
 (0)