-
Notifications
You must be signed in to change notification settings - Fork 248
[V0.9.1] Use AddRmsNormQuant ops in the custom model to optimize Qwen3's performance #1545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V0.9.1] Use AddRmsNormQuant ops in the custom model to optimize Qwen3's performance #1545
Conversation
68937f8
to
a73ed88
Compare
vllm_ascend/quantization/w8a8.py
Outdated
@@ -59,6 +60,7 @@ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: | |||
params_dict = {} | |||
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) | |||
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) | |||
AscendW8A8LinearMethod.params_dtype = params_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if there is a fp16 fallback?Then how dose that fallback linear do the calculation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think its a good solution, can we write this dtype back to the param_dict and inject it into the layer eventually?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have fix it.
vllm_ascend/models/qwen3.py
Outdated
if quant_config is not None: | ||
from vllm_ascend.quantization.quant_config import AscendQuantConfig | ||
assert isinstance(quant_config, AscendQuantConfig) | ||
self.input_layernorm = AddRMSNormQuant(config.hidden_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just discussed with @realliujiaxu , this behaviour is not a general way to apply our optimization in modeling, can we try to leverage the compilation path in vllm to fuse ops in fx graph? cc @jianzs @Yikun @wangxiyuan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with the changes in this PR, but I believe we also need an ultimate solution to handle this kind of problem for good.
3384ed1
to
07736ef
Compare
import torch_npu | ||
|
||
if residual is not None: | ||
x, _, residual = torch_npu.npu_add_rms_norm_quant( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does torch_npu.npu_add_rms_norm_quant
require a newer version of torch_npu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now version of PTA has supported it.
0137a7d
to
1b898e7
Compare
…3's performance Signed-off-by: rjg-lyh <1318825571@qq.com>
1b898e7
to
4def25f
Compare
What this PR does / why we need it?
Optimizes the performance of the Qwen3 quantization model by registering a custom model and adding the AddRmsNormQuant operation. Subsequent PRs will focus on performance optimizations based on this custom model.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
CI passed with existing test.