Skip to content

Commit d2c2790

Browse files
committed
lint
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 604ab02 commit d2c2790

21 files changed

+379
-477
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,16 @@
88
import torch
99
import triton.language as tl
1010

11-
import vllm._custom_ops as ops
11+
from tests.kernels.moe.utils import (batched_moe, make_test_weights,
12+
torch_moe2, triton_moe)
13+
from tests.kernels.quant_utils import native_w8a8_block_matmul
1214
from vllm.config import VllmConfig, set_current_vllm_config
1315
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
14-
BatchedPrepareAndFinalize, BatchedTritonExperts,
1516
invoke_moe_batched_triton_kernel)
16-
from vllm.model_executor.layers.fused_moe.utils import (
17-
moe_kernel_quantize_input)
1817
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
19-
from vllm.model_executor.layers.fused_moe.modular_kernel import (
20-
FusedMoEModularKernel)
2118
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
22-
w8a8_block_fp8_matmul,
2319
per_token_group_quant_fp8)
2420
from vllm.platforms import current_platform
25-
from tests.kernels.quant_utils import native_w8a8_block_matmul
26-
from tests.kernels.moe.utils import (
27-
torch_moe2,
28-
triton_moe,
29-
batched_moe,
30-
make_test_weights,
31-
)
3221

3322
NUM_EXPERTS = [8, 64]
3423
TOP_KS = [1, 2, 6]
@@ -104,18 +93,13 @@ def ref_impl(
10493
for e in range(num_experts):
10594
num_tokens = num_expert_tokens_cpu[e]
10695
if A.dtype.itemsize == 1 and block_shape is not None:
107-
tmp = native_w8a8_block_matmul(A[e],
108-
B[e],
109-
A_scale[e],
110-
B_scale[e],
111-
block_shape,
112-
C.dtype)
96+
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
97+
block_shape, C.dtype)
11398
C[e, :num_tokens, :] = tmp[:num_tokens, :]
11499
elif A.dtype.itemsize == 1 and block_shape is None:
115100
C[e, :num_tokens, :] = (
116-
(A[e, :num_tokens, :].to(f32) * A_scale[e]).to(bf16) @
117-
(B[e].transpose(0, 1).to(f32) * B_scale[e]).to(bf16)
118-
)
101+
(A[e, :num_tokens, :].to(f32) * A_scale[e]).to(bf16)
102+
@ (B[e].transpose(0, 1).to(f32) * B_scale[e]).to(bf16))
119103
else:
120104
assert A_scale is None
121105
assert B_scale is None
@@ -124,7 +108,8 @@ def ref_impl(
124108
return C
125109

126110

127-
def make_quantized_test_activations(E, m, k, dtype, block_shape, per_act_token):
111+
def make_quantized_test_activations(E, m, k, dtype, block_shape,
112+
per_act_token):
128113
assert not per_act_token, "NYI"
129114

130115
a_type = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
@@ -138,9 +123,11 @@ def make_quantized_test_activations(E, m, k, dtype, block_shape, per_act_token):
138123
a_scale = [None] * E
139124
for e in range(E):
140125
if block_shape is not None:
141-
a_q[e], a_scale[e] = per_token_group_quant_fp8(a[e], block_shape[1])
126+
a_q[e], a_scale[e] = per_token_group_quant_fp8(
127+
a[e], block_shape[1])
142128
else:
143-
a_tmp, a_scale[e] = per_token_group_quant_fp8(a[e].view(1, -1), a[e].numel())
129+
a_tmp, a_scale[e] = per_token_group_quant_fp8(
130+
a[e].view(1, -1), a[e].numel())
144131
a_q[e] = a_tmp.view(*a[e].shape)
145132
a_scale = torch.stack(a_scale)
146133

@@ -173,14 +160,10 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
173160
device="cuda",
174161
dtype=torch.int32)
175162

176-
A, A_q, A_scale = make_quantized_test_activations(
177-
num_experts,
178-
max_tokens_per_expert,
179-
K,
180-
dtype,
181-
block_shape,
182-
per_act_token_quant
183-
)
163+
A, A_q, A_scale = make_quantized_test_activations(num_experts,
164+
max_tokens_per_expert, K,
165+
dtype, block_shape,
166+
per_act_token_quant)
184167

