Skip to content

Commit f851058

Browse files
committed
tests + fix
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 9cfebf5 commit f851058

File tree

2 files changed

+232
-85
lines changed

2 files changed

+232
-85
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 182 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,30 @@
77
import triton.language as tl
88
from typing import Optional
99

10+
import vllm._custom_ops as ops
11+
from vllm.config import VllmConfig, set_current_vllm_config
12+
from vllm.model_executor.layers.activation import SiluAndMul
1013
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
11-
invoke_moe_batched_triton_kernel)
14+
invoke_moe_batched_triton_kernel,
15+
BatchedExperts,
16+
BatchedPrepareAndFinalize,
17+
BatchedTritonExperts)
18+
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
19+
get_default_config)
20+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
21+
FusedMoEModularKernel)
22+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
23+
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
24+
from vllm.platforms import current_platform
25+
from vllm.utils import round_up
26+
27+
28+
NUM_EXPERTS = [8, 64]
29+
TOP_KS = [1, 2, 6]
30+
31+
vllm_config = VllmConfig()
32+
vllm_config.scheduler_config.max_num_seqs = 128
33+
vllm_config.scheduler_config.max_model_len = 8192
1234

1335

1436
@dataclass
@@ -141,14 +163,13 @@ def ref_impl(
141163
B[e].transpose(0, 1),
142164
A_scale,
143165
B_scale,
144-
[1,1])#block_shape)
166+
block_shape)
145167
else:
146-
import vllm._custom_ops as ops
147168
tmp = ops.cutlass_scaled_mm(A[e, :, :],
148169
B[e].transpose(0, 1),
149170
A_scale,
150171
B_scale,
151-
C.dtype)
172+
torch.bfloat16)
152173
C[e, :num_tokens, :] = tmp[:num_tokens, :]
153174
else:
154175
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
@@ -194,8 +215,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
194215
#print(f"tensors.B {tensors.B.shape}")
195216

196217
if use_fp8_w8a8:
197-
#A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device)
218+
#A_scale = torch.ones((1, K), dtype=torch.float32, device=tensors.A.device)
198219
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
220+
#quant_block_shape = [N, K]
199221
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
200222
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
201223
quant_block_shape = [1, 1]
@@ -251,3 +273,158 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
251273

