Skip to content

Commit e5b5e83

Browse files
committed
Refactor scattered w8a8 dynamic quantization operations
AscendW8A8DynamicLinearMethod is integrated into CustomDeepseekV2MLP in a very awkward way, causing scattered quantization operations all over the model scripts. Refactor to solve this problem. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 4976b48 commit e5b5e83

File tree

4 files changed

+92
-76
lines changed

4 files changed

+92
-76
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ jobs:
127127
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
128128
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
129129
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
130+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
130131
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
131132
fi
132133
@@ -157,5 +158,6 @@ jobs:
157158
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
158159
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
159160
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
161+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
160162
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
161163
fi

tests/multicard/test_offline_inference_distributed.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import os
2424
from unittest.mock import patch
2525

26-
import vllm # noqa: F401
26+
from modelscope import snapshot_download # type: ignore
2727
from vllm import SamplingParams
2828

2929
from tests.conftest import VllmRunner
@@ -95,3 +95,19 @@ def test_models_distributed_DeepSeek_dbo():
9595
distributed_executor_backend="mp",
9696
) as vllm_model:
9797
vllm_model.generate(example_prompts, sampling_params)
98+
99+
100+
def test_models_distributed_DeepSeek_W8A8():
101+
example_prompts = [
102+
"Hello, my name is",
103+
]
104+
dtype = "half"
105+
max_tokens = 5
106+
107+
with VllmRunner(
108+
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
109+
dtype=dtype,
110+
tensor_parallel_size=4,
111+
distributed_executor_backend="mp",
112+
) as vllm_model:
113+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/models/deepseek_v2.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

28-
from typing import Any, Dict, List, Optional, Union
28+
from typing import Any, Dict, List, Optional, Tuple, Union
2929

3030
import torch
3131
import torch.distributed as dist
@@ -75,6 +75,29 @@
7575
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7676

7777

78+
class CustomDeepseekV2SiluAndMul(SiluAndMul):
79+
80+
def __init__(self, *, weight_scale: Optional[torch.Tensor] = None):
81+
super().__init__()
82+
self.weight_scale = weight_scale
83+
84+
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
85+
torch.Tensor]]):
86+
if isinstance(x, tuple):
87+
assert self.weight_scale is not None
88+
# For AscendW8A8DynamicLinearMethod:
89+
# a dynamic scale is passed along with the quantized value.
90+
quantized_x, dynamic_scale = x
91+
return torch_npu.npu_dequant_swiglu_quant(
92+
x=quantized_x,
93+
weight_scale=self.weight_scale,
94+
activation_scale=dynamic_scale,
95+
activate_left=True,
96+
quant_mode=1)
97+
else:
98+
return super().forward_oot(x)
99+
100+
78101
class CustomDeepseekV2MLP(nn.Module):
79102

