Skip to content

Commit 12cae04

Browse files
YikundingdingchaomianAngazennliujiaxuApsarasX
authored
[quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to #580 (comment) 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
1 parent 1a1f9a6 commit 12cae04

File tree

7 files changed

+843
-16
lines changed

7 files changed

+843
-16
lines changed

docs/source/installation.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ docker run --rm \
6161
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
6262
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
6363
-v /etc/ascend_install.info:/etc/ascend_install.info \
64+
-v /root/.cache:/root/.cache \
6465
-it $IMAGE bash
6566
```
6667

@@ -123,7 +124,7 @@ First install system dependencies:
123124

124125
```bash
125126
apt update -y
126-
apt install -y gcc g++ cmake libnuma-dev
127+
apt install -y gcc g++ cmake libnuma-dev wget
127128
```
128129

129130
Current version depends on a unreleased `torch-npu`, you need to install manually:
@@ -144,6 +145,7 @@ cd pta
144145
wget https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v2.5.1/20250320.3/pytorch_v2.5.1_py310.tar.gz
145146
tar -xvf pytorch_v2.5.1_py310.tar.gz
146147
pip install ./torch_npu-2.5.1.dev20250320-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
148+
cd ..
147149
```
148150

149151
Then you can install `vllm` and `vllm-ascend` from **pre-built wheel**:
@@ -152,6 +154,8 @@ Then you can install `vllm` and `vllm-ascend` from **pre-built wheel**:
152154
:substitutions:
153155
154156
# Install vllm-project/vllm from pypi
157+
# There was a vLLM v0.8.4 installation bug, please use "Build from source code"
158+
# https://github.com/vllm-project/vllm-ascend/issues/581
155159
pip install vllm==|pip_vllm_version|
156160
157161
# Install vllm-project/vllm-ascend from pypi.
@@ -168,11 +172,13 @@ or build from **source code**:
168172
git clone --depth 1 --branch |vllm_version| https://github.com/vllm-project/vllm
169173
cd vllm
170174
VLLM_TARGET_DEVICE=empty pip install . --extra-index https://download.pytorch.org/whl/cpu/
175+
cd ..
171176
172177
# Install vLLM Ascend
173178
git clone --depth 1 --branch |vllm_ascend_version| https://github.com/vllm-project/vllm-ascend.git
174179
cd vllm-ascend
175180
pip install -e . --extra-index https://download.pytorch.org/whl/cpu/
181+
cd ..
176182
```
177183
:::
178184

@@ -216,6 +222,7 @@ docker run --rm \
216222
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
217223
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
218224
-v /etc/ascend_install.info:/etc/ascend_install.info \
225+
-v /root/.cache:/root/.cache \
219226
-it $IMAGE bash
220227
```
221228

tests/singlecard/test_offline_inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030

3131
MODELS = [
3232
"Qwen/Qwen2.5-0.5B-Instruct",
33+
"vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8",
3334
]
35+
os.environ["VLLM_USE_MODELSCOPE"] = "True"
3436
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
3537

3638

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Optional, Tuple, Union
19+
20+
import torch
21+
import torch_npu
22+
from vllm.logger import logger
23+
from vllm.model_executor.layers.layernorm import RMSNorm
24+
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
25+
26+
27+
# func refers to RMSNorm.__init__
28+
def wrapper_rmsnorm_init(func):
29+
30+
def init(self, hidden_size: int, **extra_args) -> None:
31+
func(self, hidden_size, **extra_args)
32+
self.ignore_anti = True
33+
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
34+
requires_grad=False)
35+
36+
return init
37+
38+
39+
# func refers to RMSNorm.forward_oot
40+
def wrapper_rmsnorm_forward_oot(func):
41+
42+
def _rmsnorm_forward_oot(
43+
self,
44+
x: torch.Tensor,
45+
residual: Optional[torch.Tensor] = None,
46+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
47+
if not self.ignore_anti:
48+
if residual is not None:
49+
residual += x
50+
out = torch_npu._npu_quant_rms_norm(
51+
residual,
52+
self.weight,
53+
self.bias,
54+
self.input_scale,
55+
self.input_offset,
56+
self.variance_epsilon,
57+
)
58+
return out, residual
59+
out = torch_npu._npu_quant_rms_norm(
60+
x,
61+
self.weight,
62+
self.bias,
63+
self.input_scale,
64+
self.input_offset,
65+
self.variance_epsilon,
66+
)
67+
return out
68+
69+
if residual is not None:
70+
x, residual = func(self, x, residual)
71+
return x.add_(self.bias), residual
72+
73+
return func(self, x).add_(self.bias)
74+
75+
return _rmsnorm_forward_oot
76+
77+
78+
MODEL_LAYER_MAPPING = {
79+
"LlamaModel": {
80+
"attn": {
81+
"layer_attr": "self_attn",
82+
"proj_attr": "qkv_proj",
83+
"norm_attr": "input_layernorm",
84+
"unquantized_type": UnquantizedLinearMethod,
85+
},
86+
"mlp": {
87+
"layer_attr": "mlp",
88+
"proj_attr": "gate_up_proj",
89+
"norm_attr": "post_attention_layernorm",
90+
"unquantized_type": UnquantizedLinearMethod,
91+
},
92+
},
93+
}
94+
95+
96+
def wrapper_load_model(func):
97+
98+
def postprocess_loading(self) -> None:
99+
func(self)
100+
101+
def process_layer(layer, idx, mapping):
102+
103+
def process_module(module_cfg, layer_obj):
104+
if module_cfg is None:
105+
return
106+
107+
module_obj = getattr(layer_obj, module_cfg["layer_attr"], None)
108+
if module_obj is None:
109+
return
110+
111+
proj_attr = module_cfg["proj_attr"]
112+
if callable(proj_attr):
113+
proj = proj_attr(module_obj, idx)
114+
else:
115+
proj = getattr(module_obj, proj_attr, None)
116+
117+
norm = getattr(layer_obj, module_cfg["norm_attr"], None)
118+
119+
if proj is None or norm is None:
120+
return
121+
122+
norm.ignore_anti = isinstance(proj.quant_method,
123+
module_cfg["unquantized_type"])
124+
if not norm.ignore_anti:
125+
for param_name in ["input_scale", "input_offset"]:
126+
if hasattr(proj, param_name):
127+
param = getattr(proj, param_name)
128+
norm.register_parameter(
129+
param_name,
130+
torch.nn.Parameter(param.clone(),
131+
requires_grad=False))
132+
133+
process_module(mapping.get("attn"), layer)
134+
process_module(mapping.get("mlp"), layer)
135+
136+
model_type = self.model.model.__class__.__name__
137+
mapping = MODEL_LAYER_MAPPING.get(model_type)
138+
139+
if not mapping:
140+
logger.info(
141+
f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping."
142+
)
143+
return
144+
145+
for idx, layer in enumerate(self.model.model.layers):
146+
process_layer(layer, idx, mapping)
147+
148+
if isinstance(self.model.model.norm, RMSNorm):
149+
self.model.model.norm.ignore_anti = True
150+
151+
return postprocess_loading

vllm_ascend/quantization/quant_config.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -306,23 +306,23 @@ def apply(
306306
self,
307307
layer: torch.nn.Module,
308308
x: torch.Tensor,
309-
use_grouped_topk: bool,
310-
top_k: int,
311309
router_logits: torch.Tensor,
310+
top_k: int,
312311
renormalize: bool,
313-
global_num_experts: int,
314-
expert_map: torch.Tensor,
312+
use_grouped_topk: bool = False,
315313
topk_group: Optional[int] = None,
316314
num_expert_group: Optional[int] = None,
317-
is_prefill: bool = True,
315+
global_num_experts: int = -1,
316+
expert_map: Optional[torch.Tensor] = None,
318317
custom_routing_function: Optional[Callable] = None,
319318
scoring_func: str = "softmax",
320-
e_score_correction_bias: Optional[torch.Tensor] = None
319+
e_score_correction_bias: Optional[torch.Tensor] = None,
320+
**kwargs,
321321
) -> torch.Tensor:
322-
return self.quant_method.apply(layer, x, use_grouped_topk, top_k,
323-
router_logits, renormalize, topk_group,
324-
num_expert_group, global_num_experts,
325-
expert_map, is_prefill,
322+
return self.quant_method.apply(layer, x, router_logits, top_k,
323+
renormalize, use_grouped_topk,
324+
topk_group, num_expert_group,
325+
global_num_experts, expert_map,
326326
custom_routing_function, scoring_func,
327327
e_score_correction_bias)
328328

0 commit comments

Comments
 (0)