Skip to content

Commit 5e22409

Browse files
committed
fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 7a95679 commit 5e22409

25 files changed

+235
-315
lines changed

requirements/test.txt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ argcomplete==3.5.1
3131
# via datamodel-code-generator
3232
arrow==1.3.0
3333
# via isoduration
34+
async-timeout==5.0.1
35+
# via
36+
# aiohttp
37+
# redis
3438
attrs==24.2.0
3539
# via
3640
# aiohttp
@@ -141,6 +145,11 @@ eval-type-backport==0.2.2
141145
# via mteb
142146
evaluate==0.4.3
143147
# via lm-eval
148+
exceptiongroup==1.3.0
149+
# via
150+
# anyio
151+
# hypothesis
152+
# pytest
144153
fastparquet==2024.11.0
145154
# via genai-perf
146155
fastrlock==0.8.2
@@ -690,7 +699,6 @@ setuptools==77.0.3
690699
# via
691700
# mamba-ssm
692701
# pytablewriter
693-
# torch
694702
# triton
695703
shellingham==1.5.4
696704
# via typer
@@ -753,8 +761,13 @@ tokenizers==0.21.1
753761
# via
754762
# -r requirements/test.in
755763
# transformers
764+
toml==0.10.2
765+
# via datamodel-code-generator
756766
tomli==2.2.1
757-
# via schemathesis
767+
# via
768+
# black
769+
# pytest
770+
# schemathesis
758771
tomli-w==1.2.0
759772
# via schemathesis
760773
torch==2.7.0+cu128
@@ -828,13 +841,18 @@ types-python-dateutil==2.9.0.20241206
828841
# via arrow
829842
typing-extensions==4.12.2
830843
# via
844+
# anyio
845+
# black
846+
# exceptiongroup
831847
# huggingface-hub
832848
# librosa
833849
# mistral-common
834850
# mteb
851+
# multidict
835852
# pqdm
836853
# pydantic
837854
# pydantic-core
855+
# rich
838856
# torch
839857
# typer
840858
# typing-inspection

tests/kernels/moe/test_batched_moe.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
import torch
99
import triton.language as tl
1010

11-
from tests.kernels.moe.utils import (
12-
batched_moe,
13-
make_test_weights,
14-
make_quantized_test_activations,
15-
torch_moe2,
16-
triton_moe)
11+
from tests.kernels.utils import torch_experts
12+
from tests.kernels.moe.utils import (batched_moe,
13+
make_quantized_test_activations,
14+
make_test_weights, triton_moe)
1715
from tests.kernels.quant_utils import native_w8a8_block_matmul
1816
from vllm.config import VllmConfig, set_current_vllm_config
1917
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
@@ -109,11 +107,13 @@ def ref_impl(
109107
[32, 64, 128, 192, 224, 256, 512])
110108
@pytest.mark.parametrize("K", [128, 256, 1024])
111109
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
112-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
110+
@pytest.mark.parametrize("dtype",
111+
[torch.float32, torch.float16, torch.bfloat16])
113112
@pytest.mark.parametrize("block_shape", [None])
114113
@pytest.mark.parametrize("per_act_token_quant", [False])
115114
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
116-
N: int, dtype: torch.dtype, block_shape: Optional[list[int]],
115+
N: int, dtype: torch.dtype,
116+
block_shape: Optional[list[int]],
117117
per_act_token_quant: bool):
118118
current_platform.seed_everything(7)
119119

@@ -144,8 +144,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
144144
in_dtype=act_dtype,
145145
quant_dtype=quant_dtype,
146146
block_shape=block_shape,
147-
per_act_token_quant=per_act_token_quant
148-
)
147+
per_act_token_quant=per_act_token_quant)
149148

150149
B, B_q, B_scale, _, _, _ = make_test_weights(
151150
num_experts,
@@ -252,7 +251,10 @@ def test_fused_moe_batched_experts(
252251
act_dtype = dtype
253252
quant_dtype = None
254253

255-
_, w1, w1_s, _, w2, w2_s = make_test_weights(e, n, k, block_shape=block_shape,
254+
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
255+
n,
256+
k,
257+
block_shape=block_shape,
256258
in_dtype=act_dtype,
257259
quant_dtype=quant_dtype)
258260

@@ -263,9 +265,11 @@ def test_fused_moe_batched_experts(
263265
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
264266
w2_s, quant_dtype, per_act_token_quant,
265267
block_shape)
266-
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
267-
w2_s, quant_dtype, per_act_token_quant,
268-
block_shape)
268+
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids,
269+
w1_scale=w1_s, w2_scale=w2_s,
270+
quant_dtype=quant_dtype,
271+
per_act_token_quant=per_act_token_quant,
272+
block_shape=block_shape)
269273
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
270274
w2_s, quant_dtype, per_act_token_quant,
271275
block_shape)

