Skip to content

Commit 909f234

Browse files
committed
stuff
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent f851058 commit 909f234

File tree

3 files changed

+111
-126
lines changed

3 files changed

+111
-126
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 46 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,26 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from dataclasses import dataclass
4+
from typing import Optional
45

56
import pytest
67
import torch
78
import triton.language as tl
8-
from typing import Optional
99

1010
import vllm._custom_ops as ops
1111
from vllm.config import VllmConfig, set_current_vllm_config
1212
from vllm.model_executor.layers.activation import SiluAndMul
1313
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
14-
invoke_moe_batched_triton_kernel,
15-
BatchedExperts,
16-
BatchedPrepareAndFinalize,
17-
BatchedTritonExperts)
18-
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
19-
get_default_config)
14+
BatchedPrepareAndFinalize, BatchedTritonExperts,
15+
invoke_moe_batched_triton_kernel)
16+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
2017
from vllm.model_executor.layers.fused_moe.modular_kernel import (
2118
FusedMoEModularKernel)
2219
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
23-
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
20+
per_token_group_quant_fp8)
2421
from vllm.platforms import current_platform
2522
from vllm.utils import round_up
2623

27-
2824
NUM_EXPERTS = [8, 64]
2925
TOP_KS = [1, 2, 6]
3026

@@ -80,10 +76,12 @@ def make_tensors(config: BatchedMMConfig):
8076
return BatchedMMTensors(A, B, C, num_expert_tokens)
8177

8278

