Skip to content

Commit 2e25bb1

Browse files
authored
[Bugfix] Fix import of CutlassExpertsFp8 in compressed_tensors_moe.py (#20381)
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 9965c47 commit 2e25bb1

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from vllm import _custom_ops as ops
1515
from vllm.logger import init_logger
1616
from vllm.model_executor.layers.fused_moe import (
17-
CutlassExpertsFp8, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig,
18-
FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute,
19-
FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, fused_experts)
17+
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
18+
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
19+
FusedMoeWeightScaleSupported)
2020
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
2121
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
2222
from vllm.model_executor.layers.quantization.utils import replace_parameter
@@ -570,6 +570,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
570570
del layer.w2_input_scale
571571
self.fused_experts_func = None
572572
else:
573+
from vllm.model_executor.layers.fused_moe import fused_experts
573574
self.fused_experts_func = fused_experts
574575

575576
def apply(
@@ -826,6 +827,7 @@ def select_gemm_impl(
826827
prepare_finalize: FusedMoEPrepareAndFinalize,
827828
moe: FusedMoEConfig,
828829
) -> FusedMoEPermuteExpertsUnpermute:
830+
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8
829831

830832
use_batched_format = (prepare_finalize.activation_format ==
831833
FusedMoEActivationFormat.BatchedExperts)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
from vllm.distributed import get_tensor_model_parallel_world_size
1515
from vllm.logger import init_logger
1616
from vllm.model_executor.layers.fused_moe import (
17-
BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat,
18-
FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute,
19-
FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported,
20-
TritonOrDeepGemmExperts)
17+
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
18+
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
19+
FusedMoeWeightScaleSupported)
2120
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
2221
UnquantizedLinearMethod)
2322
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -785,6 +784,9 @@ def select_gemm_impl(
785784
prepare_finalize: FusedMoEPrepareAndFinalize,
786785
moe: FusedMoEConfig,
787786
) -> FusedMoEPermuteExpertsUnpermute:
787+
from vllm.model_executor.layers.fused_moe import (
788+
BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)
789+
788790
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
789791
"Marlin and ROCm AITER are not supported with all2all yet.")
790792

0 commit comments

Comments
 (0)