Skip to content

Commit 23f26c9

Browse files
committed
lint
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent bd9bd37 commit 23f26c9

File tree

7 files changed

+70
-50
lines changed

7 files changed

+70
-50
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
Run `pytest tests/kernels/test_moe.py`.
66
"""
77
import functools
8+
from typing import Callable, Optional, Union
9+
810
import pytest
911
import torch
10-
1112
from torch.nn import Parameter
1213
from torch.nn import functional as F
1314
from transformers import MixtralConfig
1415
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
15-
from typing import Callable, Optional, Union
1616

1717
import vllm.model_executor.layers.fused_moe # noqa
1818
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
@@ -56,13 +56,19 @@ def run_moe_test(
5656
padding: bool = False,
5757
use_compile: bool = False,
5858
use_cudagraph: bool = False,
59-
atol:float=2e-2,
60-
rtol:float=0,
59+
atol: float = 2e-2,
60+
rtol: float = 0,
6161
) -> torch.Tensor:
6262
if isinstance(baseline, torch.Tensor):
6363
baseline_output = baseline
6464
else:
65-
baseline_output = baseline(a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map)
65+
baseline_output = baseline(a,
66+
w1,
67+
w2,
68+
score,
69+
topk,
70+
global_num_experts=global_num_experts,
71+
expert_map=expert_map)
6672

6773
# Pad the weight if moe padding is enabled
6874
if padding:
@@ -96,7 +102,10 @@ def run_moe_test(
96102
graph.replay()
97103
torch.cuda.synchronize()
98104

99-
torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
105+
torch.testing.assert_close(test_output,
106+
baseline_output,
107+
atol=atol,
108+
rtol=rtol)
100109

101110
return baseline_output
102111

@@ -167,7 +176,7 @@ def m_fused_moe(
167176
score: torch.Tensor,
168177
topk: int,
169178
global_num_experts: int = -1,
170-
expert_map: Optional[torch.Tensor]= None,
179+
expert_map: Optional[torch.Tensor] = None,
171180
) -> torch.Tensor:
172181
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
173182
return m_fused_moe_fn(a,
@@ -195,13 +204,20 @@ def m_fused_moe(
195204
padding=padding,
196205
)
197206

198-
use_compile = m >= chunk_size and n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
207+
use_compile = (m >= chunk_size and n >= 1024 and k >= 1024
208+
and current_platform.is_cuda_alike())
199209
use_cudagraph = use_compile
200210

201211
with set_current_vllm_config(vllm_config):
202212
baseline_output = runner(torch_moe, iterative_moe)
203-
runner(baseline_output, fused_moe_fn, use_compile=use_compile, use_cudagraph=use_cudagraph)
204-
runner(baseline_output, m_fused_moe, use_compile=use_compile, use_cudagraph=use_cudagraph)
213+
runner(baseline_output,
214+
fused_moe_fn,
215+
use_compile=use_compile,
216+
use_cudagraph=use_cudagraph)
217+
runner(baseline_output,
218+
m_fused_moe,
219+
use_compile=use_compile,
220+
use_cudagraph=use_cudagraph)
205221

206222

207223
@pytest.mark.parametrize("m", [1, 32, 222])
@@ -311,7 +327,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
311327
w1_zp=w1_qzeros if has_zp else None,
312328
w2_zp=w2_qzeros if has_zp else None,
313329
block_shape=[0, group_size])
314-
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map)
330+
torch_output = torch_moe(a,
331+
w1_ref,
332+
w2_ref,
333+
score,
334+
topk,
335+
expert_map=e_map)
315336

316337
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
317338

@@ -619,7 +640,12 @@ def test_fused_marlin_moe(
619640
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
620641

621642
with set_current_vllm_config(vllm_config):
622-
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
643+
torch_output = torch_moe(a,
644+
w_ref1,
645+
w_ref2,
646+
score,
647+
topk,
648+
expert_map=e_map)
623649

624650
marlin_output = torch.ops.vllm.fused_marlin_moe(
625651
a,

tests/kernels/moe/test_pplx_cutlass_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,15 @@
66
import pytest
77
import torch
88

9+
from tests.kernels.utils import torch_experts
910
from vllm import _custom_ops as ops
1011
from vllm.config import VllmConfig, set_current_vllm_config
11-
from vllm.model_executor.layers.activation import SiluAndMul
1212
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
1313
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
1414
from vllm.model_executor.layers.fused_moe.modular_kernel import (
1515
FusedMoEModularKernel)
1616
from vllm.platforms import current_platform
1717

18-
from tests.kernels.utils import torch_experts
19-
2018
from .deepep_utils import ProcessGroupInfo, parallel_launch
2119

2220
try:

tests/kernels/moe/test_pplx_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
except ImportError:
1919
has_pplx = False
2020

21+
from tests.kernels.utils import torch_experts
2122
from vllm.config import VllmConfig, set_current_vllm_config
22-
from vllm.model_executor.layers.activation import SiluAndMul
2323
from vllm.model_executor.layers.fused_moe import override_config
2424
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
2525
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
@@ -29,8 +29,6 @@
2929
FusedMoEModularKernel)
3030
from vllm.platforms import current_platform
3131

32-
from tests.kernels.utils import torch_experts
33-
3432
from .deepep_utils import ProcessGroupInfo, parallel_launch
3533

3634
requires_pplx = pytest.mark.skipif(

tests/kernels/quantization/test_block_fp8.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
403403
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
404404
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
405405
@torch.inference_mode()
406-
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch):
406+
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
407+
monkeypatch):
407408
if topk > E:
408409
pytest.skip(f"Skipping test: topk={topk} > E={E}")
409410

@@ -455,7 +456,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
455456
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
456457
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
457458

458-
use_compile = M > chunk_size and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
459+
use_compile = (chunk_size < M and N >= 1024 and K >= 1024
460+
and current_platform.is_cuda_alike())
459461
use_cudagraph = use_compile
460462

461463
# Set the context to avoid lots of warning spam.
@@ -477,14 +479,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
477479
else:
478480
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
479481

480-
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
482+
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
483+
topk_ids)
481484

482485
if use_cudagraph:
483486
out.fill_(0)
484487
stream = torch.cuda.Stream()
485488
graph = torch.cuda.CUDAGraph()
486489
with torch.cuda.graph(graph, stream=stream):
487-
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
490+
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
491+
topk_ids)
488492
torch.cuda.synchronize()
489493
graph.replay()
490494
torch.cuda.synchronize()

tests/kernels/utils.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,18 +1054,16 @@ def compute_max_diff(output, output_ref):
10541054
torch.abs(output_ref))
10551055

10561056

1057-
def torch_experts(
1058-
a: torch.Tensor,
1059-
w1: torch.Tensor,
1060-
w2: torch.Tensor,
1061-
topk_weight: torch.Tensor,
1062-
topk_ids: torch.Tensor,
1063-
global_num_experts: int = -1,
1064-
expert_map: Optional[torch.Tensor] = None
1065-
) -> torch.Tensor:
1066-
assert (global_num_experts == -1 or
1067-
(global_num_experts == w1.shape[0] and expert_map is None) or
1068-
global_num_experts == expert_map.shape[0])
1057+
def torch_experts(a: torch.Tensor,
1058+
w1: torch.Tensor,
1059+
w2: torch.Tensor,
1060+
topk_weight: torch.Tensor,
1061+
topk_ids: torch.Tensor,
1062+
global_num_experts: int = -1,
1063+
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
1064+
assert (global_num_experts == -1
1065+
or (global_num_experts == w1.shape[0] and expert_map is None)
1066+
or global_num_experts == expert_map.shape[0])
10691067
topk = topk_ids.shape[1]
10701068
B, D = a.shape
10711069
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
@@ -1083,18 +1081,17 @@ def torch_experts(
10831081
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
10841082

10851083

1086-
def torch_moe(
1087-
a: torch.Tensor,
1088-
w1: torch.Tensor,
1089-
w2: torch.Tensor,
1090-
score: torch.Tensor,
1091-
topk: int,
1092-
global_num_experts: int = -1,
1093-
expert_map: Optional[torch.Tensor] = None
1094-
) -> torch.Tensor:
1084+
def torch_moe(a: torch.Tensor,
1085+
w1: torch.Tensor,
1086+
w2: torch.Tensor,
1087+
score: torch.Tensor,
1088+
topk: int,
1089+
global_num_experts: int = -1,
1090+
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
10951091
score = torch.softmax(score, dim=-1, dtype=torch.float32)
10961092
topk_weight, topk_ids = torch.topk(score, topk)
1097-
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map)
1093+
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
1094+
expert_map)
10981095

10991096

11001097
def torch_moe_single(a, w, score, topk):

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import torch
88

99
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
10-
import vllm.model_executor.layers.quantization.deepgemm
11-
1210
from vllm.logger import init_logger
1311
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
1412
_moe_permute)

vllm/model_executor/layers/quantization/deepgemm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import importlib.util
33
import logging
4+
from typing import Optional
45

56
import torch
67

7-
from typing import Optional
8-
98
from vllm.platforms import current_platform
109
from vllm.triton_utils import triton
1110
from vllm.utils import direct_register_custom_op
@@ -86,7 +85,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm(
8685
expert_ids: torch.Tensor,
8786
) -> None:
8887
import deep_gemm as dg
89-
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a, a_scale), (b, b_scale), output, expert_ids)
88+
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a, a_scale), (b, b_scale),
89+
output, expert_ids)
9090

9191

9292
direct_register_custom_op(
@@ -97,7 +97,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm(
9797
dispatch_key=current_platform.dispatch_key,
9898
)
9999

100-
101100
direct_register_custom_op(
102101
op_name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm",
103102
op_func=m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm,

0 commit comments

Comments
 (0)