83-
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
84-
As: torch.Tensor, Bs: torch.Tensor,
79+
def native_w8a8_block_matmul(A: torch.Tensor,
80+
B: torch.Tensor,
81+
As: torch.Tensor,
82+
Bs: torch.Tensor,
8583
block_size,
86-
output_dtype = torch.bfloat16):
84+
output_dtype=torch.bfloat16):
8785
"""This function performs matrix multiplication with block-wise
8886
quantization using native torch.
8987
It is agnostic to the input data type and can be used for both int8 and
@@ -160,16 +158,11 @@ def ref_impl(
160158
if A.dtype == torch.torch.float8_e4m3fn:
161159
if False:
162160
tmp = native_w8a8_block_matmul(A[e, :, :],
163-
B[e].transpose(0, 1),
164-
A_scale,
165-
B_scale,
166-
block_shape)
161+
B[e].transpose(0, 1), A_scale,
162+
B_scale, block_shape)
167163
else:
168-
tmp = ops.cutlass_scaled_mm(A[e, :, :],
169-
B[e].transpose(0, 1),
170-
A_scale,
171-
B_scale,
172-
torch.bfloat16)
164+
tmp = ops.cutlass_scaled_mm(A[e, :, :], B[e].transpose(0, 1),
165+
A_scale, B_scale, torch.bfloat16)
173166
C[e, :num_tokens, :] = tmp[:num_tokens, :]
174167
else:
175168
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
@@ -195,7 +188,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
195188
in_dtype = dtype
196189
out_dtype = dtype
197190

198-
config = BatchedMMConfig(in_dtype, out_dtype, num_experts, max_tokens_per_expert, K, N)
191+
config = BatchedMMConfig(in_dtype, out_dtype, num_experts,
192+
max_tokens_per_expert, K, N)
199193
tensors = BatchedMMTensors.make_tensors(config)
200194

201195
test_output = tensors.C
@@ -209,7 +203,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
209203
}[test_output.dtype]
210204

211205
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
212-
block_shape = [16, 16, 32] # 16 for k if not fp8
206+
block_shape = [16, 16, 32] # 16 for k if not fp8
213207

214208
#print(f"tensors.A {tensors.A.shape}")
215209
#print(f"tensors.B {tensors.B.shape}")
@@ -250,19 +244,12 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
250244

251245
ref_output = ref_output.to(dtype=out_dtype)
252246
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
253-
tensors.B.to(dtype=out_dtype),
254-
ref_output,
255-
tensors.num_expert_tokens,
256-
A_scale,
257-
B_scale,
247+
tensors.B.to(dtype=out_dtype), ref_output,
248+
tensors.num_expert_tokens, A_scale, B_scale,
258249
block_shape[-2:])
259250

260-
ref_output2 = ref_impl(tensors.A,
261-
tensors.B,
262-
ref_output2,
263-
tensors.num_expert_tokens,
264-
A_scale,
265-
B_scale,
251+
ref_output2 = ref_impl(tensors.A, tensors.B, ref_output2,
252+
tensors.num_expert_tokens, A_scale, B_scale,
266253
block_shape[-2:])
267254

268255
rtol, atol = {
@@ -286,11 +273,17 @@ def batched_moe(
286273
use_fp8_w8a8: bool = False,
287274
block_shape: Optional[list[int]] = None,
288275
) -> torch.Tensor:
289-
max_num_tokens = round_up(a.shape[0], 64) # ?
276+
max_num_tokens = round_up(a.shape[0], 64) # ?
290277
fused_experts = FusedMoEModularKernel(
291-
BatchedPrepareAndFinalize(max_num_tokens, world_size=1, dp_size=1, rank=0, use_fp8_w8a8=use_fp8_w8a8,
278+
BatchedPrepareAndFinalize(max_num_tokens,
279+
world_size=1,
280+
dp_size=1,
281+
rank=0,
282+
use_fp8_w8a8=use_fp8_w8a8,
292283
block_shape=block_shape),
293-
BatchedTritonExperts(max_num_tokens=max_num_tokens, dp_size=1, world_size=1,
284+
BatchedTritonExperts(max_num_tokens=max_num_tokens,
285+
dp_size=1,
286+
world_size=1,
294287
use_fp8_w8a8=use_fp8_w8a8,
295288
block_shape=block_shape))
296289

@@ -322,11 +315,13 @@ def torch_moe2(
322315

323316
if use_fp8_w8a8:
324317
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
325-
#print(f"a_scale {a_scale.shape}")
326318
else:
327319
a_scale = None
328320

329-
out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device)
321+
out = torch.zeros(M * topk,
322+
w2.shape[1],
323+
dtype=torch.bfloat16,
324+
device=a.device)
330325
num_experts = w1.shape[0]
331326
for i in range(num_experts):
332327
mask = (topk_ids == i).view(-1)
@@ -341,11 +336,8 @@ def torch_moe2(
341336
# a_scale[mask],
342337
# w1_scale[i],
343338
# torch.bfloat16)
344-
tmp1 = native_w8a8_block_matmul(a[mask],
345-
w1[i],
346-
a_scale[mask],
347-
w1_scale[i],
348-
block_shape,
339+
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
340+
w1_scale[i], block_shape,
349341
torch.bfloat16)
350342
tmp2 = SiluAndMul()(tmp1)
351343
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
@@ -355,11 +347,8 @@ def torch_moe2(
355347
# b_scale,
356348
# w2_scale[i],
357349
# torch.bfloat16)
358-
out[mask] = native_w8a8_block_matmul(tmp2,
359-
w2[i],
360-
b_scale,
361-
w2_scale[i],
362-
block_shape,
350+
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
351+
w2_scale[i], block_shape,
363352
torch.bfloat16)
364353

365354
return (out.view(M, -1, w2.shape[1]) *
@@ -406,23 +395,21 @@ def test_fused_moe_batched_experts(
406395

407396
factor_for_scale = 1e-2
408397
w1_s = torch.rand(
409-
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale
398+
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32,
399+
device="cuda") * factor_for_scale
410400
w2_s = torch.rand(
411-
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale
401+
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32,
402+
device="cuda") * factor_for_scale
412403
else:
413404
w1_s = None
414405
w2_s = None
415406

416407
with set_current_vllm_config(vllm_config):
417408
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
418-
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
419-
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
420-
# batched_output = batched_moe(a,
421-
# w1.to(torch.bfloat16),
422-
# w2.to(torch.bfloat16),
423-
# topk_weight, topk_ids,
424-
# w1_s, w2_s, False,
425-
# block_shape)
409+
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
410+
w2_s, use_fp8_w8a8, block_shape)
411+
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
412+
w2_s, use_fp8_w8a8, block_shape)
426413

427414
torch.testing.assert_close(baseline_output,
428415
batched_output,

0 commit comments

Comments
 (0)