185168
B_q, _, B_scale, _, B, _ = make_test_weights(
186169
num_experts,
@@ -206,7 +189,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
206189

207190
#print(f"A {use_fp8_w8a8} {A_q.dtype} {B_q.dtype} {A_scale.shape} {B_scale.shape}")
208191
if False:
209-
from vllm.model_executor.layers.fused_moe.batched_moe2 import fused_moe_kernel2
192+
from vllm.model_executor.layers.fused_moe.batched_moe2 import (
193+
fused_moe_kernel2)
210194
fused_moe_kernel2(
211195
A_q,
212196
B_q,
@@ -238,7 +222,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
238222
config_block_shape[1],
239223
config_block_shape[2],
240224
1,
241-
1, # topk hack
225+
1, # topk hack
242226
compute_tl_dtype,
243227
use_fp8_w8a8,
244228
False,
@@ -279,15 +263,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
279263
None,
280264
)
281265

282-
q_ref_output = ref_impl(
283-
A_q,
284-
B_q,
285-
q_ref_output,
286-
num_expert_tokens,
287-
A_scale,
288-
B_scale,
289-
block_shape
290-
)
266+
q_ref_output = ref_impl(A_q, B_q, q_ref_output, num_expert_tokens, A_scale,
267+
B_scale, block_shape)
291268

292269
rtol, atol = {
293270
torch.float16: (6e-2, 6e-2),
@@ -393,11 +370,14 @@ def test_fused_moe_batched_experts(
393370
with set_current_vllm_config(vllm_config):
394371
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
395372
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
396-
w2_s, quant_type, per_act_token_quant, block_shape)
373+
w2_s, quant_type, per_act_token_quant,
374+
block_shape)
397375
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
398-
w2_s, quant_type, per_act_token_quant, block_shape)
376+
w2_s, quant_type, per_act_token_quant,
377+
block_shape)
399378
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
400-
w2_s, quant_type, per_act_token_quant, block_shape)
379+
w2_s, quant_type, per_act_token_quant,
380+
block_shape)
401381

402382
torch.testing.assert_close(triton_output,
403383
baseline_output,
@@ -411,4 +391,4 @@ def test_fused_moe_batched_experts(
411391
torch.testing.assert_close(triton_output,
412392
batched_output,
413393
atol=2e-2,
414-
rtol=2e-2) # 0
394+
rtol=2e-2) # 0

tests/kernels/moe/test_pplx_moe.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,37 +18,26 @@
1818
except ImportError:
1919
has_pplx = False
2020

21+
#from tests.kernels.quant_utils import native_w8a8_block_matmul
22+
from tests.kernels.moe.utils import (make_test_weights, naive_batched_moe,
23+
torch_moe2)
24+
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
2125
from vllm.config import VllmConfig, set_current_vllm_config
22-
from vllm.model_executor.layers.activation import SiluAndMul
2326
from vllm.model_executor.layers.fused_moe import override_config
2427
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
25-
NaiveBatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
28+
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
2629
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
2730
get_default_config)
2831
from vllm.model_executor.layers.fused_moe.modular_kernel import (
2932
FusedMoEModularKernel)
30-
from vllm.model_executor.layers.fused_moe.utils import (
31-
moe_kernel_quantize_input)
32-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
33-
per_token_group_quant_fp8)
3433
from vllm.platforms import current_platform
3534
from vllm.utils import round_up
3635

37-
from .deepep_utils import ProcessGroupInfo, parallel_launch
38-
39-
from tests.kernels.moe.utils import (
40-
torch_moe2,
41-
naive_batched_moe,
42-
make_test_weights,
43-
)
44-
45-
4636
requires_pplx = pytest.mark.skipif(
4737
not has_pplx,
4838
reason="Requires PPLX kernels",
4939
)
5040

51-
5241
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
5342
(222, 2048, 1024)]
5443

@@ -430,13 +419,11 @@ def pplx_moe(
430419
dp_size,
431420
)
432421

433-
experts = BatchedTritonExperts(
434-
max_num_tokens=max_num_tokens,
435-
world_size=world_size,
436-
dp_size=dp_size,
437-
use_fp8_w8a8=qtype==torch.float8_e4m3fn,
438-
block_shape=block_shape
439-
)
422+
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
423+
world_size=world_size,
424+
dp_size=dp_size,
425+
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
426+
block_shape=block_shape)
440427

441428
fused_experts = FusedMoEModularKernel(
442429
prepare_finalize,
@@ -584,8 +571,7 @@ def _pplx_moe(
584571
with set_current_vllm_config(vllm_config), override_config(moe_config):
585572
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
586573
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
587-
qtype, per_act_token_quant,
588-
block_shape)
574+
qtype, per_act_token_quant, block_shape)
589575
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, a,
590576
w1, w2, topk_weight, topk_ids, w1_s, w2_s, qtype,
591577
per_act_token_quant, block_shape)
@@ -633,7 +619,8 @@ def test_pplx_moe(
633619
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
634620
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
635621

636-
w1, w2, w1_s, w2_s, w1_16, w2_16 = make_test_weights(e, n, k, block_shape, dtype)
622+
w1, w2, w1_s, w2_s, w1_16, w2_16 = make_test_weights(
623+
e, n, k, block_shape, dtype)
637624

638625
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
639626
w1_s, w2_s, dtype, per_act_token_quant, block_shape,

0 commit comments

Comments
 (0)