Skip to content

Commit 542bd18

Browse files
committed
Refactor scattered w8a8 dynamic quantization operations
AscendW8A8DynamicLinearMethod is integrated into CustomDeepseekV2MLP in a very awkward way, causing scattered quantization operations all over the model scripts. Refactor to solve this problem. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 3640c60 commit 542bd18

File tree

3 files changed

+88
-75
lines changed

3 files changed

+88
-75
lines changed

tests/multicard/test_offline_inference_distributed.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ def test_models_distributed_DeepSeek():
6161
vllm_model.generate_greedy(example_prompts, max_tokens)
6262

6363

64+
def test_models_distributed_DeepSeek_W8A8():
65+
example_prompts = [
66+
"Hello, my name is",
67+
]
68+
dtype = "half"
69+
max_tokens = 5
70+
with VllmRunner(
71+
"vllm-ascend/DeepSeek-V2-Lite-W8A8",
72+
dtype=dtype,
73+
tensor_parallel_size=4,
74+
distributed_executor_backend="mp",
75+
) as vllm_model:
76+
vllm_model.generate_greedy(example_prompts, max_tokens)
77+
78+
6479
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": "1"})
6580
def test_models_distributed_topk() -> None:
6681
example_prompts = [

vllm_ascend/models/deepseek_v2.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

28-
from typing import Any, Dict, List, Optional, Union
28+
from typing import Any, Dict, List, Optional, Tuple, Union
2929

3030
import torch
3131
import torch.distributed as dist
@@ -74,6 +74,29 @@
7474
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7575

7676

77+
class CustomDeepseekV2SiluAndMul(SiluAndMul):
78+
79+
def __init__(self, *, weight_scale: Optional[torch.Tensor] = None):
80+
super().__init__()
81+
self.weight_scale = weight_scale
82+
83+
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
84+
torch.Tensor]]):
85+
if isinstance(x, tuple):
86+
assert self.weight_scale is not None
87+
# For AscendW8A8DynamicLinearMethod:
88+
# a dynamic scale is passed along with the quantized value.
89+
quantized_x, dynamic_scale = x
90+
return torch_npu.npu_dequant_swiglu_quant(
91+
x=quantized_x,
92+
weight_scale=self.weight_scale,
93+
activation_scale=dynamic_scale,
94+
activate_left=True,
95+
quant_mode=1)
96+
else:
97+
return super().forward_oot(x)
98+
99+
77100
class CustomDeepseekV2MLP(nn.Module):
78101

