Skip to content

Commit 3ca8322

Browse files
committed
lint
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 03b41b6 commit 3ca8322

File tree

4 files changed

+10
-28
lines changed

4 files changed

+10
-28
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
205205
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
206206
block_shape = [16, 16, 32] # 16 for k if not fp8
207207

208-
#print(f"tensors.A {tensors.A.shape}")
209-
#print(f"tensors.B {tensors.B.shape}")
210-
211208
if use_fp8_w8a8:
212-
#A_scale = torch.ones((1, K), dtype=torch.float32, device=tensors.A.device)
213-
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
214-
#quant_block_shape = [N, K]
215209
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
216210
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
217211
quant_block_shape = [1, 1]

tests/kernels/moe/test_pplx_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
reason="Requires PPLX kernels",
6464
)
6565

66+
6667
@dataclasses.dataclass
6768
class ProcessGroupInfo:
6869
world_size: int

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from vllm.model_executor.layers.fused_moe.fused_moe import (
1111
get_config_dtype_str, try_get_optimal_moe_config)
1212
from vllm.model_executor.layers.fused_moe.utils import (
13-
_resize_cache,
14-
moe_kernel_quantize_input)
13+
_resize_cache, moe_kernel_quantize_input)
1514

1615

1716
@triton.jit
@@ -480,8 +479,7 @@ def prepare(
480479
self.qtype,
481480
self.per_act_token,
482481
self.block_shape,
483-
)
484-
)
482+
))
485483
else:
486484
b_a1[idx, :rows, :] = rhs
487485

@@ -652,10 +650,8 @@ def batched_moe_kernel_quantize_input(
652650
if num_tokens > 0:
653651
A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input(
654652
A[e, :num_tokens],
655-
A_scale[e, :num_tokens] if A_scale else None,
656-
qtype,
657-
per_channel_quant,
658-
[block_k, block_n])
653+
A_scale[e, :num_tokens] if A_scale else None, qtype,
654+
per_channel_quant, [block_k, block_n])
659655
A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale
660656

661657
return A_q, A_q_scale
@@ -812,16 +808,8 @@ def apply(
812808
intermediate_cache1.view(-1, N))
813809

814810
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
815-
intermediate_cache2,
816-
a2_scale,
817-
num_tokens,
818-
E,
819-
N,
820-
expert_num_tokens,
821-
self.qtype,
822-
self.per_act_token,
823-
self.block_shape
824-
)
811+
intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens,
812+
self.qtype, self.per_act_token, self.block_shape)
825813

826814
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
827815
B=w2,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -769,13 +769,12 @@ def process_weights_after_loading(self, layer: Module) -> None:
769769
del layer.w2_input_scale
770770

771771
def select_gemm_impl(self, prepare_finalize):
772-
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
773-
TritonOrDeepGemmExperts)
774772
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
775-
BatchedPrepareAndFinalize,
776-
BatchedTritonExperts)
773+
BatchedPrepareAndFinalize, BatchedTritonExperts)
777774
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
778775
PplxPrepareAndFinalize)
776+
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
777+
TritonOrDeepGemmExperts)
779778

780779
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
781780
"Marlin and ROCm AITER are not supported with all2all yet.")

0 commit comments

Comments
 (0)