-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Qwen FP8 ModelOPT support #20734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
jingyu-ml
wants to merge
27
commits into
vllm-project:main
Choose a base branch
from
jingyu-ml:jingyux/dev-qwen-fp8
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+543
−43
Draft
Qwen FP8 ModelOPT support #20734
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
6927b02
resolve conflict
Edwardf0t1 b45972e
bugfix
Edwardf0t1 bf96528
handle language_model. prefix
Edwardf0t1 cb20cd1
fix issue in fused_experts calling
Edwardf0t1 1f61802
minor
Edwardf0t1 0fb23e1
update ModelOptFp8Config, handle prefix in mllama4 weight loading, debug
Edwardf0t1 03d2b3b
debug, handle kv scales
Edwardf0t1 93d7185
fix kv scale name matching issue
Edwardf0t1 d154fe1
update, debug
Edwardf0t1 782c018
cleanup
Edwardf0t1 b78b191
fix format
Edwardf0t1 22745f3
resolve conflict
Edwardf0t1 1c5acec
debug
Edwardf0t1 b10782d
handle eplb in ModelOptFp8MoEMethod
Edwardf0t1 8265287
broadcasting BMM experts scales
Edwardf0t1 59190ea
cleanup
Edwardf0t1 7a6fc84
some refactor and cleanup
Edwardf0t1 47a47a9
refactor Llama4ForConditionalGeneration.load_weights
Edwardf0t1 eec1daf
resolve conflict
Edwardf0t1 770bc24
format and linter error fix
Edwardf0t1 770890a
simplify ModelOptFp8MoEMethod to avoid mypy error
Edwardf0t1 1beecbf
resolve conflict
Edwardf0t1 0b98a7f
format fix
Edwardf0t1 cc44385
fix mypy error
Edwardf0t1 1ece491
Merge remote-tracking branch 'vllm/main' into zhiyu/llama4-fp8-modelopt
767358e
Merge remote-tracking branch 'vllm/main' into jingyux/dev-qwen-fp8
1206f33
add qwen fp8 modelopt support
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -81,6 +81,16 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, | |||||||
params_dtype: torch.dtype, **extra_weight_attrs): | ||||||||
raise NotImplementedError | ||||||||
|
||||||||
def uses_weight_scale_2_pattern(self) -> bool: | ||||||||
""" | ||||||||
Returns True if this quantization method uses 'weight_scale_2' pattern | ||||||||
for per-tensor weight scales (e.g., FP4 variants), False otherwise. | ||||||||
This method should be overridden by subclasses that use the | ||||||||
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern. | ||||||||
""" | ||||||||
return False | ||||||||
|
||||||||
def init_prepare_finalize(self, moe: FusedMoEConfig, | ||||||||
quant_config: Optional[QuantizationConfig]): | ||||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager | ||||||||
|
@@ -1049,12 +1059,23 @@ def weight_loader(self, | |||||||
|
||||||||
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern | ||||||||
if "ModelOpt" in quant_method_name: | ||||||||
if ('weight_scale_2' in weight_name | ||||||||
or 'input_scale' in weight_name): | ||||||||
self._load_per_tensor_weight_scale(shard_id=shard_id, | ||||||||
param=param, | ||||||||
loaded_weight=loaded_weight, | ||||||||
expert_id=expert_id) | ||||||||
# Determine per-tensor weight scale patterns based on variant | ||||||||
# Use the dedicated method instead of brittle string matching | ||||||||
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern( | ||||||||
) | ||||||||
|
||||||||
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale" | ||||||||
per_tensor_conditions = ( | ||||||||
"weight_scale_2" in weight_name if uses_weight_scale_2 else | ||||||||
"weight_scale" in weight_name) or "input_scale" in weight_name | ||||||||
|
||||||||
if per_tensor_conditions: | ||||||||
self._load_per_tensor_weight_scale( | ||||||||
shard_id=shard_id, | ||||||||
param=param, | ||||||||
loaded_weight=loaded_weight, | ||||||||
expert_id=expert_id, | ||||||||
) | ||||||||
elif "weight" in weight_name: | ||||||||
self._load_model_weight_or_group_weight_scale( | ||||||||
shard_id=shard_id, | ||||||||
|
@@ -1526,3 +1547,7 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, | |||||||
dispatch_key=current_platform.dispatch_key, | ||||||||
tags=(torch.Tag.needs_fixed_stride_order, ), | ||||||||
) | ||||||||
|
||||||||
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters | ||||||||
# to avoid expensive runtime reflection in model loading code | ||||||||
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using
Suggested change
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The direct access to
quant_cfg["producer"]["name"]
is unsafe and could raise aKeyError
if the keys are not present in thequant_cfg
dictionary. It's better to use.get()
for safer access. Additionally, the nestedif
statements can be combined for improved readability.