79102
def __init__(
@@ -100,44 +123,33 @@ def __init__(
100123
if hidden_act != "silu":
101124
raise ValueError(f"Unsupported activation: {hidden_act}. "
102125
"Only silu is supported for now.")
103-
self.act_fn = SiluAndMul()
104126

105-
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
106-
self.is_dynamic_quant = not isinstance(
107-
self.gate_up_proj.quant_method,
108-
UnquantizedLinearMethod) and isinstance(
109-
self.gate_up_proj.quant_method.quant_method,
110-
AscendW8A8DynamicLinearMethod)
127+
quant_method = self.gate_up_proj.quant_method
128+
if isinstance(quant_method, UnquantizedLinearMethod):
129+
self.act_fn = CustomDeepseekV2SiluAndMul()
130+
elif isinstance(quant_method, AscendW8A8DynamicLinearMethod):
131+
# TODO(sdmyzlp): Currently preserved as before:
132+
# 1. The only quantization supported for silu is W8A8Dynamic
133+
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
134+
#
135+
# Maybe one can implement a better and more general configuration
136+
# scheme, e.g. by somehow passing around the tweaked `quant_config`
137+
self.act_fn = CustomDeepseekV2SiluAndMul(
138+
weight_scale=self.gate_up_proj.weight_scale_fp32)
139+
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
140+
self.gate_up_proj._dynamic_quant_config = {
141+
"output_dtype": torch.int32,
142+
"return_scale": True,
143+
}
144+
self.down_proj._dynamic_quant_config = {
145+
"output_dtype": torch.bfloat16,
146+
"return_scale": False,
147+
}
148+
else:
149+
raise NotImplementedError(
150+
f"Quantization with [{type(quant_method)}] is NOT supported")
111151

112152
def forward(self, x):
113-
if self.is_dynamic_quant:
114-
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
115-
x = torch_npu.npu_quant_matmul(
116-
x,
117-
self.gate_up_proj.weight,
118-
self.gate_up_proj.weight_scale,
119-
output_dtype=torch.int32,
120-
)
121-
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
122-
x=x,
123-
weight_scale=self.gate_up_proj.weight_scale_fp32,
124-
activation_scale=dynamic_scale,
125-
bias=None,
126-
quant_scale=None,
127-
quant_offset=None,
128-
group_index=None,
129-
activate_left=True,
130-
quant_mode=1)
131-
x = torch_npu.npu_quant_matmul(
132-
x,
133-
self.down_proj.weight,
134-
self.down_proj.weight_scale,
135-
pertoken_scale=dynamic_scale,
136-
output_dtype=torch.bfloat16,
137-
)
138-
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
139-
x = tensor_model_parallel_all_reduce(x)
140-
return x
141153
gate_up, _ = self.gate_up_proj(x)
142154
x = self.act_fn(gate_up)
143155
x, _ = self.down_proj(x)

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Any, Callable, Dict, Optional
18+
from typing import Any, Callable, Dict, Optional, Tuple, Union
1919

2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
2323
import torchair as tng # type: ignore
24-
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
24+
from vllm.distributed import GroupCoordinator
2525

2626
import vllm_ascend.envs as envs_ascend
2727
from vllm_ascend.ascend_config import get_ascend_config
@@ -77,19 +77,9 @@ def apply_mlp(hidden_states: torch.Tensor,
7777
shared_experts = kwargs.get('shared_experts', None)
7878
if shared_experts:
7979
shared_gate_up = kwargs.get('shared_gate_up', None)
80-
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
8180
with tng.scope.npu_stream_switch('cv'):
82-
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
83-
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
84-
x=shared_gate_up,
85-
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
86-
activation_scale=shared_dynamic_scale,
87-
bias=None,
88-
quant_scale=None,
89-
quant_offset=None,
90-
group_index=None,
91-
activate_left=True,
92-
quant_mode=1)
81+
tng.scope.npu_wait_tensor(shared_gate_up[0], hidden_states)
82+
shared_act = shared_experts.act_fn(shared_gate_up)
9383

9484
# gmm1: gate_up_proj
9585
hidden_states = torch_npu.npu_grouped_matmul(
@@ -122,16 +112,9 @@ def apply_mlp(hidden_states: torch.Tensor,
122112

123113
if shared_experts:
124114
with tng.scope.npu_stream_switch('cv'):
125-
tng.scope.npu_wait_tensor(shared_x, hidden_states)
126-
shared_output = torch_npu.npu_quant_matmul(
127-
shared_x,
128-
shared_experts.down_proj.weight,
129-
shared_experts.down_proj.weight_scale,
130-
pertoken_scale=shared_dynamic_scale,
131-
output_dtype=torch.bfloat16,
132-
)
133-
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
134-
shared_output = tensor_model_parallel_all_reduce(shared_output)
115+
tng.scope.npu_wait_tensor(shared_act[0], hidden_states)
116+
shared_output, _ = shared_experts.down_proj(shared_act)
117+
135118
if shared_experts:
136119
return hidden_states, shared_output
137120
return hidden_states
@@ -189,17 +172,10 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
189172
shared_hidden_states = kwargs.get('shared_hidden_states', None)
190173
with tng.scope.npu_stream_switch('cv'):
191174
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
192-
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
175+
shared_gate_up, _ = shared_experts.gate_up_proj(
193176
shared_hidden_states)
194-
shared_gate_up = torch_npu.npu_quant_matmul(
195-
shared_x,
196-
shared_experts.gate_up_proj.weight,
197-
shared_experts.gate_up_proj.weight_scale,
198-
output_dtype=torch.int32,
199-
)
200177
kwargs.update({
201178
"shared_gate_up": shared_gate_up,
202-
"shared_dynamic_scale": shared_dynamic_scale,
203179
})
204180

205181
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
@@ -532,21 +508,31 @@ def get_perchannel_param(
532508
@staticmethod
533509
def apply(
534510
layer: torch.nn.Module,
535-
x: torch.Tensor,
511+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
536512
bias: Optional[torch.Tensor] = None,
537513
tp_rank: Optional[int] = 0,
538514
) -> torch.Tensor:
539-
original_dtype = x.dtype
540-
# use ATB quantize
541-
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
542-
return torch_npu.npu_quant_matmul(
543-
quant_out,
515+
config = getattr(layer, "_dynamic_quant_config", {})
516+
if not isinstance(x, tuple):
517+
output_dtype = config.get("output_dtype", x.dtype)
518+
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
519+
else:
520+
assert "output_dtype" in config.keys(), (
521+
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
522+
f"for pre-quantized input, got config [{config}]")
523+
output_dtype = config["output_dtype"]
524+
quantized_x, dynamic_scale = x
525+
526+
output = torch_npu.npu_quant_matmul(
527+
quantized_x,
544528
layer.weight,
545529
layer.weight_scale,
546530
pertoken_scale=dynamic_scale,
547531
bias=bias,
548-
output_dtype=original_dtype,
532+
output_dtype=output_dtype,
549533
)
534+
return ((output, dynamic_scale)
535+
if config.get("return_scale", False) else output)
550536

551537
def process_weights_after_loading(self, layer):
552538
if self.transpose_weight:

0 commit comments

Comments
 (0)