Skip to content

Commit e2de455

Browse files
authored
[Feature] Integrate SM100 DeepGEMM support (#20087)
1 parent 5b03235 commit e2de455

File tree

16 files changed

+397
-114
lines changed

16 files changed

+397
-114
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def benchmark_config(
8686
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
8787
)
8888
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
89+
if use_deep_gemm:
90+
# we use the default block shape for deepgemm
91+
block_quant_shape = [128, 128]
8992
if use_fp8_w8a8:
9093
if block_quant_shape:
9194
block_n, block_k = block_quant_shape[0], block_quant_shape[1]

tests/kernels/moe/test_block_fp8.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from vllm.model_executor.layers.fused_moe.fused_moe import (
1616
fused_topk, modular_triton_fused_moe)
1717
from vllm.platforms import current_platform
18+
from vllm.utils import has_deep_gemm
19+
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
1820

19-
dg_available = False
20-
try:
21-
import deep_gemm
22-
dg_available = True
23-
except ImportError:
24-
pass
21+
dg_available = has_deep_gemm()
22+
23+
if dg_available:
24+
from deep_gemm import get_m_alignment_for_contiguous_layout
2525

2626
if current_platform.get_device_capability() < (9, 0):
2727
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
@@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
224224
@pytest.mark.parametrize("topk", TOP_KS)
225225
@pytest.mark.parametrize("seed", SEEDS)
226226
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
227+
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
227228
@torch.inference_mode()
228229
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
229230
monkeypatch):
@@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
238239
torch.manual_seed(seed)
239240

240241
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
241-
242-
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
242+
block_m = get_m_alignment_for_contiguous_layout()
243243
block_size = [block_m, block_m]
244244
dtype = torch.bfloat16
245245

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
FusedMoEModularKernel)
2121
from vllm.platforms import current_platform
2222
from vllm.utils import has_deep_ep, has_deep_gemm
23+
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
2324

2425
from .parallel_utils import ProcessGroupInfo, parallel_launch
2526
from .utils import make_test_weights
@@ -368,6 +369,8 @@ def _test_deepep_deepgemm_moe(
368369
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
369370
@requires_deep_ep
370371
@requires_deep_gemm
372+
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
373+
reason="Skipping test for Blackwell DeepGEMM")
371374
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
372375
topk: int, world_dp_size: tuple[int, int]):
373376
"""
@@ -423,6 +426,8 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
423426
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
424427
@requires_deep_ep
425428
@requires_deep_gemm
429+
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
430+
reason="Skipping test for Blackwell DeepGEMM")
426431
def test_ll_deepep_deepgemm_moe(
427432
mnk: tuple[int, int, int],
428433
num_experts: int,

tests/kernels/moe/test_deepgemm.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,48 +13,18 @@
1313

1414
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
1515
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
16-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
17-
per_token_group_quant_fp8)
18-
from vllm.utils import cdiv
16+
from vllm.utils import has_deep_gemm
17+
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
18+
per_token_group_cast_to_fp8)
1919

20-
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
21-
22-
if has_deep_gemm:
23-
import deep_gemm
24-
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
25-
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
20+
BLOCK_SIZE = [128, 128]
2621

2722
requires_deep_gemm = pytest.mark.skipif(
28-
not has_deep_gemm,
23+
not has_deep_gemm(),
2924
reason="Requires deep_gemm kernels",
3025
)
3126

3227

33-
def calc_diff(x: torch.Tensor, y: torch.Tensor):
34-
x, y = x.double(), y.double()
35-
denominator = (x * x + y * y).sum()
36-
sim = 2 * (x * y).sum() / denominator
37-
return 1 - sim
38-
39-
40-
def per_block_cast_to_fp8(
41-
x: torch.Tensor,
42-
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
43-
assert x.dim() == 2
44-
m, n = x.shape
45-
x_padded = torch.zeros(
46-
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
47-
dtype=x.dtype,
48-
device=x.device)
49-
x_padded[:m, :n] = x
50-
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
51-
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
52-
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
53-
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
54-
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
55-
return x_scaled_sub, scales
56-
57-
5828
def make_block_quant_fp8_weights(
5929
e: int,
6030
n: int,
@@ -111,7 +81,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
11181
"""
11282
tokens_bf16 = torch.randn(
11383
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
114-
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
84+
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
11585

11686
# expert weight tensors
11787
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
@@ -155,17 +125,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
155125
block_shape=block_size,
156126
allow_deep_gemm=True,
157127
)
158-
159-
base = out_triton.abs().mean()
160-
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
161-
rtol = 0.05
162-
# ----- Compare -----
163-
torch.testing.assert_close(
164-
out_deepgemm.to(torch.float32),
165-
out_triton.to(torch.float32),
166-
rtol=rtol,
167-
atol=float(atol),
168-
)
128+
diff = calc_diff(out_deepgemm, out_triton)
129+
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
169130

