Skip to content

Commit 00f5303

Browse files
authored
[Model] Fused rope implementation for DeepSeek-v2 (#3105)
This PR fuses the rope operation with the matrix transposes which we were doing before the rope operation. Have tested the model (after compiling) post the changes.
1 parent 2c1001b commit 00f5303

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -137,38 +137,44 @@ def forward(
137137
k: Tensor,
138138
positions: Tensor,
139139
):
140-
def _rope(x: te.Tensor, positions: te.Tensor):
140+
def _rope_fused(x: te.Tensor, positions: te.Tensor):
141+
_, _, _, d_dim = x.shape
142+
d_dim_half = d_dim // 2
141143
dtype = x.dtype
142144

143145
def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):
146+
d1 = d // d_dim_half
147+
d2 = d % d_dim_half
148+
144149
cos_freq, sin_freq, var_map = self.rope_fn(
145150
positions[s], d, self.rotary_dim, self.theta, dtype
146151
)
147-
cos = cos_freq * x[b, s, h, d]
148-
sin = sin_freq * tir.if_then_else(
152+
cos = x[b, s, h, d2 * 2 + d1] * cos_freq
153+
154+
partner_d = tir.if_then_else(
149155
d < self.rotary_dim // 2,
150-
-x[b, s, h, d + self.rotary_dim // 2],
151-
x[b, s, h, d - self.rotary_dim // 2],
156+
d + self.rotary_dim // 2,
157+
d - self.rotary_dim // 2,
158+
)
159+
160+
partner_d1 = partner_d // d_dim_half
161+
partner_d2 = partner_d % d_dim_half
162+
sin = (
163+
x[b, s, h, partner_d2 * 2 + partner_d1]
164+
* sin_freq
165+
* tir.if_then_else(
166+
d < self.rotary_dim // 2, tir.const(-1, dtype), tir.const(1, dtype)
167+
)
152168
)
153169
expr = cos + sin
154-
for var, value in var_map.items():
155-
expr = tir.Let(var, value, expr)
170+
for var, val in var_map.items():
171+
expr = tir.Let(var, val, expr)
156172
return expr
157173

158174
return te.compute(x.shape, compute, name="yarn_rope")
159175

160-
b, s, h, d = q.shape
161-
q = op.reshape(
162-
op.permute_dims(op.reshape(q, (b, s, h, d // 2, 2)), [0, 1, 2, 4, 3]), (b, s, h, d)
163-
)
164-
165-
b, s, h, d = k.shape
166-
k = op.reshape(
167-
op.permute_dims(op.reshape(k, (b, s, h, d // 2, 2)), [0, 1, 2, 4, 3]), (b, s, h, d)
168-
)
169-
170-
q_embed = op.tensor_expr_op(_rope, "rope", [q, positions])
171-
k_embed = op.tensor_expr_op(_rope, "rope", [k, positions])
176+
q_embed = op.tensor_expr_op(_rope_fused, "rope", [q, positions])
177+
k_embed = op.tensor_expr_op(_rope_fused, "rope", [k, positions])
172178
return q_embed, k_embed
173179

174180

0 commit comments

Comments
 (0)