Skip to content

Commit c169b05

Browse files
committed
merge
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 468d166 commit c169b05

File tree

4 files changed

+17
-62
lines changed

4 files changed

+17
-62
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 5 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def flatten_tp_across_dp(dp_rank: int):
184184
# Adapted from pplx-kernels tests/all_to_all_utils.py
185185
@dataclass
186186
class MoEConfig:
187-
max_num_tokens: int
188187
num_experts: int
189188
experts_per_token: int
190189
hidden_dim: int
@@ -347,33 +346,18 @@ def select_gemm_impl(
347346
all2all_manager = get_ep_group().device_communicator.all2all_manager
348347
assert all2all_manager is not None
349348

350-
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
351-
352349
if isinstance(prepare_finalize,
353350
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
354351
logger.debug("BatchedTritonExperts %s", self.moe)
355-
experts = BatchedTritonExperts(
352+
return BatchedTritonExperts(
356353
max_num_tokens=MOE_DP_CHUNK_SIZE,
357354
world_size=all2all_manager.world_size,
358355
# dp_size actually means tp_size, bug in pplx kernels
359356
dp_size=all2all_manager.tp_group.world_size,
360-
use_fp8_w8a8=False,
361-
use_int8_w8a8=False,
362-
use_int8_w8a16=False,
363-
use_int4_w4a16=False,
364-
block_shape=None,
365357
)
366358
else:
367359
logger.debug("TritonExperts %s", self.moe)
368-
experts = TritonExperts(
369-
use_fp8_w8a8=False,
370-
use_int8_w8a8=False,
371-
use_int8_w8a16=False,
372-
use_int4_w4a16=False,
373-
block_shape=None,
374-
per_channel_quant=False,
375-
)
376-
return experts
360+
return TritonExperts()
377361

378362
def create_weights(self, layer: torch.nn.Module, num_experts: int,
379363
hidden_size: int, intermediate_size_per_partition: int,
@@ -472,35 +456,6 @@ def apply(
472456
activation=activation,
473457
apply_router_weight_on_input=apply_router_weight_on_input)
474458

475-
def set_prepare_finalize(
476-
self,
477-
dp_size: int,
478-
world_size: int,
479-
prepare_finalize: FusedMoEPrepareAndFinalize,
480-
) -> bool:
481-
assert self.fused_experts == fused_experts
482-
483-
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
484-
485-
if isinstance(prepare_finalize,
486-
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
487-
logger.debug("BatchedTritonExperts %s", self.moe)
488-
experts = BatchedTritonExperts(
489-
max_num_tokens=MOE_DP_CHUNK_SIZE,
490-
world_size=world_size,
491-
dp_size=dp_size,
492-
)
493-
else:
494-
logger.debug("TritonExperts %s", self.moe)
495-
experts = TritonExperts()
496-
497-
self.fused_experts = FusedMoEModularKernel(
498-
prepare_finalize,
499-
experts,
500-
)
501-
502-
return True
503-
504459
def forward_cuda(
505460
self,
506461
layer: torch.nn.Module,
@@ -815,16 +770,14 @@ def __init__(
815770
from vllm_hpu_extension.ops import DynamicFusedMOE
816771
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
817772

818-
logger.debug("Model dtype = %s", vllm_config.model_config.dtype)
819-
820773
moe = MoEConfig(
821-
max_num_tokens=MOE_DP_CHUNK_SIZE,
822774
num_experts=self.global_num_experts,
823775
experts_per_token=top_k,
824776
hidden_dim=hidden_size,
825777
num_local_experts=self.local_num_experts,
826778
moe_parallel_config=self.moe_parallel_config,
827-
in_dtype=moe.in_dtype,
779+
in_dtype=vllm_config.model_config.dtype,
780+
max_num_tokens=MOE_DP_CHUNK_SIZE,
828781
)
829782
self.moe_config = moe
830783
self.quant_config = quant_config
@@ -1281,7 +1234,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
12811234

12821235
assert (self.batched_hidden_states.size(0) # type: ignore
12831236
>= chunk_size)
1284-
assert (self.batched_router_logits.size(0) # type: ignore
1237+
assert (self.batched_router_logits.size(0) # type: ignore
12851238
>= chunk_size)
12861239
staged_hidden_states = self.batched_hidden_states[:
12871240
chunk_size, :] # type: ignore

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ def prepare(
108108
# There's not much point setting this unless it is != indices.size(0)
109109
bound_m: Optional[torch.Tensor] = None
110110

111-
#print(f"SCALE= {a1q_scale.shape}")
112-
113111
self.a2a.dispatch(
114112
out_expert_num_tokens=expert_num_tokens,
115113
out_expert_x=expert_x,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import vllm.envs as envs
1313
from vllm import _custom_ops as ops
14-
from vllm.distributed import get_tensor_model_parallel_world_size
14+
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
1515
from vllm.logger import init_logger
1616
from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE,
1717
FusedMoEMethodBase,
@@ -771,17 +771,26 @@ def process_weights_after_loading(self, layer: Module) -> None:
771771
def select_gemm_impl(self, prepare_finalize):
772772
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
773773
TritonOrDeepGemmExperts)
774+
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
775+
BatchedPrepareAndFinalize,
776+
BatchedTritonExperts)
777+
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
778+
PplxPrepareAndFinalize)
774779

775780
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
776781
"Marlin and ROCm AITER are not supported with all2all yet.")
777782

783+
all2all_manager = get_ep_group().device_communicator.all2all_manager
784+
assert all2all_manager is not None
785+
778786
if isinstance(prepare_finalize,
779787
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
780788
logger.debug("BatchedTritonExperts(fp8)")
789+
self.use_pplx_kernels = True
781790
return BatchedTritonExperts(
782791
max_num_tokens=MOE_DP_CHUNK_SIZE,
783-
world_size=world_size,
784-
dp_size=dp_size,
792+
world_size=all2all_manager.world_size,
793+
dp_size=all2all_manager.tp_group.world_size,
785794
qtype=torch.float8_e4m3fn,
786795
block_shape=self.quant_config.weight_block_size,
787796
per_act_token=False, #?

vllm/v1/engine/core_client.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -492,11 +492,6 @@ def _wait_for_engine_startup(self, output_address: str,
492492
(e for e in self.core_engines if e.identity == eng_identity),
493493
None)
494494
if engine is None:
495-
msg = msgspec.msgpack.decode(ready_msg_bytes)
496-
status, local = msg["status"], msg["local"]
497-
logger.debug(f"XXXXXX {status} message from "
498-
f"{'local' if local else 'remote'} "
499-
f"engine {eng_index}")
500495
raise RuntimeError(f"Message from engine with unexpected data "
501496
f"parallel rank: {eng_index}")
502497
msg = msgspec.msgpack.decode(ready_msg_bytes)

0 commit comments

Comments
 (0)