170131

171132
# Note: W1 has shape (E, 2N, K), so N = 512

tests/kernels/quantization/test_block_fp8.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,15 @@
88
import torch
99

1010
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
11-
native_w8a8_block_matmul,
12-
per_block_cast_to_fp8)
11+
native_w8a8_block_matmul)
1312
from vllm.config import VllmConfig
1413
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
15-
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
14+
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
15+
w8a8_block_fp8_matmul)
1616
from vllm.platforms import current_platform
17-
18-
dg_available = False
19-
try:
20-
import deep_gemm
21-
dg_available = True
22-
except ImportError:
23-
pass
17+
from vllm.utils import has_deep_gemm
18+
from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
19+
per_token_group_cast_to_fp8)
2420

2521
if current_platform.get_device_capability() < (9, 0):
2622
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
@@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
106102
@pytest.mark.parametrize(
107103
"M,N,K,block_size,out_dtype,seed",
108104
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
109-
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
105+
@pytest.mark.skipif(not has_deep_gemm(),
106+
reason="DeepGemm kernels not available.")
110107
@torch.inference_mode()
111108
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
112109
# only aligned sizes
@@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
120117
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
121118
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
122119

123-
_, block_k = block_size[0], block_size[1]
124-
125-
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
120+
A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
126121
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
127122

128123
As = As_fp8.to(torch.float32)
@@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
132127
out_dtype)
133128

134129
# Transpose earlier so that the testing will not trigger transposing kernels
135-
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
130+
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
136131

137132
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
138133

139134
assert As_fp8.shape == (M, (K + 127) //
140135
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
141136

142-
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
137+
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
143138

144139
rel_diff = (torch.mean(
145140
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
TopKWeightAndReduceDelegate)
1212
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1313
from vllm.triton_utils import tl, triton
14+
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked
1415

1516
logger = init_logger(__name__)
1617

@@ -271,7 +272,6 @@ def apply(
271272
assert expert_tokens_meta is not None
272273
expert_num_tokens = expert_tokens_meta.expert_num_tokens
273274

274-
import deep_gemm as dg
275275
assert hidden_states.ndim == 3
276276
assert self.block_shape is not None
277277

@@ -289,18 +289,15 @@ def apply(
289289
# for the M expectation of each batch, correctly setting this value
290290
# may lead to better performance.
291291
expected_m = max_num_tokens
292-
293-
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale),
294-
(w1, w1_scale),
295-
out=workspace1,
296-
masked_m=expert_num_tokens,
297-
expected_m=expected_m)
292+
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
293+
out=workspace1,
294+
masked_m=expert_num_tokens,
295+
expected_m=expected_m)
298296

299297
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
300298
expert_num_tokens)
301299

302-
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
303-
(w2, w2_scale),
304-
out=output,
305-
masked_m=expert_num_tokens,
306-
expected_m=expected_m)
300+
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale),
301+
out=output,
302+
masked_m=expert_num_tokens,
303+
expected_m=expected_m)

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
MoEPrepareAndFinalizeNoEP)
1515
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1616
TopKWeightAndReduceDelegate)
17-
from vllm.model_executor.layers.fused_moe.utils import (
18-
_resize_cache, per_token_group_quant_fp8)
17+
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1918
from vllm.utils import has_deep_gemm, round_up
19+
from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous,
20+
per_token_group_cast_to_fp8)
2021

