Skip to content

Commit 12e42ea

Browse files
committed
more lint
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 5e22409 commit 12e42ea

File tree

12 files changed

+37
-31
lines changed

12 files changed

+37
-31
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
import torch
99
import triton.language as tl
1010

11-
from tests.kernels.utils import torch_experts
1211
from tests.kernels.moe.utils import (batched_moe,
1312
make_quantized_test_activations,
1413
make_test_weights, triton_moe)
1514
from tests.kernels.quant_utils import native_w8a8_block_matmul
15+
from tests.kernels.utils import torch_experts
1616
from vllm.config import VllmConfig, set_current_vllm_config
1717
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
1818
invoke_moe_batched_triton_kernel)
@@ -265,11 +265,17 @@ def test_fused_moe_batched_experts(
265265
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
266266
w2_s, quant_dtype, per_act_token_quant,
267267
block_shape)
268-
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids,
269-
w1_scale=w1_s, w2_scale=w2_s,
270-
quant_dtype=quant_dtype,
271-
per_act_token_quant=per_act_token_quant,
272-
block_shape=block_shape)
268+
baseline_output = torch_experts(
269+
a,
270+
w1,
271+
w2,
272+
topk_weight,
273+
topk_ids,
274+
w1_scale=w1_s,
275+
w2_scale=w2_s,
276+
quant_dtype=quant_dtype,
277+
per_act_token_quant=per_act_token_quant,
278+
block_shape=block_shape)
273279
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
274280
w2_s, quant_dtype, per_act_token_quant,
275281
block_shape)

tests/kernels/moe/test_pplx_moe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,8 +577,13 @@ def _pplx_moe(
577577

578578
with set_current_vllm_config(vllm_config), override_config(moe_config):
579579
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
580-
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids,
581-
w1_scale=w1_s, w2_scale=w2_s,
580+
torch_output = torch_experts(a,
581+
w1,
582+
w2,
583+
topk_weight,
584+
topk_ids,
585+
w1_scale=w1_s,
586+
w2_scale=w2_s,
582587
quant_dtype=qtype,
583588
per_act_token_quant=per_act_token_quant,
584589
block_shape=block_shape)

tests/kernels/moe/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44

55
import torch
66

7-
from vllm.model_executor.layers.activation import SiluAndMul
87
from vllm.model_executor.layers.fused_moe import fused_experts
98
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
109
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
1110
from vllm.model_executor.layers.fused_moe.modular_kernel import (
1211
FusedMoEModularKernel)
13-
from vllm.model_executor.layers.fused_moe.utils import (
14-
moe_kernel_quantize_input)
1512
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1613
per_token_group_quant_fp8)
1714
from vllm.utils import round_up

tests/kernels/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
from torch._prims_common import TensorLikeType
1515

1616
from tests.kernels.quant_utils import native_w8a8_block_matmul
17-
1817
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
1918
from vllm.model_executor.layers.activation import SiluAndMul
20-
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
19+
from vllm.model_executor.layers.fused_moe.utils import (
20+
moe_kernel_quantize_input)
2121
from vllm.platforms.interface import _Backend
2222
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
2323
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
@@ -1081,10 +1081,7 @@ def torch_experts(
10811081

10821082
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
10831083

1084-
out = torch.zeros(M * topk,
1085-
w2.shape[1],
1086-
dtype=a.dtype,
1087-
device=a.device)
1084+
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
10881085

10891086
a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype,
10901087
per_act_token_quant, block_shape)

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_config() -> Optional[dict[str, Any]]:
4545
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
4646
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
4747
BatchedDeepGemmExperts)
48-
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
48+
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
4949
BatchedTritonOrDeepGemmExperts)
5050
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
5151
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def activation_formats(
7474
self
7575
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
7676
if self.batched_triton_experts is not None:
77-
assert (self.batched_deep_gemm_experts is None or
78-
self.batched_deep_gemm_experts.activation_formats ==
79-
self.batched_triton_experts.activation_formats)
77+
assert (self.batched_deep_gemm_experts is None
78+
or self.batched_deep_gemm_experts.activation_formats
79+
== self.batched_triton_experts.activation_formats)
8080
return self.batched_triton_experts.activation_formats
8181
else:
8282
assert self.batched_deep_gemm_experts is not None

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
import vllm.envs as envs
1212
from vllm.config import ParallelConfig
13-
from vllm.logger import init_logger
1413
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
14+
from vllm.logger import init_logger
1515
from vllm.model_executor.layers.quantization.base_config import (
1616
QuantizationConfig)
1717

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import torch
88

9-
import vllm.model_executor.layers.quantization.deepgemm
109
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
1110
from vllm.logger import init_logger
1211
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
@@ -16,6 +15,9 @@
1615
MoEPrepareAndFinalizeNoEP)
1716
from vllm.model_executor.layers.fused_moe.utils import (
1817
_resize_cache, per_token_group_quant_fp8)
18+
from vllm.model_executor.layers.quantization.deepgemm import (
19+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm as
20+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm)
1921
from vllm.utils import round_up
2022

2123
logger = init_logger(__name__)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def select_gemm_impl(
230230
assert all2all_manager is not None
231231

232232
if (prepare_finalize.activation_format ==
233-
FusedMoEActivationFormat.BatchedExperts):
233+
FusedMoEActivationFormat.BatchedExperts):
234234
logger.debug("BatchedTritonExperts %s", self.moe)
235235
assert self.moe.dp_size == all2all_manager.dp_world_size
236236
return BatchedTritonExperts(

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def __init__(
4949
def activation_formats(
5050
self
5151
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
52-
assert (self.deep_gemm_expert is None or
53-
self.triton_expert.activation_formats ==
54-
self.deep_gemm_expert.activation_formats)
52+
assert (self.deep_gemm_expert is None
53+
or self.triton_expert.activation_formats
54+
== self.deep_gemm_expert.activation_formats)
5555
return self.triton_expert.activation_formats
5656

5757
def supports_chunking(self) -> bool:

0 commit comments

Comments
 (0)