Skip to content

Commit 45e33e4

Browse files
authored
[0.9.1]Refactoring w4a8 and w8a8 and supporting w4a8 graph mode (#1480)
### What this PR does / why we need it? 1.Refactoring w4a8_dynamic and w8a8_dynamic 2.support w4a8 graph mode ### Does this PR introduce _any_ user-facing change? #### 1.How to get weights using Modelslim ##### Installation steps Use the branch master, the commit id is: 298e175d69b3b855111a1e09bbe2fcd12fdb4e24 git clone https://gitee.com/ascend/msit.git cd msit/msmodelslim bash install.sh ##### The required transformers environment pip install transformers==4.48.2 ##### Generate w4a8 weights cd /example/DeepSeek Command reference: msmodelslim/example/DeepSeek/README.md Execute the [pre-check](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#运行前必检) and [DeepSeek-R1 w4a8 mix quantization](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-混合量化前三层-mlpw8a8-dynamic-量化mla共享专家w8a8量化路由专家w4a8-dynamic量化) chapter Reference command:python3 quant_deepseek_w4a8.py --model_path {Original weight path} --save_path {Generate weight path} --mindie_format ##### Adapt to vllm-ascend Since mindie_format generates mindie format, some adaptation modifications are needed for vllm-ascend to use it: `quant_model_description_w8a8_dynamic.json` rename to `quant_model_description.json`, and add `"group_size": 256` Modification in `config.json`:`"model_type":deepseekv2` is changed to `"model_type":deepseek_v3` ; `quantization_config` is removed; #### 2.How to run w4a8 ##### a.How to run eager mode export VLLM_USE_V1=1 # v1 TP + EP: python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 --enable_expert_parallel --quantization ascend --port $3 --max-model-len $4 --max-num-seqs $5 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 2048 --max-num-seqs 128 --enforce-eager DP+TP+EP: python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 2 -dp2 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 2048 --max-num-seqs 128 --enforce-eager ##### b.How to run graph mode export VLLM_USE_V1=1 # v1 export HCCL_BUFFSIZE=1024 #If you use graph mode, tp<=4 DP+TP+EP: python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' eg: python -m vllm.entrypoints.openai.api_server --model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' --------- Signed-off-by: pichangping <1337510399@qq.com>
1 parent 0c99cf7 commit 45e33e4

File tree

3 files changed

+89
-255
lines changed

3 files changed

+89
-255
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,14 @@ jobs:
197197
else
198198
pytest -sv tests/multicard/test_ilama_lora_tp2.py
199199
# To avoid oom, we need to run the test in a single process.
200+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_w4a8_deepseek.py::test_deepseek_W4A8
200201
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
201202
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
202203
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
203204
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
204205
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
205206
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
206-
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
207-
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_w4a8_deepseek.py::test_deepseek_W4A8
207+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py
208208
fi
209209
210210
- name: Run vllm-project/vllm-ascend test on V0 engine

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 45 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -19,225 +19,16 @@
1919

2020
import numpy as np
2121
import torch
22-
import torch.distributed as dist
2322
import torch_npu
2423
from vllm.config import get_current_vllm_config
25-
from vllm.distributed import GroupCoordinator, get_ep_group
24+
from vllm.distributed import get_ep_group
25+
from vllm.forward_context import get_forward_context
2626

2727
from vllm_ascend.ascend_config import get_ascend_config
28+
from vllm_ascend.ascend_forward_context import FusedMoEState
2829
from vllm_ascend.ops.fused_moe import select_experts
29-
from vllm_ascend.utils import dispose_tensor
30-
31-
32-
def apply_mlp(hidden_states: torch.Tensor,
33-
w1: torch.Tensor,
34-
w1_scale: torch.Tensor,
35-
w2: torch.Tensor,
36-
w2_scale: torch.Tensor,
37-
w1_scale_bias: torch.Tensor,
38-
w2_scale_bias: torch.Tensor,
39-
group_list: torch.Tensor,
40-
dynamic_scale: torch.Tensor = None,
41-
group_list_type: int = 1) -> torch.Tensor:
42-
"""
43-
apply MLP: gate_up_proj -> swiglu -> down_proj
44-
45-
Args:
46-
hidden_states: input hidden states with shape (num_tokens, hidden_size).
47-
w1: expert weights1 with shape
48-
(num_experts, hidden_size, intermediate_size * 2)
49-
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
50-
w2: expert weights2 with shape
51-
(num_experts, intermediate_size, hidden_size)
52-
w2_scale: weights2 scale with shape (num_experts, hidden_size)
53-
group_list: number of tokens for each expert, follow cumsum mode, and
54-
with shape (num_experts).
55-
transpose_weight:
56-
w1: (num_experts, intermediate_size * 2, hidden_size) ->
57-
(num_experts, hidden_size, intermediate_size * 2)
58-
w2: (num_experts, hidden_size, intermediate_size) ->
59-
(num_experts, intermediate_size, hidden_size)
60-
61-
Returns:
62-
hidden_states: output hidden states after MLP.
63-
"""
64-
65-
if dynamic_scale is None:
66-
unquantized_hidden_states = hidden_states
67-
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
68-
hidden_states)
69-
# Dispose the original unquantized hidden states
70-
# to save npu memory because they're no longer used.
71-
dispose_tensor(unquantized_hidden_states)
72-
else:
73-
pertoken_scale = dynamic_scale
74-
75-
# gmm1: gate_up_proj
76-
group_list_type = 1
77-
group_list = torch.cat([group_list[:1], torch.diff(group_list, dim=0)])
78-
79-
hidden_states = torch_npu.npu_grouped_matmul(
80-
x=[hidden_states],
81-
weight=[w1],
82-
scale=[w1_scale],
83-
bias=[w1_scale_bias],
84-
per_token_scale=[pertoken_scale],
85-
split_item=2,
86-
group_list_type=group_list_type,
87-
group_type=0,
88-
group_list=group_list,
89-
output_dtype=torch.bfloat16)[0]
90-
91-
# act_fn: swiglu
92-
hidden_states = torch_npu.npu_swiglu(hidden_states)
93-
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
94-
hidden_states)
95-
96-
# gmm2: down_proj
97-
hidden_states = torch_npu.npu_grouped_matmul(
98-
x=[hidden_states],
99-
weight=[w2],
100-
scale=[w2_scale],
101-
bias=[w2_scale_bias],
102-
per_token_scale=[swiglu_out_scale],
103-
split_item=2,
104-
group_list_type=group_list_type,
105-
group_type=0,
106-
group_list=group_list,
107-
output_dtype=torch.bfloat16)[0]
108-
return hidden_states
109-
110-
111-
# currently expert parallelism implemented with all2all
112-
# is under-optimized.
113-
def fused_experts_with_all2all(
114-
hidden_states: torch.Tensor,
115-
w1: torch.Tensor,
116-
w1_scale: torch.Tensor,
117-
w2: torch.Tensor,
118-
w2_scale: torch.Tensor,
119-
w1_scale_bias: torch.Tensor,
120-
w2_scale_bias: torch.Tensor,
121-
topk_weights: torch.Tensor,
122-
topk_ids: torch.Tensor,
123-
top_k: int,
124-
expert_map: torch.Tensor = None,
125-
ep_group: GroupCoordinator = None,
126-
log2phy: torch.Tensor = None,
127-
global_redundant_expert_num: int = 0,
128-
):
129-
if log2phy:
130-
topk_ids = log2phy[topk_ids]
131-
original_shape = hidden_states.shape
132-
if len(original_shape) == 3:
133-
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
134-
135-
num_tokens, _ = hidden_states.shape
136-
num_experts = w1.shape[0]
137-
device = hidden_states.device
138-
139-
if expert_map is not None:
140-
global_num_experts = len(expert_map) + global_redundant_expert_num
141-
local_num_experts = global_num_experts // ep_group.world_size
142-
row_idx_len = num_tokens * top_k
143-
row_idx = (torch.arange(0,
144-
row_idx_len,
145-
dtype=torch.int32,
146-
device=device).view(top_k, -1).permute(
147-
1, 0).contiguous())
148-
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
149-
hidden_states,
150-
row_idx=row_idx,
151-
expert_idx=topk_ids,
152-
active_num=num_tokens)
153-
154-
global_expert_tokens = torch.bincount(expanded_expert_idx,
155-
minlength=global_num_experts)
156-
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
157-
-1).sum(-1)
158-
159-
gather_sizes = torch.empty_like(scatter_sizes)
160-
dist.all_to_all_single(gather_sizes,
161-
scatter_sizes,
162-
group=ep_group.device_group)
163-
scatter_size_list = scatter_sizes.cpu().tolist()
164-
gather_size_list = gather_sizes.cpu().tolist()
165-
166-
expanded_expert_idx = expanded_expert_idx % local_num_experts
167-
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
168-
scatter_size_list,
169-
gather_size_list)
170-
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
171-
scatter_size_list,
172-
gather_size_list)
173-
174-
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
175-
176-
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
177-
sorted_local_expert_idx, local_num_experts).to(torch.int64)
178-
179-
hidden_states = hidden_states[sorted_idx]
180-
group_list_type = 0
181-
else:
182-
row_idx_len = num_tokens * top_k
183-
row_idx = torch.arange(0,
184-
row_idx_len,
185-
dtype=torch.int32,
186-
device=topk_weights.device).view(
187-
top_k, -1).permute(1, 0).contiguous()
188-
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
189-
hidden_states,
190-
row_idx=row_idx,
191-
expert_idx=topk_ids,
192-
active_num=num_tokens)
193-
194-
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
195-
expanded_expert_idx, num_experts)
196-
expert_tokens = expert_tokens.to(torch.int64)
197-
group_list_type = 0
198-
199-
# `hidden_states` will be disposed in the `apply_mlp` function
200-
hidden_states = apply_mlp(hidden_states,
201-
w1,
202-
w1_scale,
203-
w2,
204-
w2_scale,
205-
w1_scale_bias,
206-
w2_scale_bias,
207-
expert_tokens,
208-
group_list_type=group_list_type)
209-
210-
if expert_map is not None:
211-
resorted_idx = torch.argsort(sorted_idx)
212-
hidden_states = hidden_states[resorted_idx]
213-
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
214-
gather_size_list,
215-
scatter_size_list)
216-
217-
final_hidden_states = torch_npu.npu_moe_finalize_routing(
218-
hidden_states,
219-
skip1=None,
220-
skip2=None,
221-
bias=None,
222-
scales=topk_weights,
223-
expanded_src_to_dst_row=expanded_row_idx,
224-
export_for_source_row=topk_ids,
225-
)
226-
else:
227-
# TODO: Reorder device memory 2 times here, replace the current
228-
# implementation here when suitable operators become available.
229-
final_hidden_states = torch_npu.npu_moe_finalize_routing(
230-
hidden_states,
231-
skip1=None,
232-
skip2=None,
233-
bias=None,
234-
scales=topk_weights,
235-
expanded_src_to_dst_row=expanded_row_idx,
236-
export_for_source_row=topk_ids,
237-
)
238-
if len(original_shape) == 3:
239-
final_hidden_states = final_hidden_states.view(original_shape)
240-
return final_hidden_states
30+
from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all,
31+
fused_experts_with_mc2)
24132

