Skip to content

Commit cb5c853

Browse files
committed
config stuff + add more tests
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 6e85ea1 commit cb5c853

33 files changed

+2111
-1197
lines changed

tests/kernels/moe/deepep_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,7 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
138138
rank=pgi.rank,
139139
dp_size=dp_size,
140140
rank_expert_offset=pgi.rank *
141-
ht_args.num_local_experts,
142-
quant_dtype=q_dtype,
143-
block_shape=block_shape)
141+
ht_args.num_local_experts)
144142

145143

146144
def make_deepep_ll_a2a(pg: ProcessGroup,
@@ -168,8 +166,6 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
168166
world_size=pgi.world_size,
169167
dp_size=dp_size,
170168
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
171-
quant_dtype=q_dtype,
172-
block_shape=block_shape,
173169
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
174170
)
175171

tests/kernels/moe/test_batched_moe.py

Lines changed: 182 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,38 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from dataclasses import dataclass
5+
from typing import Optional
56

67
import pytest
78
import torch
89
import triton.language as tl
910

11+
from tests.kernels.moe.utils import (
12+
batched_moe,
13+
make_test_weights,
14+
make_quantized_test_activations,
15+
torch_moe2,
16+
triton_moe)
17+
from tests.kernels.quant_utils import native_w8a8_block_matmul
18+
from vllm.config import VllmConfig, set_current_vllm_config
1019
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
1120
invoke_moe_batched_triton_kernel)
21+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
22+
from vllm.platforms import current_platform
23+
24+
NUM_EXPERTS = [8, 64]
25+
TOP_KS = [1, 2, 6]
26+
27+
vllm_config = VllmConfig()
28+
vllm_config.scheduler_config.max_num_seqs = 128
29+
vllm_config.scheduler_config.max_model_len = 8192
1230

1331

1432
@dataclass
1533
class BatchedMMConfig:
16-
dtype: torch.dtype
34+
in_dtype: torch.dtype
35+
quant_dtype: Optional[torch.dtype]
36+
out_dtype: torch.dtype
1737
num_experts: int
1838
max_tokens_per_expert: int
1939
K: int
@@ -32,84 +52,220 @@ def make_tensors(config: BatchedMMConfig):
3252
A = torch.randn(
3353
(config.num_experts, config.max_tokens_per_expert, config.K),
3454
device="cuda",
35-
dtype=config.dtype) / 10
55+
dtype=config.in_dtype) / 10
3656
B = torch.randn((config.num_experts, config.N, config.K),
3757
device="cuda",
38-
dtype=config.dtype)
58+
dtype=config.in_dtype)
3959
C = torch.zeros(
4060
(config.num_experts, config.max_tokens_per_expert, config.N),
4161
device="cuda",
42-
dtype=config.dtype)
62+
dtype=config.out_dtype)
63+
4364
num_expert_tokens = torch.randint(low=0,
4465
high=config.max_tokens_per_expert,
4566
size=(config.num_experts, ),
4667
device="cuda",
4768
dtype=torch.int32)
48-
return BatchedMMTensors(A, B, C, num_expert_tokens)
4969

5070

51-
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
52-
num_expert_tokens: torch.Tensor) -> torch.Tensor:
5371

72+
return BatchedMMTensors(A, B, C, num_expert_tokens)
73+
74+
75+
def ref_impl(
76+
A: torch.Tensor,
77+
B: torch.Tensor,
78+
C: torch.Tensor,
79+
num_expert_tokens: torch.Tensor,
80+
A_scale: Optional[torch.Tensor],
81+
B_scale: Optional[torch.Tensor],
82+
block_shape: Optional[list[int]],
83+
) -> torch.Tensor:
5484
num_expert_tokens_cpu = num_expert_tokens.clone()
5585
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
5686
num_experts = num_expert_tokens.size(0)
5787

88+
f32 = torch.float32
89+
bf16 = torch.bfloat16
90+
5891
for e in range(num_experts):
5992
num_tokens = num_expert_tokens_cpu[e]
60-
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
93+
if A.dtype.itemsize == 1 and block_shape is not None:
94+
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
95+
block_shape, C.dtype)
96+
C[e, :num_tokens, :] = tmp[:num_tokens, :]
97+
elif A.dtype.itemsize == 1 and block_shape is None:
98+
C[e, :num_tokens, :] = (
99+
(A[e, :num_tokens, :].to(f32) * A_scale[e]).to(bf16)
100+
@ (B[e].transpose(0, 1).to(f32) * B_scale[e]).to(bf16))
101+
else:
102+
assert A_scale is None
103+
assert B_scale is None
104+
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
61105

62106
return C
63107

64108

65-
@pytest.mark.parametrize("num_experts", [16, 32])
109+
@pytest.mark.parametrize("num_experts", [8, 16, 32])
66110
@pytest.mark.parametrize("max_tokens_per_expert",
67111
[32, 64, 128, 192, 224, 256, 512])
68112
@pytest.mark.parametrize("K", [128, 256, 1024])
69113
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
70-
@pytest.mark.parametrize("dtype",
71-
[torch.float32, torch.float16, torch.bfloat16])
114+
@pytest.mark.parametrize(
115+
"dtype",
116+
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
117+
@pytest.mark.parametrize("block_shape", [None])
118+
@pytest.mark.parametrize("per_act_token_quant", [False])
72119
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
73-
N: int, dtype: torch.dtype):
120+
N: int, dtype: torch.dtype, block_shape: Optional[list[int]],
121+
per_act_token_quant: bool):
122+
current_platform.seed_everything(7)
74123

