|
11 | 11 | TopKWeightAndReduceDelegate)
|
12 | 12 | from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
13 | 13 | from vllm.triton_utils import tl, triton
|
14 |
| -from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked |
| 14 | +from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, |
| 15 | + is_blackwell_deep_gemm_used) |
15 | 16 |
|
16 | 17 | logger = init_logger(__name__)
|
17 | 18 |
|
@@ -50,6 +51,7 @@ def _silu_mul_fp8_quant_deep_gemm(
|
50 | 51 | eps: tl.constexpr,
|
51 | 52 | fp8_min: tl.constexpr,
|
52 | 53 | fp8_max: tl.constexpr,
|
| 54 | + use_ue8m0: tl.constexpr, |
53 | 55 |
|
54 | 56 | # Meta ---------------------------------------------------------------
|
55 | 57 | BLOCK: tl.constexpr,
|
@@ -92,7 +94,9 @@ def _silu_mul_fp8_quant_deep_gemm(
|
92 | 94 | y = x * y2
|
93 | 95 |
|
94 | 96 | _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
95 |
| - y_s = _absmax / fp8_max |
| 97 | + scale_raw = _absmax / fp8_max |
| 98 | + y_s = tl.math.exp2(tl.ceil( |
| 99 | + tl.log2(scale_raw))) if use_ue8m0 else scale_raw |
96 | 100 | y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
97 | 101 |
|
98 | 102 | tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
|
@@ -174,6 +178,7 @@ def silu_mul_fp8_quant_deep_gemm(
|
174 | 178 | eps,
|
175 | 179 | fp8_min,
|
176 | 180 | fp8_max,
|
| 181 | + is_blackwell_deep_gemm_used(), |
177 | 182 | BLOCK=group_size,
|
178 | 183 | num_warps=4,
|
179 | 184 | )
|
@@ -290,14 +295,10 @@ def apply(
|
290 | 295 | # may lead to better performance.
|
291 | 296 | expected_m = max_num_tokens
|
292 | 297 | fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
|
293 |
| - out=workspace1, |
294 |
| - masked_m=expert_num_tokens, |
295 |
| - expected_m=expected_m) |
| 298 | + workspace1, expert_num_tokens, expected_m) |
296 | 299 |
|
297 | 300 | a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
|
298 | 301 | expert_num_tokens)
|
299 | 302 |
|
300 |
| - fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), |
301 |
| - out=output, |
302 |
| - masked_m=expert_num_tokens, |
303 |
| - expected_m=expected_m) |
| 303 | + fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, |
| 304 | + expert_num_tokens, expected_m) |
0 commit comments