tests/kernels/moe/test_block_fp8.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import pytest
88
import torch
99

10-
from tests.kernels.quant_utils import (native_w8a8_block_matmul,
11-
native_per_token_group_quant_fp8,
10+
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
11+
native_w8a8_block_matmul,
1212
per_block_cast_to_fp8)
1313
from vllm.config import VllmConfig, set_current_vllm_config
1414
from vllm.model_executor.layers.activation import SiluAndMul
@@ -20,7 +20,7 @@
2020
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
2121
moe_align_block_size)
2222
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
23-
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
23+
per_token_group_quant_fp8)
2424
from vllm.platforms import current_platform
2525

2626
dg_available = False
@@ -261,9 +261,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
261261
return final_out
262262

263263

264-
@pytest.mark.parametrize(
265-
"M,N,K,E,topk,seed",
266-
itertools.product(M_dg, N, K, E, TOP_KS, SEEDS))
264+
@pytest.mark.parametrize("M,N,K,E,topk,seed",
265+
itertools.product(M_dg, N, K, E, TOP_KS, SEEDS))
267266
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
268267
@torch.inference_mode()
269268
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,

tests/kernels/moe/test_block_int8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import pytest
88
import torch
99

10-
from tests.kernels.quant_utils import (native_w8a8_block_matmul,
11-
native_per_token_group_quant_int8)
10+
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
11+
native_w8a8_block_matmul)
1212
from vllm.config import VllmConfig, set_current_vllm_config
1313
from vllm.model_executor.layers.activation import SiluAndMul
1414
from vllm.model_executor.layers.fused_moe import fused_moe

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
2929
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
3030

31-
3231
if has_deep_ep:
3332
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
3433
DeepEPHTPrepareAndFinalize)
@@ -69,8 +68,7 @@ def per_block_cast_to_fp8(
6968
assert x.dim() == 2
7069
m, n = x.shape
7170
x_padded = torch.zeros(
72-
(cdiv(m, 128) * 128,
73-
cdiv(n, block_size_n) * block_size_n),
71+
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
7472
dtype=x.dtype,
7573
device=x.device)
7674
x_padded[:m, :n] = x

tests/kernels/moe/test_pplx_moe.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,16 @@
1818
except ImportError:
1919
has_pplx = False
2020

21+
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
2122
from tests.kernels.utils import torch_experts
22-
from tests.kernels.moe.utils import (make_test_weights, naive_batched_moe)
2323
from vllm.config import VllmConfig, set_current_vllm_config
24-
from vllm.model_executor.layers.fused_moe import (
25-
override_config,
26-
fused_topk)
27-
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
24+
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
2825
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
29-
from vllm.model_executor.layers.fused_moe.modular_kernel import (
30-
FusedMoEModularKernel)
3126
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
3227
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
28+
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
29+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
30+
FusedMoEModularKernel)
3331
from vllm.platforms import current_platform
3432
from vllm.utils import round_up
3533

@@ -579,11 +577,14 @@ def _pplx_moe(
579577

580578
with set_current_vllm_config(vllm_config), override_config(moe_config):
581579
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
582-
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
583-
qtype, per_act_token_quant, block_shape)
584-
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, a,
585-
w1, w2, topk_weight, topk_ids, w1_s, w2_s, qtype,
586-
per_act_token_quant, block_shape)
580+
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids,
581+
w1_scale=w1_s, w2_scale=w2_s,
582+
quant_dtype=qtype,
583+
per_act_token_quant=per_act_token_quant,
584+
block_shape=block_shape)
585+
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
586+
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
587+
qtype, per_act_token_quant, block_shape)
587588
# TODO (bnell): fix + re-enable
588589
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
589590
# topk_ids)
@@ -601,7 +602,7 @@ def _pplx_moe(
601602
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
602603
@pytest.mark.parametrize("e", NUM_EXPERTS)
603604
@pytest.mark.parametrize("topk", TOP_KS)
604-
@pytest.mark.parametrize("dtype", [torch.bfloat16]) # torch.float8_e4m3fn,
605+
@pytest.mark.parametrize("dtype", [torch.bfloat16]) # torch.float8_e4m3fn,
605606
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
606607
@pytest.mark.parametrize("per_act_token_quant", [False, True])
607608
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@@ -634,8 +635,11 @@ def test_pplx_moe(
634635
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
635636
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
636637

637-
_, w1, w1_s, _, w2, w2_s = make_test_weights(
638-
e, n, k, quant_dtype=quant_dtype, block_shape=block_shape)
638+
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
639+
n,
640+
k,
641+
quant_dtype=quant_dtype,
642+
block_shape=block_shape)
639643

640644
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
641645
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,

0 commit comments

Comments
 (0)