Skip to content

Commit 58ad0a6

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[Kernel][Performance] Tweak MoE Batched silu_mul_fp8_quant_deep_gemm kernel (vllm-project#21193)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Himanshu Jaju <hj@mistral.ai>
1 parent 2f65f4c commit 58ad0a6

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _silu_mul_fp8_quant_deep_gemm(
5555

5656
# Meta ---------------------------------------------------------------
5757
BLOCK: tl.constexpr,
58+
NUM_STAGES: tl.constexpr,
5859
):
5960
G = H // GROUP_SIZE
6061

@@ -73,8 +74,7 @@ def _silu_mul_fp8_quant_deep_gemm(
7374
cols = cols.to(tl.int64)
7475
mask_h = cols < BLOCK
7576

76-
t = tl.zeros([], tl.int64)
77-
while t < n_tokens:
77+
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
7878
base_i_offset = (e * stride_i_e + t * stride_i_t +
7979
g * GROUP_SIZE * stride_i_h)
8080
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
@@ -102,8 +102,6 @@ def _silu_mul_fp8_quant_deep_gemm(
102102
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
103103
tl.store(y_s_ptr + base_ys_offset, y_s)
104104

105-
t += 1
106-
107105

108106
def silu_mul_fp8_quant_deep_gemm(
109107
y: torch.Tensor, # (E, T, 2*H) float32
@@ -180,7 +178,8 @@ def silu_mul_fp8_quant_deep_gemm(
180178
fp8_max,
181179
is_blackwell_deep_gemm_used(),
182180
BLOCK=group_size,
183-
num_warps=4,
181+
NUM_STAGES=8,
182+
num_warps=1,
184183
)
185184

186185
return y_q, y_s

0 commit comments

Comments
 (0)