Skip to content

Commit 42d440c

Browse files
authored
[Perf] Use Triton instead of Torch for DeepGEMM Per Token Group Quant (#20841)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent f45a332 commit 42d440c

File tree

6 files changed

+26
-42
lines changed

6 files changed

+26
-42
lines changed

tests/kernels/moe/test_deepgemm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313

1414
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
1515
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
16+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
17+
per_token_group_quant_fp8)
1618
from vllm.utils import has_deep_gemm
17-
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
18-
per_token_group_cast_to_fp8)
19+
from vllm.utils.deep_gemm import calc_diff, per_block_cast_to_fp8
1920

2021
BLOCK_SIZE = [128, 128]
2122

@@ -81,7 +82,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
8182
"""
8283
tokens_bf16 = torch.randn(
8384
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
84-
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
85+
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
8586

8687
# expert weight tensors
8788
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,

tests/kernels/quantization/test_block_fp8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
w8a8_block_fp8_matmul)
1616
from vllm.platforms import current_platform
1717
from vllm.utils import has_deep_gemm
18-
from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
19-
per_token_group_cast_to_fp8)
18+
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
2019

2120
if current_platform.get_device_capability() < (9, 0):
2221
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
@@ -117,7 +116,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
117116
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
118117
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
119118

120-
A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
119+
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
121120
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
122121

123122
As = As_fp8.to(torch.float32)

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1616
TopKWeightAndReduceDelegate)
1717
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
18+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
19+
per_token_group_quant_fp8)
1820
from vllm.utils import has_deep_gemm, round_up
19-
from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous,
20-
per_token_group_cast_to_fp8)
21+
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
2122

2223
logger = init_logger(__name__)
2324

@@ -170,10 +171,10 @@ def apply(
170171
self.activation(activation, act_out, mm1_out.view(-1, N))
171172

172173
a2q_scale: Optional[torch.Tensor] = None
173-
a2q, a2q_scale = per_token_group_cast_to_fp8(act_out,
174-
self.block_shape[1],
175-
column_major_scales=True,
176-
out_q=quant_out)
174+
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
175+
self.block_shape[1],
176+
column_major_scales=True,
177+
out_q=quant_out)
177178

178179
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
179180
mm2_out, expert_ids)

vllm/model_executor/layers/fused_moe/utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from vllm.platforms import current_platform
1616
from vllm.triton_utils import tl, triton
1717
from vllm.utils import cdiv
18-
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
19-
per_token_group_cast_to_fp8)
2018

2119

2220
@triton.jit
@@ -119,10 +117,7 @@ def _fp8_quantize(
119117
assert not per_act_token
120118
assert len(block_shape) == 2
121119
_, block_k = block_shape[0], block_shape[1]
122-
if is_blackwell_deep_gemm_used():
123-
A, A_scale = per_token_group_cast_to_fp8(A, block_k)
124-
else:
125-
A, A_scale = per_token_group_quant_fp8(A, block_k)
120+
A, A_scale = per_token_group_quant_fp8(A, block_k)
126121
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
127122

128123
return A, A_scale

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.platforms import current_platform
2121
from vllm.triton_utils import tl, triton
2222
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
23+
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
2324

2425
logger = init_logger(__name__)
2526

@@ -256,6 +257,7 @@ def _per_token_group_quant_fp8(
256257
# Information for float8
257258
fp8_min,
258259
fp8_max,
260+
use_ue8m0: tl.constexpr,
259261
# Meta-parameters
260262
BLOCK: tl.constexpr,
261263
):
@@ -285,7 +287,8 @@ def _per_token_group_quant_fp8(
285287
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
286288
# Quant
287289
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
288-
y_s = _absmax / fp8_max
290+
scale_raw = _absmax / fp8_max
291+
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
289292
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
290293

291294
tl.store(y_q_ptr + cols, y_q, mask=mask)
@@ -309,6 +312,7 @@ def _per_token_group_quant_fp8_colmajor(
309312
# Information for float8
310313
fp8_min,
311314
fp8_max,
315+
use_ue8m0: tl.constexpr,
312316
# Meta-parameters
313317
BLOCK: tl.constexpr,
314318
):
@@ -347,7 +351,8 @@ def _per_token_group_quant_fp8_colmajor(
347351
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
348352
# Quant
349353
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
350-
y_s = _absmax / fp8_max
354+
scale_raw = _absmax / fp8_max
355+
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
351356
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
352357

353358
tl.store(y_q_ptr + cols, y_q, mask=mask)
@@ -373,9 +378,11 @@ def per_token_group_quant_fp8(
373378
is supported for now.
374379
column_major_scales: Outputs scales in column major.
375380
out_q: Optional output tensor. If not provided, function will create.
376-
Returns:
377381
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
378382
scaling factor for quantization.
383+
Returns:
384+
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
385+
scaling factor.
379386
"""
380387
dtype = current_platform.fp8_dtype() if dtype is None else dtype
381388
assert (x.shape[-1] % group_size == 0), (
@@ -418,6 +425,7 @@ def per_token_group_quant_fp8(
418425
eps,
419426
fp8_min=fp8_min,
420427
fp8_max=fp8_max,
428+
use_ue8m0=is_blackwell_deep_gemm_used(),
421429
BLOCK=BLOCK,
422430
num_warps=num_warps,
423431
num_stages=num_stages,
@@ -433,6 +441,7 @@ def per_token_group_quant_fp8(
433441
eps,
434442
fp8_min=fp8_min,
435443
fp8_max=fp8_max,
444+
use_ue8m0=is_blackwell_deep_gemm_used(),
436445
BLOCK=BLOCK,
437446
num_warps=num_warps,
438447
num_stages=num_stages,

vllm/utils/deep_gemm.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
4949
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
5050
_grouped_impl: Callable[..., Any] | None = None
5151
_grouped_masked_impl: Callable[..., Any] | None = None
52-
_per_token_cast_impl: Callable[..., Any] | None = None
5352
_per_block_cast_impl: Callable[..., Any] | None = None
5453
else:
5554
_dg = importlib.import_module("deep_gemm") # type: ignore
@@ -74,12 +73,9 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
7473
try:
7574
_math_mod = importlib.import_module(
7675
"deep_gemm.utils.math") # type: ignore
77-
_per_token_cast_impl = getattr(_math_mod, "per_token_cast_to_fp8",
78-
None)
7976
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
8077
None)
8178
except ModuleNotFoundError:
82-
_per_token_cast_impl = None
8379
_per_block_cast_impl = None
8480

8581

@@ -101,22 +97,6 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
10197
return _grouped_masked_impl(*args, **kwargs)
10298

10399

104-
def per_token_group_cast_to_fp8(x, group_size, *args, **kwargs):
105-
"""Wrapper for token-wise FP8 quantisation.
106-
107-
• If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it.
108-
• Otherwise, fall back to vLLM's ``per_token_group_quant_fp8``
109-
"""
110-
111-
if _per_token_cast_impl is not None and is_blackwell_deep_gemm_used():
112-
assert group_size == 128, "group_size must be 128 for deepgemm"
113-
return _per_token_cast_impl(x)
114-
115-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
116-
per_token_group_quant_fp8 as _ptg)
117-
return _ptg(x, group_size, *args, **kwargs)
118-
119-
120100
def per_block_cast_to_fp8(x, *args, **kwargs):
121101
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
122102
return _per_block_cast_impl(x)
@@ -146,7 +126,6 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
146126
"fp8_gemm_nt",
147127
"m_grouped_fp8_gemm_nt_contiguous",
148128
"fp8_m_grouped_gemm_nt_masked",
149-
"per_token_group_cast_to_fp8",
150129
"per_block_cast_to_fp8",
151130
"is_blackwell_deep_gemm_used",
152131
]

0 commit comments

Comments
 (0)