-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Bugfix] Fix a couple PPLX+CUTLASS MoE bugs #20825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -737,10 +737,8 @@ def __init__( | |
"For FP8 Fused MoE layer, we require either per tensor or " | ||
"channelwise, dynamic per token quantization.") | ||
|
||
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( | ||
cutlass_moe_fp8) | ||
self.topk_indices_dtype = None | ||
self.fused_experts = cutlass_moe_fp8 # type: ignore | ||
self.fused_experts = None # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So how does this get set now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is set in
This function is called for non-EP parallel runs. If it's never called, Before this PR, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should leave a comment for this tbh as it is difficult to know There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, and we should revisit this as well - we need to keep the control flow as simple as possible in the MoE layers given how complicated they are. |
||
self.disable_expert_map = False | ||
|
||
def create_weights(self, layer: torch.nn.Module, num_experts: int, | ||
|
@@ -936,21 +934,40 @@ def apply( | |
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( | ||
a2_scale.numel() != 1 if a2_scale is not None else False) | ||
|
||
return self.fused_experts( | ||
x, | ||
layer.w13_weight, | ||
layer.w2_weight, | ||
topk_weights, | ||
topk_ids, | ||
per_act_token=per_act_token, | ||
activation=activation, | ||
global_num_experts=global_num_experts, | ||
expert_map=None if self.disable_expert_map else expert_map, | ||
w1_scale=layer.w13_weight_scale, | ||
w2_scale=layer.w2_weight_scale, | ||
a1_scale=a1_scale, | ||
a2_scale=a2_scale, | ||
) | ||
if self.fused_experts is None: | ||
# If no modular kernel is provided, use cutlass_moe_fp8 | ||
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( | ||
cutlass_moe_fp8) | ||
return cutlass_moe_fp8( | ||
x, | ||
layer.w13_weight, | ||
layer.w2_weight, | ||
topk_weights, | ||
topk_ids, | ||
per_act_token=per_act_token, | ||
activation=activation, | ||
global_num_experts=global_num_experts, | ||
expert_map=None if self.disable_expert_map else expert_map, | ||
w1_scale=layer.w13_weight_scale, | ||
w2_scale=layer.w2_weight_scale, | ||
a1_scale=a1_scale, | ||
a2_scale=a2_scale, | ||
) | ||
else: | ||
return self.fused_experts( | ||
x, | ||
layer.w13_weight, | ||
layer.w2_weight, | ||
topk_weights, | ||
topk_ids, | ||
activation=activation, | ||
global_num_experts=global_num_experts, | ||
expert_map=None if self.disable_expert_map else expert_map, | ||
w1_scale=layer.w13_weight_scale, | ||
w2_scale=layer.w2_weight_scale, | ||
a1_scale=layer.w13_input_scale, | ||
a2_scale=layer.w2_input_scale, | ||
) | ||
tlrmchlsmth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): | ||
|
Uh oh!
There was an error while loading. Please reload this page.