80103
def __init__(
@@ -101,44 +124,33 @@ def __init__(
101124
if hidden_act != "silu":
102125
raise ValueError(f"Unsupported activation: {hidden_act}. "
103126
"Only silu is supported for now.")
104-
self.act_fn = SiluAndMul()
105127

106-
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
107-
self.is_dynamic_quant = not isinstance(
108-
self.gate_up_proj.quant_method,
109-
UnquantizedLinearMethod) and isinstance(
110-
self.gate_up_proj.quant_method.quant_method,
111-
AscendW8A8DynamicLinearMethod)
128+
quant_method = self.gate_up_proj.quant_method
129+
if isinstance(quant_method, UnquantizedLinearMethod):
130+
self.act_fn = CustomDeepseekV2SiluAndMul()
131+
elif isinstance(quant_method, AscendW8A8DynamicLinearMethod):
132+
# TODO(sdmyzlp): Currently preserved as before:
133+
# 1. The only quantization supported for silu is W8A8Dynamic
134+
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
135+
#
136+
# Maybe one can implement a better and more general configuration
137+
# scheme, e.g. by somehow passing around the tweaked `quant_config`
138+
self.act_fn = CustomDeepseekV2SiluAndMul(
139+
weight_scale=self.gate_up_proj.weight_scale_fp32)
140+
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
141+
self.gate_up_proj._dynamic_quant_config = {
142+
"output_dtype": torch.int32,
143+
"return_scale": True,
144+
}
145+
self.down_proj._dynamic_quant_config = {
146+
"output_dtype": torch.bfloat16,
147+
"return_scale": False,
148+
}
149+
else:
150+
raise NotImplementedError(
151+
f"Quantization with [{type(quant_method)}] is NOT supported")
112152

113153
def forward(self, x):
114-
if self.is_dynamic_quant:
115-
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
116-
x = torch_npu.npu_quant_matmul(
117-
x,
118-
self.gate_up_proj.weight,
119-
self.gate_up_proj.weight_scale,
120-
output_dtype=torch.int32,
121-
)
122-
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
123-
x=x,
124-
weight_scale=self.gate_up_proj.weight_scale_fp32,
125-
activation_scale=dynamic_scale,
126-
bias=None,
127-
quant_scale=None,
128-
quant_offset=None,
129-
group_index=None,
130-
activate_left=True,
131-
quant_mode=1)
132-
x = torch_npu.npu_quant_matmul(
133-
x,
134-
self.down_proj.weight,
135-
self.down_proj.weight_scale,
136-
pertoken_scale=dynamic_scale,
137-
output_dtype=torch.bfloat16,
138-
)
139-
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
140-
x = tensor_model_parallel_all_reduce(x)
141-
return x
142154
gate_up, _ = self.gate_up_proj(x)
143155
x = self.act_fn(gate_up)
144156
x, _ = self.down_proj(x)

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Any, Callable, Dict, Optional
18+
from typing import Any, Callable, Dict, Optional, Tuple, Union
1919

2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
2323
import torchair as tng # type: ignore
24-
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
24+
from vllm.distributed import GroupCoordinator
2525

2626
import vllm_ascend.envs as envs_ascend
2727
from vllm_ascend.ascend_config import get_ascend_config
@@ -77,19 +77,9 @@ def apply_mlp(hidden_states: torch.Tensor,
7777
shared_experts = kwargs.get('shared_experts', None)
7878
if shared_experts:
7979
shared_gate_up = kwargs.get('shared_gate_up', None)
80-
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
8180
with tng.scope.npu_stream_switch('cv'):
82-
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
83-
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
84-
x=shared_gate_up,
85-
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
86-
activation_scale=shared_dynamic_scale,
87-
bias=None,
88-
quant_scale=None,
89-
quant_offset=None,
90-
group_index=None,
91-
activate_left=True,
92-
quant_mode=1)
81+
tng.scope.npu_wait_tensor(shared_gate_up[0], hidden_states)
82+
shared_act = shared_experts.act_fn(shared_gate_up)
9383

9484
# gmm1: gate_up_proj
9585
hidden_states = torch_npu.npu_grouped_matmul(
@@ -122,16 +112,9 @@ def apply_mlp(hidden_states: torch.Tensor,
122112

123113
if shared_experts:
124114
with tng.scope.npu_stream_switch('cv'):
125-
tng.scope.npu_wait_tensor(shared_x, hidden_states)
126-
shared_output = torch_npu.npu_quant_matmul(
127-
shared_x,
128-
shared_experts.down_proj.weight,
129-
shared_experts.down_proj.weight_scale,
130-
pertoken_scale=shared_dynamic_scale,
131-
output_dtype=torch.bfloat16,
132-
)
133-
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
134-
shared_output = tensor_model_parallel_all_reduce(shared_output)
115+
tng.scope.npu_wait_tensor(shared_act[0], hidden_states)
116+
shared_output, _ = shared_experts.down_proj(shared_act)
117+
135118
if shared_experts:
136119
return hidden_states, shared_output
137120
return hidden_states
@@ -189,17 +172,10 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
189172
shared_hidden_states = kwargs.get('shared_hidden_states', None)
190173
with tng.scope.npu_stream_switch('cv'):
191174
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
192-
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
175+
shared_gate_up, _ = shared_experts.gate_up_proj(
193176
shared_hidden_states)
194-
shared_gate_up = torch_npu.npu_quant_matmul(
195-
shared_x,
196-
shared_experts.gate_up_proj.weight,
197-
shared_experts.gate_up_proj.weight_scale,
198-
output_dtype=torch.int32,
199-
)
200177
kwargs.update({
201178
"shared_gate_up": shared_gate_up,
202-
"shared_dynamic_scale": shared_dynamic_scale,
203179
})
204180

205181
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
@@ -532,21 +508,31 @@ def get_perchannel_param(
532508
@staticmethod
533509
def apply(
534510
layer: torch.nn.Module,
535-
x: torch.Tensor,
511+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
536512
bias: Optional[torch.Tensor] = None,
537513
tp_rank: Optional[int] = 0,
538514
) -> torch.Tensor:
539-
original_dtype = x.dtype
540-
# use ATB quantize
541-
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
542-
return torch_npu.npu_quant_matmul(
543-
quant_out,
515+
config = getattr(layer, "_dynamic_quant_config", {})
516+
if not isinstance(x, tuple):
517+
output_dtype = config.get("output_dtype", x.dtype)
518+
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
519+
else:
520+
assert "output_dtype" in config.keys(), (
521+
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
522+
f"for pre-quantized input, got config [{config}]")
523+
output_dtype = config["output_dtype"]
524+
quantized_x, dynamic_scale = x
525+
526+
output = torch_npu.npu_quant_matmul(
527+
quantized_x,
544528
layer.weight,
545529
layer.weight_scale,
546530
pertoken_scale=dynamic_scale,
547531
bias=bias,
548-
output_dtype=original_dtype,
532+
output_dtype=output_dtype,
549533
)
534+
return ((output, dynamic_scale)
535+
if config.get("return_scale", False) else output)
550536

551537
def process_weights_after_loading(self, layer):
552538
if self.transpose_weight:

0 commit comments

Comments
 (0)