Skip to content

Commit 7219559

Browse files
committed
fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 946788f commit 7219559

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,14 @@ def is_grouped(self) -> bool:
9090
def is_per_tensor(self) -> bool:
9191
return not self.per_act_token_quant and self.block_shape is None
9292

93-
def scale_shape(self, max_tokens: int, hidden_dim: int) -> Optional[tuple[int, int]]:
93+
def scale_shape(
94+
self,
95+
max_tokens: int,
96+
hidden_dim: int,
97+
) -> Optional[tuple[int, int]]:
9498
if self.is_quantized:
9599
if self.is_grouped:
100+
assert self.block_shape is not None
96101
_, block_k = self.block_shape
97102
k_tiles = cdiv(hidden_dim, block_k)
98103
return (max_tokens, k_tiles)
@@ -107,10 +112,11 @@ def batched_scale_shape(
107112
self,
108113
num_experts: int,
109114
max_tokens: int,
110-
hidden_dim: int
115+
hidden_dim: int,
111116
) -> Optional[tuple[int, int, int]]:
112117
if self.is_quantized:
113118
scale_shape = self.scale_shape(max_tokens, hidden_dim)
119+
assert scale_shape is not None
114120
return (num_experts, *scale_shape)
115121
else:
116122
return None

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
308308
else:
309309
self.fused_experts_func = fused_experts
310310

311-
def select_gemm_impl(self, prepare_finalize):
311+
def select_gemm_impl(
312+
self,
313+
prepare_finalize: FusedMoEPrepareAndFinalize,
314+
moe: FusedMoEConfig,
315+
) -> FusedMoEPermuteExpertsUnpermute:
312316
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
313317
BatchedTritonExperts)
314318

315319
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
316320

317-
logger.debug("BatchedTritonExperts(%s)", self.__classname__.__name__)
321+
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
318322

319323
use_batched_format = (prepare_finalize.activation_format ==
320324
FusedMoEActivationFormat.BatchedExperts)
@@ -595,7 +599,7 @@ def select_gemm_impl(
595599
num_experts = (moe.num_local_experts
596600
if use_batched_format else moe.num_experts)
597601

598-
logger.debug("CutlassExpertsFp8(%s)", self.__classname__.__name__)
602+
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
599603

600604
experts = CutlassExpertsFp8(
601605
num_experts,

0 commit comments

Comments
 (0)