Skip to content

Commit 337320f

Browse files
committed
review comments, reduce test combinations, cleanup test code, etc.
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 1c89788 commit 337320f

File tree

8 files changed

+134
-41
lines changed

8 files changed

+134
-41
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,29 @@
1919
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
2020
from vllm.platforms import current_platform
2121

22+
MNK_FACTORS = [
23+
(1, 128, 128),
24+
(1, 128, 2048),
25+
(1, 512, 512),
26+
(1, 1024, 128),
27+
(1, 1024, 2048),
28+
(32, 128, 128),
29+
(32, 512, 512),
30+
(32, 1024, 2048),
31+
(45, 128, 128),
32+
(45, 128, 2048),
33+
(45, 512, 512),
34+
(45, 1024, 128),
35+
(45, 1024, 2048),
36+
(64, 128, 128),
37+
(64, 512, 512),
38+
(64, 1024, 2048),
39+
(222, 128, 128),
40+
(222, 128, 2048),
41+
(222, 512, 512),
42+
(222, 1024, 128),
43+
(222, 1024, 2048),
44+
]
2245
NUM_EXPERTS = [8, 64]
2346
TOP_KS = [1, 2, 6]
2447

@@ -182,9 +205,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
182205
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
183206

184207

185-
@pytest.mark.parametrize("m", [1, 32, 45, 64, 222])
186-
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
187-
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
208+
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
188209
@pytest.mark.parametrize("e", NUM_EXPERTS)
189210
@pytest.mark.parametrize("topk", TOP_KS)
190211
@pytest.mark.parametrize("dtype", [torch.bfloat16])

tests/kernels/moe/test_block_fp8.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import itertools
5-
64
import pytest
75
import torch
86

@@ -37,10 +35,62 @@
3735
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
3836
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
3937
# and its hidden size is 7168.
40-
M = [1, 83, 128, 2048, 8192]
41-
M_dg = [128, 192, 1335, 2048]
42-
N = [128, 256, 1024, 4608]
43-
K = [256, 512, 7168]
38+
MNK_FACTORS = [
39+
(1, 128, 128),
40+
(1, 512, 512),
41+
(1, 128, 7168),
42+
(1, 1024, 7168),
43+
(1, 4608, 128),
44+
(1, 4608, 512),
45+
(1, 4608, 7168),
46+
(83, 128, 128),
47+
(83, 512, 512),
48+
(83, 1024, 7168),
49+
(83, 4608, 512),
50+
(83, 4608, 7168),
51+
(128, 128, 128),
52+
(128, 512, 512),
53+
(128, 1024, 7168),
54+
(128, 4608, 512),
55+
(128, 4608, 7168),
56+
(2048, 128, 128),
57+
(2048, 1024, 7168),
58+
(2048, 4608, 512),
59+
(2048, 4608, 7168),
60+
(8192, 128, 128),
61+
(8192, 512, 512),
62+
(8192, 128, 7168),
63+
(8192, 1024, 7168),
64+
(8192, 4608, 512),
65+
(8192, 4608, 7168),
66+
]
67+
68+
MNK_FACTORS_DG = [
69+
(128, 128, 128),
70+
(128, 512, 512),
71+
(128, 128, 7168),
72+
(128, 1024, 7168),
73+
(128, 4608, 128),
74+
(128, 4608, 512),
75+
(128, 4608, 7168),
76+
(192, 128, 128),
77+
(192, 512, 512),
78+
(192, 1024, 7168),
79+
(192, 4608, 512),
80+
(192, 4608, 7168),
81+
(1335, 128, 128),
82+
(1335, 1024, 7168),
83+
(1335, 4608, 512),
84+
(1335, 4608, 7168),
85+
(2048, 128, 128),
86+
(2048, 512, 512),
87+
(2048, 128, 7168),
88+
(2048, 1024, 7168),
89+
(2048, 4608, 128),
90+
(2048, 4608, 512),
91+
(2048, 4608, 7168),
92+
]
93+
4494
BLOCK_SIZE = [[128, 128]]
4595
E = [2, 8, 16] # [128, 256]
4696
TOP_KS = [1, 2, 6]
@@ -92,9 +142,12 @@ def setup_cuda():
92142
torch.set_default_device("cuda")
93143

94144

95-
@pytest.mark.parametrize(
96-
"M,N,K,E,topk,block_size,dtype,seed",
97-
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
145+
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
146+
@pytest.mark.parametrize("E", E)
147+
@pytest.mark.parametrize("topk", TOP_KS)
148+
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
149+
@pytest.mark.parametrize("dtype", DTYPES)
150+
@pytest.mark.parametrize("seed", SEEDS)
98151
@torch.inference_mode()
99152
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
100153
monkeypatch):
@@ -166,8 +219,10 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
166219
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
167220

168221

