Skip to content

Commit 5777938

Browse files
committed
Deepseek V3 fix
1 parent c5a9406 commit 5777938

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,26 @@ def fused_moe_kernel(
3333
expert_ids_ptr,
3434
num_tokens_post_padded_ptr,
3535
# Matrix dimensions
36-
N,
37-
K,
38-
EM,
39-
num_valid_tokens,
36+
N: tl.int64,
37+
K: tl.int64,
38+
EM: tl.int64,
39+
num_valid_tokens: tl.int64,
4040
# The stride variables represent how much to increase the ptr by when
4141
# moving by 1 element in a particular dimension. E.g. `stride_am` is
4242
# how much to increase `a_ptr` by to get the element one row down
4343
# (A has M rows).
44-
stride_am,
45-
stride_ak,
46-
stride_be,
47-
stride_bk,
48-
stride_bn,
49-
stride_cm,
50-
stride_cn,
51-
stride_asm,
52-
stride_ask,
53-
stride_bse,
54-
stride_bsk,
55-
stride_bsn,
44+
stride_am: tl.int64,
45+
stride_ak: tl.int64,
46+
stride_be: tl.int64,
47+
stride_bk: tl.int64,
48+
stride_bn: tl.int64,
49+
stride_cm: tl.int64,
50+
stride_cn: tl.int64,
51+
stride_asm: tl.int64,
52+
stride_ask: tl.int64,
53+
stride_bse: tl.int64,
54+
stride_bsk: tl.int64,
55+
stride_bsn: tl.int64,
5656
# Block size for block-wise quantization
5757
group_n: tl.constexpr,
5858
group_k: tl.constexpr,
@@ -114,18 +114,16 @@ def fused_moe_kernel(
114114
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
115115
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
116116
return
117-
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
118-
tl.int64)
117+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
119118
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
120119
token_mask = offs_token < num_valid_tokens
121120

122-
offs_bn = (pid_n * BLOCK_SIZE_N +
123-
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
121+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
124122
offs_k = tl.arange(0, BLOCK_SIZE_K)
125123
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
126124
offs_k[None, :] * stride_ak)
127125

128-
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
126+
off_experts = tl.load(expert_ids_ptr + pid_m)
129127
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
130128
offs_bn[None, :] * stride_bn)
131129
if use_int8_w8a16:

0 commit comments

Comments
 (0)