75-
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
76-
tensors = BatchedMMTensors.make_tensors(config)
124+
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
77125

78-
test_output = tensors.C
79-
ref_output = test_output.clone()
126+
if block_shape is not None and not use_fp8_w8a8:
127+
pytest.skip("Don't test blocking for non-quantized types.")
128+
129+
if dtype.itemsize == 1:
130+
act_dtype = torch.bfloat16
131+
quant_dtype = dtype
132+
else:
133+
act_dtype = dtype
134+
quant_dtype = None
135+
136+
num_expert_tokens = torch.randint(low=0,
137+
high=max_tokens_per_expert,
138+
size=(num_experts, ),
139+
device="cuda",
140+
dtype=torch.int32)
141+
142+
A, A_q, A_scale = make_quantized_test_activations(
143+
num_experts,
144+
max_tokens_per_expert,
145+
K,
146+
in_dtype=act_dtype,
147+
quant_dtype=quant_dtype,
148+
block_shape=block_shape,
149+
per_act_token_quant=per_act_token_quant
150+
)
151+
152+
B, B_q, B_scale, _, _, _ = make_test_weights(
153+
num_experts,
154+
N // 2,
155+
K,
156+
quant_dtype=dtype,
157+
block_shape=block_shape,
158+
)
159+
160+
out_shape = (num_experts, max_tokens_per_expert, N)
161+
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
162+
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
163+
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
80164

81165
compute_tl_dtype = {
82166
torch.float16: tl.float16,
83167
torch.bfloat16: tl.bfloat16,
84168
torch.float32: tl.float32
85169
}[test_output.dtype]
170+
86171
invoke_moe_batched_triton_kernel(
87-
tensors.A,
88-
tensors.B,
172+
A_q,
173+
B_q,
89174
test_output,
90-
tensors.num_expert_tokens,
175+
num_expert_tokens,
91176
compute_tl_dtype,
92177
# Quantization data
93-
None,
94-
None,
178+
A_scale,
179+
B_scale,
95180
None,
96181
# Quantization schemes
97-
False,
182+
use_fp8_w8a8,
98183
False,
99184
False,
100185
config={
101186
"BLOCK_SIZE_M": 16,
102187
"BLOCK_SIZE_N": 16,
103188
"BLOCK_SIZE_K": 16
104-
})
189+
},
190+
block_shape=block_shape,
191+
)
192+
193+
ref_output = ref_impl(
194+
A,
195+
B,
196+
ref_output,
197+
num_expert_tokens,
198+
None,
199+
None,
200+
None,
201+
)
105202

106-
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
107-
tensors.num_expert_tokens)
203+
q_ref_output = ref_impl(A_q, B_q, q_ref_output, num_expert_tokens, A_scale,
204+
B_scale, block_shape)
108205

109206
rtol, atol = {
110207
torch.float16: (6e-2, 6e-2),
111208
torch.bfloat16: (6e-2, 6e-2),
112209
torch.float32: (1e-2, 1e-2),
113210
}[test_output.dtype]
114211

115-
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
212+
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
213+
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
214+
215+
216+
@pytest.mark.parametrize("m", [1, 32, 45, 64, 222])
217+
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
218+
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
219+
@pytest.mark.parametrize("e", NUM_EXPERTS)
220+
@pytest.mark.parametrize("topk", TOP_KS)
221+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
222+
@pytest.mark.parametrize("per_act_token_quant", [False])
223+
@pytest.mark.parametrize("block_shape", [None])
224+
def test_fused_moe_batched_experts(
225+
m: int,
226+
n: int,
227+
k: int,
228+
e: int,
229+
topk: int,
230+
dtype: torch.dtype,
231+
per_act_token_quant: bool,
232+
block_shape: Optional[list[int]],
233+
):
234+
current_platform.seed_everything(7)
235+
236+
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
237+
quant_type = torch.float8_e4m3fn if use_fp8_w8a8 else None
238+
239+
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
240+
pytest.skip("Skip quantization test for non-quantized type")
241+
242+
if per_act_token_quant and block_shape is not None or topk > e:
243+
pytest.skip("Skip illegal quantization test")
244+
245+
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
246+
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
247+
_, w1, w1_s, _, w2, w2_s = make_test_weights(e, n, k, block_shape=block_shape, quant_dtype=dtype)
248+
249+
torch.set_printoptions(profile="full")
250+
251+
with set_current_vllm_config(vllm_config):
252+
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
253+
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
254+
w2_s, quant_type, per_act_token_quant,
255+
block_shape)
256+
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
257+
w2_s, quant_type, per_act_token_quant,
258+
block_shape)
259+
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
260+
w2_s, quant_type, per_act_token_quant,
261+
block_shape)
262+
263+
torch.testing.assert_close(triton_output,
264+
baseline_output,
265+
atol=2e-2,
266+
rtol=2e-2)
267+
268+
torch.testing.assert_close(triton_output,
269+
batched_output,
270+
atol=2e-2,
271+
rtol=2e-2)

0 commit comments

Comments
 (0)