24233

24334
class AscendW4A8DynamicLinearMethod:
@@ -483,26 +274,46 @@ def apply(
483274

484275
topk_weights = topk_weights.to(x.dtype)
485276

486-
# The current implementation of deepseek moe splits hidden_states
487-
# according to tp_size before they are feed into fused_moe module.
488-
# Therefore, all2all is needed no matter how dp/tp is set so as to
489-
# dispatch/combine tokens.
490-
return fused_experts_with_all2all(
491-
hidden_states=x,
492-
w1=layer.w13_weight,
493-
w2=layer.w2_weight,
494-
w1_scale=layer.w13_weight_scale_second,
495-
w2_scale=layer.w2_weight_scale_second,
496-
w1_scale_bias=layer.w13_scale_bias,
497-
w2_scale_bias=layer.w2_scale_bias,
498-
topk_weights=topk_weights,
499-
topk_ids=topk_ids,
500-
top_k=top_k,
501-
expert_map=expert_map,
502-
ep_group=self.ep_group,
503-
log2phy=log2phy,
504-
global_redundant_expert_num=global_redundant_expert_num,
505-
)
277+
fused_moe_state = get_forward_context().fused_moe_state
278+
if fused_moe_state == FusedMoEState.MC2:
279+
return fused_experts_with_mc2(
280+
hidden_states=x,
281+
w1=layer.w13_weight,
282+
w2=layer.w2_weight,
283+
w1_scale=layer.w13_weight_scale_second,
284+
w2_scale=layer.w2_weight_scale_second,
285+
w1_scale_bias=layer.w13_scale_bias,
286+
w2_scale_bias=layer.w2_scale_bias,
287+
topk_weights=topk_weights,
288+
topk_ids=topk_ids,
289+
top_k=top_k,
290+
expert_map=expert_map,
291+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
292+
log2phy=log2phy,
293+
global_redundant_expert_num=global_redundant_expert_num,
294+
shared_experts=shared_experts,
295+
is_torchair=self.torchair_graph_enabled)
296+
else:
297+
# The current implementation of deepseek moe splits hidden_states
298+
# according to tp_size before they are feed into fused_moe module.
299+
# Therefore, all2all is needed no matter how dp/tp is set so as to
300+
# dispatch/combine tokens.
301+
return fused_experts_with_all2all(
302+
hidden_states=x,
303+
w1=layer.w13_weight,
304+
w2=layer.w2_weight,
305+
w1_scale=layer.w13_weight_scale_second,
306+
w2_scale=layer.w2_weight_scale_second,
307+
w1_scale_bias=layer.w13_scale_bias,
308+
w2_scale_bias=layer.w2_scale_bias,
309+
topk_weights=topk_weights,
310+
topk_ids=topk_ids,
311+
top_k=top_k,
312+
expert_map=expert_map,
313+
ep_group=self.ep_group,
314+
log2phy=log2phy,
315+
global_redundant_expert_num=global_redundant_expert_num,
316+
)
506317

507318
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
508319
group_num, k, n = weight.shape

0 commit comments

Comments
 (0)