169-
@pytest.mark.parametrize("M,N,K,E,topk,seed",
170-
itertools.product(M_dg, N, K, E, TOP_KS, SEEDS))
222+
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
223+
@pytest.mark.parametrize("E", E)
224+
@pytest.mark.parametrize("topk", TOP_KS)
225+
@pytest.mark.parametrize("seed", SEEDS)
171226
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
172227
@torch.inference_mode()
173228
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,

tests/kernels/moe/test_block_int8.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import itertools
5-
64
import pytest
75
import torch
86

@@ -23,9 +21,38 @@
2321
vllm_config.scheduler_config.max_model_len = 8192
2422

2523
DTYPES = [torch.half, torch.bfloat16]
26-
M = [1, 33, 64, 222]
27-
N = [128, 1024]
28-
K = [256, 4096]
24+
25+
MNK_FACTORS = [
26+
(1, 128, 128),
27+
(1, 512, 512),
28+
(1, 128, 7168),
29+
(1, 1024, 7168),
30+
(1, 4096, 128),
31+
(1, 4096, 512),
32+
(1, 4096, 7168),
33+
(33, 128, 128),
34+
(33, 512, 512),
35+
(33, 128, 7168),
36+
(33, 1024, 7168),
37+
(33, 4096, 128),
38+
(33, 4096, 512),
39+
(33, 4096, 7168),
40+
(128, 128, 128),
41+
(128, 512, 512),
42+
(128, 1024, 7168),
43+
(128, 4096, 512),
44+
(128, 4096, 7168),
45+
(222, 128, 128),
46+
(222, 512, 512),
47+
(222, 1024, 7168),
48+
(222, 4096, 512),
49+
(222, 4096, 7168),
50+
(2048, 128, 128),
51+
(2048, 1024, 7168),
52+
(2048, 4096, 512),
53+
(2048, 4096, 7168),
54+
]
55+
2956
E = [8, 24]
3057
TOP_KS = [2, 6]
3158
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
@@ -76,9 +103,12 @@ def setup_cuda():
76103
torch.set_default_device("cuda")
77104

78105

79-
@pytest.mark.parametrize(
80-
"M, N, K, E, topk, block_size, dtype, seed",
81-
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
106+
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
107+
@pytest.mark.parametrize("E", E)
108+
@pytest.mark.parametrize("topk", TOP_KS)
109+
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
110+
@pytest.mark.parametrize("dtype", DTYPES)
111+
@pytest.mark.parametrize("seed", SEEDS)
82112
@torch.inference_mode()
83113
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
84114
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,8 @@ def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
9797
n_b_scales = 2 * n if per_out_channel else 1
9898
k_b_scales = k if per_out_channel else 1
9999
# Get the right scale for tests.
100-
if False:
101-
_, a_scale = ops.scaled_fp8_quant(
102-
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
103-
a_q, _ = ops.scaled_fp8_quant(
104-
moe_tensors_fp16.a,
105-
a_scale,
106-
use_per_token_if_dynamic=per_act_token)
107-
else:
108-
a_q, a_scale = ops.scaled_fp8_quant(
109-
moe_tensors_fp16.a,
110-
None,
111-
use_per_token_if_dynamic=per_act_token)
100+
a_q, a_scale = ops.scaled_fp8_quant(
101+
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
112102

113103
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
114104
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)

tests/kernels/quant_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ def per_block_cast_to_fp8(
241241
return x_scaled_sub, scales
242242

243243

244-
# TODO: fix this
245244
def per_block_cast_to_int8(
246245
x: torch.Tensor,
247246
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
@@ -255,9 +254,9 @@ def per_block_cast_to_int8(
255254
x_padded[:m, :n] = x
256255
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
257256
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
258-
x_scaled = (x_view * (448.0 / x_amax)).to(torch.int8)
257+
x_scaled = (x_view * (256.0 / x_amax)).to(torch.int8)
259258
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
260-
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
259+
scales = (x_amax / 256.0).view(x_view.size(0), x_view.size(2))
261260
return x_scaled_sub, scales
262261

263262

vllm/_custom_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,8 +1275,7 @@ def scaled_fp8_quant(
12751275
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
12761276
else:
12771277
# num_token_padding not implemented for this case
1278-
assert (scale.numel() == 1 and num_token_padding
1279-
is None), f"{scale.shape} {num_token_padding}"
1278+
assert scale.numel() == 1, f"{scale.shape}"
12801279
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
12811280

12821281
return output, scale

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
3636

3737
# DeepEP low-latency kernels are compiled only for certain
3838
# specific hidden sizes.
39-
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
39+
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 7168]
4040

4141
def __init__(self,
4242
buffer: deep_ep.Buffer,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,6 @@ def inplace_fused_experts_fake(
10381038
pass
10391039

10401040

1041-
# TODO: get rid of these? replace with modular op?
10421041
direct_register_custom_op(
10431042
op_name="inplace_fused_experts",
10441043
op_func=inplace_fused_experts,

0 commit comments

Comments
 (0)