252274
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
253275
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
276+
277+
278+
def batched_moe(
279+
a: torch.Tensor,
280+
w1: torch.Tensor,
281+
w2: torch.Tensor,
282+
topk_weight: torch.Tensor,
283+
topk_ids: torch.Tensor,
284+
w1_scale: Optional[torch.Tensor] = None,
285+
w2_scale: Optional[torch.Tensor] = None,
286+
use_fp8_w8a8: bool = False,
287+
block_shape: Optional[list[int]] = None,
288+
) -> torch.Tensor:
289+
max_num_tokens = round_up(a.shape[0], 64) # ?
290+
fused_experts = FusedMoEModularKernel(
291+
BatchedPrepareAndFinalize(max_num_tokens, world_size=1, dp_size=1, rank=0, use_fp8_w8a8=use_fp8_w8a8,
292+
block_shape=block_shape),
293+
BatchedTritonExperts(max_num_tokens=max_num_tokens, dp_size=1, world_size=1,
294+
use_fp8_w8a8=use_fp8_w8a8,
295+
block_shape=block_shape))
296+
297+
return fused_experts(a,
298+
w1,
299+
w2,
300+
topk_weight,
301+
topk_ids,
302+
w1_scale=w1_scale,
303+
w2_scale=w2_scale)
304+
305+
306+
# Note: same as torch_moe but with fused_topk factored out.
307+
def torch_moe2(
308+
a: torch.Tensor,
309+
w1: torch.Tensor,
310+
w2: torch.Tensor,
311+
topk_weight: torch.Tensor,
312+
topk_ids: torch.Tensor,
313+
w1_scale: Optional[torch.Tensor] = None,
314+
w2_scale: Optional[torch.Tensor] = None,
315+
use_fp8_w8a8: bool = False,
316+
block_shape: Optional[list[int]] = None,
317+
) -> torch.Tensor:
318+
M, K = a.shape
319+
topk = topk_ids.shape[1]
320+
321+
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
322+
323+
if use_fp8_w8a8:
324+
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
325+
#print(f"a_scale {a_scale.shape}")
326+
else:
327+
a_scale = None
328+
329+
out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device)
330+
num_experts = w1.shape[0]
331+
for i in range(num_experts):
332+
mask = (topk_ids == i).view(-1)
333+
if mask.sum():
334+
if not use_fp8_w8a8:
335+
tmp1 = a[mask] @ w1[i].transpose(0, 1)
336+
tmp2 = SiluAndMul()(tmp1)
337+
out[mask] = tmp2 @ w2[i].transpose(0, 1)
338+
else:
339+
#tmp1 = ops.cutlass_scaled_mm(a[mask],
340+
# w1[i].transpose(0, 1),
341+
# a_scale[mask],
342+
# w1_scale[i],
343+
# torch.bfloat16)
344+
tmp1 = native_w8a8_block_matmul(a[mask],
345+
w1[i],
346+
a_scale[mask],
347+
w1_scale[i],
348+
block_shape,
349+
torch.bfloat16)
350+
tmp2 = SiluAndMul()(tmp1)
351+
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
352+
353+
# out[mask] = ops.cutlass_scaled_mm(tmp2,
354+
# w2[i].transpose(0, 1),
355+
# b_scale,
356+
# w2_scale[i],
357+
# torch.bfloat16)
358+
out[mask] = native_w8a8_block_matmul(tmp2,
359+
w2[i],
360+
b_scale,
361+
w2_scale[i],
362+
block_shape,
363+
torch.bfloat16)
364+
365+
return (out.view(M, -1, w2.shape[1]) *
366+
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
367+
368+
369+
@pytest.mark.parametrize("m", [1, 33, 64, 222])
370+
@pytest.mark.parametrize("n", [128, 1024, 2048])
371+
@pytest.mark.parametrize("k", [128, 512, 1024])
372+
@pytest.mark.parametrize("e", NUM_EXPERTS)
373+
@pytest.mark.parametrize("topk", TOP_KS)
374+
@pytest.mark.parametrize("dtype", [torch.torch.float8_e4m3fn, torch.bfloat16])
375+
def test_fused_moe_batched_experts(
376+
m: int,
377+
n: int,
378+
k: int,
379+
e: int,
380+
topk: int,
381+
dtype: torch.dtype,
382+
):
383+
current_platform.seed_everything(7)
384+
block_shape = [128, 128]
385+
386+
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
387+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
388+
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
389+
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
390+
391+
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
392+
393+
if use_fp8_w8a8:
394+
block_n, block_k = block_shape[0], block_shape[1]
395+
n_tiles_w1 = (2 * n + block_n - 1) // block_n
396+
n_tiles_w2 = (k + block_n - 1) // block_n
397+
k_tiles_w1 = (k + block_k - 1) // block_k
398+
k_tiles_w2 = (n + block_k - 1) // block_k
399+
400+
finfo = torch.finfo(dtype)
401+
fp8_min = finfo.min
402+
fp8_max = finfo.max
403+
404+
w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype)
405+
w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype)
406+
407+
factor_for_scale = 1e-2
408+
w1_s = torch.rand(
409+
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale
410+
w2_s = torch.rand(
411+
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale
412+
else:
413+
w1_s = None
414+
w2_s = None
415+
416+
with set_current_vllm_config(vllm_config):
417+
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
418+
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
419+
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
420+
# batched_output = batched_moe(a,
421+
# w1.to(torch.bfloat16),
422+
# w2.to(torch.bfloat16),
423+
# topk_weight, topk_ids,
424+
# w1_s, w2_s, False,
425+
# block_shape)
426+
427+
torch.testing.assert_close(baseline_output,
428+
batched_output,
429+
atol=2e-2,
430+
rtol=0)

0 commit comments

Comments
 (0)