Skip to content

Commit 0d4891c

Browse files
authored
[Bug] Fix DeepGemm for EP low latency case (#20833)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent f56d299 commit 0d4891c

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
TopKWeightAndReduceDelegate)
1212
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1313
from vllm.triton_utils import tl, triton
14-
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked
14+
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
15+
is_blackwell_deep_gemm_used)
1516

1617
logger = init_logger(__name__)
1718

@@ -50,6 +51,7 @@ def _silu_mul_fp8_quant_deep_gemm(
5051
eps: tl.constexpr,
5152
fp8_min: tl.constexpr,
5253
fp8_max: tl.constexpr,
54+
use_ue8m0: tl.constexpr,
5355

5456
# Meta ---------------------------------------------------------------
5557
BLOCK: tl.constexpr,
@@ -92,7 +94,9 @@ def _silu_mul_fp8_quant_deep_gemm(
9294
y = x * y2
9395

9496
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
95-
y_s = _absmax / fp8_max
97+
scale_raw = _absmax / fp8_max
98+
y_s = tl.math.exp2(tl.ceil(
99+
tl.log2(scale_raw))) if use_ue8m0 else scale_raw
96100
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
97101

98102
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
@@ -174,6 +178,7 @@ def silu_mul_fp8_quant_deep_gemm(
174178
eps,
175179
fp8_min,
176180
fp8_max,
181+
is_blackwell_deep_gemm_used(),
177182
BLOCK=group_size,
178183
num_warps=4,
179184
)
@@ -290,14 +295,10 @@ def apply(
290295
# may lead to better performance.
291296
expected_m = max_num_tokens
292297
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
293-
out=workspace1,
294-
masked_m=expert_num_tokens,
295-
expected_m=expected_m)
298+
workspace1, expert_num_tokens, expected_m)
296299

297300
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
298301
expert_num_tokens)
299302

300-
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale),
301-
out=output,
302-
masked_m=expert_num_tokens,
303-
expected_m=expected_m)
303+
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
304+
expert_num_tokens, expected_m)

0 commit comments

Comments
 (0)