2122
logger = init_logger(__name__)
2223

@@ -127,7 +128,6 @@ def apply(
127128
workspace2: torch.Tensor,
128129
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
129130
):
130-
import deep_gemm as dg
131131
assert self.block_shape is not None
132132

133133
a1q = hidden_states
@@ -164,19 +164,19 @@ def apply(
164164
(M_sum, N // 2))
165165
mm2_out = _resize_cache(workspace2, (M_sum, K))
166166

167-
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
168-
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
167+
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
168+
mm1_out, expert_ids)
169169

170170
self.activation(activation, act_out, mm1_out.view(-1, N))
171171

172172
a2q_scale: Optional[torch.Tensor] = None
173-
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
174-
self.block_shape[1],
175-
column_major_scales=True,
176-
out_q=quant_out)
173+
a2q, a2q_scale = per_token_group_cast_to_fp8(act_out,
174+
self.block_shape[1],
175+
column_major_scales=True,
176+
out_q=quant_out)
177177

178-
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
179-
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
178+
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
179+
mm2_out, expert_ids)
180180

181181
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
182182

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from vllm.platforms import current_platform
3535
from vllm.triton_utils import tl, triton
3636
from vllm.utils import direct_register_custom_op
37+
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
3738

3839
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
3940

@@ -1171,9 +1172,15 @@ def fused_experts(
11711172
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
11721173
# For now, disable DeepGemm for small N (<= 512) until better
11731174
# permute/unpermute ops are available.
1175+
# However, on B200, we use DeepGemm for all cases becuase they only support
1176+
# E8M0 scale, which means we requantize the weight and input to the specific
1177+
# scale. Fallen back to cutlass or triton for some cases would cause
1178+
# accuracy issue.
11741179
N = w1.size(1)
1175-
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
1176-
and _valid_deep_gemm(hidden_states, w1, w2)):
1180+
should_use_deep_gemm = ((N > 512
1181+
and _valid_deep_gemm(hidden_states, w1, w2))
1182+
or is_blackwell_deep_gemm_used())
1183+
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
11771184
assert apply_router_weight_on_input is False
11781185
return deep_gemm_moe_fp8(
11791186
hidden_states=hidden_states,
@@ -1363,7 +1370,6 @@ def fused_experts_impl(
13631370

13641371
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
13651372
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
1366-
13671373
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
13681374
A=curr_hidden_states,
13691375
A_scale=a1_scale,

vllm/model_executor/layers/fused_moe/prepare_finalize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def prepare(
4848
assert topk == 1, \
4949
"apply_router_weight_on_input is only implemented for topk=1"
5050
a1.mul_(topk_weights.to(a1.dtype))
51-
5251
a1q, a1q_scale = moe_kernel_quantize_input(
5352
a1, a1_scale, quant_config.quant_dtype,
5453
quant_config.per_act_token_quant, quant_config.block_shape)

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
1010
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
1111
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
12+
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
1213

1314

1415
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -102,7 +103,8 @@ def workspace_shapes(
102103
# Note: the deep gemm workspaces are strictly larger than the triton
103104
# workspaces so we can be pessimistic here and allocate for DeepGemm
104105
# even if we fall back to triton later, e.g. if expert maps are set.
105-
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
106+
if self.allow_deep_gemm and (_valid_deep_gemm_shape(M, N, K)
107+
or is_blackwell_deep_gemm_used()):
106108
assert self.deep_gemm_expert is not None
107109
return self.deep_gemm_expert.workspace_shapes(
108110
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
@@ -132,7 +134,8 @@ def apply(
132134
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
133135
):
134136
use_deep_gemm = (self.allow_deep_gemm
135-
and _valid_deep_gemm(hidden_states, w1, w2))
137+
and (_valid_deep_gemm(hidden_states, w1, w2)
138+
or is_blackwell_deep_gemm_used()))
136139

137140
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
138141
assert experts is not None

0 commit comments

Comments
 (0)