Skip to content

[Bugfix] Fix a couple PPLX+CUTLASS MoE bugs #20825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def prepare(
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
indices=topk_ids.view(dtype=torch.uint32),
bound_m=bound_m,
)

Expand Down Expand Up @@ -249,7 +249,7 @@ def finalize(
topk_weights = torch.ones_like(topk_weights)

self.a2a.combine(out_tokens=output,
indices=topk_ids,
indices=topk_ids.view(dtype=torch.uint32),
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m)
Original file line number Diff line number Diff line change
Expand Up @@ -737,10 +737,8 @@ def __init__(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")

from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
self.topk_indices_dtype = None
self.fused_experts = cutlass_moe_fp8 # type: ignore
self.fused_experts = None # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So how does this get set now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is set in init_prepare_finalize() method in layer.py:

self.fused_experts = FusedMoEModularKernel(
                prepare_finalize,
                experts,
            )

This function is called for non-EP parallel runs. If it's never called, self.fused_experts is never set and the condition in CompressedTensorsW8A8Fp8MoECutlassMethod's apply() function results in importing and calling cutlass_moe_fp8().

Before this PR, init_prepare_finalize() would overwrite an existing cutlass_moe_fp8() function and CompressedTensorsW8A8Fp8MoECutlassMethod's apply() would call whatever self.fused_experts was at the time of the call. It was convenient to do so because cutlass_moe_fp8() and FusedMoEModularKernel's experts.apply() were called with the same arguments. This changed in one of the recent PRs resulting in errors in PPLX runs, so now there's an if-else condition required to decide which arguments self.fused_experts should be called with.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should leave a comment for this tbh as it is difficult to know

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, and we should revisit this as well - we need to keep the control flow as simple as possible in the MoE layers given how complicated they are.

self.disable_expert_map = False

def create_weights(self, layer: torch.nn.Module, num_experts: int,
Expand Down Expand Up @@ -936,21 +934,40 @@ def apply(
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)

return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
per_act_token=per_act_token,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
if self.fused_experts is None:
# If no modular kernel is provided, use cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
return cutlass_moe_fp8(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
per_act_token=per_act_token,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
else:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)


class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
Expand Down