Skip to content

Commit efc014f

Browse files
committed
lint
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent a548905 commit efc014f

File tree

5 files changed

+44
-25
lines changed

5 files changed

+44
-25
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -242,20 +242,32 @@ class FusedMoEConfig:
242242
max_num_tokens: int = MOE_DP_CHUNK_SIZE
243243

244244
@property
245-
def quant_dtype(self):
246-
return self.quant_config.quant_dtype if self.quant_config is not None else None
245+
def quant_dtype(self) -> Optional[torch.dtype]:
246+
if self.quant_config is not None:
247+
return self.quant_config.quant_dtype
248+
else:
249+
return None
247250

248251
@property
249-
def block_shape(self):
250-
return self.quant_config.block_shape if self.quant_config is not None else None
252+
def block_shape(self) -> Optional[list[int]]:
253+
if self.quant_config is not None:
254+
return self.quant_config.block_shape
255+
else:
256+
return None
251257

252258
@property
253-
def per_act_token_quant(self):
254-
return self.quant_config.per_act_token_quant if self.quant_config is not None else False
259+
def per_act_token_quant(self) -> bool:
260+
if self.quant_config is not None:
261+
return self.quant_config.per_act_token_quant
262+
else:
263+
return False
255264

256265
@property
257-
def per_out_ch_quant(self):
258-
return self.quant_config.per_out_ch_quant if self.quant_config is not None else False
266+
def per_out_ch_quant(self) -> bool:
267+
if self.quant_config is not None:
268+
return self.quant_config.per_out_ch_quant
269+
else:
270+
return False
259271

260272
@property
261273
def tp_size(self):
@@ -338,7 +350,8 @@ def make(
338350
quant_dtype = torch.float8_e4m3fn
339351

340352
if weight_quant is not None:
341-
per_out_ch_quant = weight_quant.strategy == QuantizationStrategy.CHANNEL
353+
per_out_ch_quant = (
354+
weight_quant.strategy == QuantizationStrategy.CHANNEL)
342355

343356
assert quant_dtype is not None
344357

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _do_quant(
8282

8383
assert isinstance(x, torch.Tensor)
8484

85+
# TODO (bnell):
8586
# Check if there is a block_shape / or if we can infer the quantization
8687
# schemes from the scales.
8788
# per_token_quant = None

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,8 @@ def batched_triton_kernel(
303303
if (group_k > 0 and group_n > 0) or per_channel_quant:
304304
a_scale_ptr = a_scale_ptr + (expert_id *
305305
stride_ase) + cta_m_start * stride_asm
306-
#b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) # + cta_n_start * stride_bsn?
306+
#b_scale_ptr = b_scale_ptr + (expert_id * stride_bse)
307+
# (?) b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn
307308
# channel-wise or tensor-wise
308309
else:
309310
a_scale_ptr = a_scale_ptr + (expert_id * stride_ase)
@@ -379,9 +380,10 @@ def invoke_moe_batched_triton_kernel(
379380
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
380381
triton.cdiv(B.size(1), BLOCK_N))
381382

382-
assert A_scale is None or A_scale.ndim == 3, f"{0 if A_scale is None else A_scale.shape}"
383-
assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, f"{0 if B_scale is None else B_scale.shape}"
384-
#assert B_scale is None or B_scale.ndim == 3, f"{0 if B_scale is None else (A.shape, B_scale.shape)}"
383+
assert A_scale is None or A_scale.ndim == 3, (
384+
f"{0 if A_scale is None else A_scale.shape}")
385+
assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, (
386+
f"{0 if B_scale is None else B_scale.shape}")
385387

386388
if B_scale is not None:
387389
if B_scale.ndim == 1:
@@ -522,7 +524,10 @@ def prepare(
522524
k_tiles = (hidden_dim + block_k - 1) // block_k
523525
scale_shape = (num_local_experts, self.max_num_tokens, k_tiles)
524526
else:
525-
num = self.max_num_tokens if quant_config.per_act_token_quant else 1
527+
if quant_config.per_act_token_quant:
528+
num = self.max_num_tokens
529+
else:
530+
num = 1
526531
scale_shape = (num_local_experts, num, 1)
527532

528533
#print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}")
@@ -555,7 +560,8 @@ def prepare(
555560
quant_config.per_act_token_quant,
556561
quant_config.block_shape,
557562
))
558-
if quant_config.block_shape is None and not quant_config.per_act_token_quant:
563+
if (quant_config.block_shape is None and
564+
not quant_config.per_act_token_quant):
559565
b_a1_scale[idx] = b_s
560566
else:
561567
#print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}")

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
556556

557557
if use_fp8_w8a8 or use_int8_w8a8:
558558
assert B_scale is not None
559-
# assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
560-
# == B_scale.shape[-2]), f"{block_shape[0]} {B.shape[-2]} {B_scale.shape[-2]}"
559+
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
560+
== B_scale.shape[-2]), (
561+
f"{block_shape[0]} {B.shape[-2]} {B_scale.shape[-2]}")
561562
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
562563
== B_scale.shape[-1])
563564

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,18 @@
2020
from vllm.model_executor.custom_op import CustomOp
2121
from vllm.model_executor.layers.fused_moe.config import (
2222
FusedMoEConfig, FusedMoEParallelConfig)
23+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
24+
FusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute,
25+
FusedMoEPrepareAndFinalize)
2326
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
2427
is_rocm_aiter_moe_enabled)
2528
from vllm.model_executor.layers.quantization.base_config import (
2629
QuantizationConfig, QuantizeMethodBase)
27-
from vllm.model_executor.layers.fused_moe.modular_kernel import (
28-
FusedMoEModularKernel,
29-
FusedMoEPermuteExpertsUnpermute,
30-
FusedMoEPrepareAndFinalize)
3130
from vllm.model_executor.utils import set_weight_attrs
3231
from vllm.platforms import current_platform
3332
from vllm.platforms.interface import CpuArchEnum
3433
from vllm.utils import direct_register_custom_op
3534

36-
3735
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
3836
has_deepep = importlib.util.find_spec("deep_ep") is not None
3937

@@ -95,9 +93,9 @@ def init_prepare_finalize(self, moe: FusedMoEConfig,
9593
block_shape=moe.block_shape,
9694
)
9795

98-
logger.debug(
99-
f"All2All {moe.quant_dtype}, {moe.block_shape} = {hidden_dim_bytes}/{hidden_scale_bytes}"
100-
)
96+
logger.debug("All2All %s, %s = %s/%s", moe.quant_dtype,
97+
moe.block_shape, hidden_dim_bytes,
98+
hidden_scale_bytes)
10199

102100
all_to_all_args = dict(
103101
max_num_tokens=moe.max_num_tokens,

0 commit comments

Comments
 (0)