|
15 | 15 | # limitations under the License.
|
16 | 16 | #
|
17 | 17 |
|
18 |
| -from typing import Any, Callable, Dict, Optional |
| 18 | +from typing import Any, Callable, Dict, Optional, Tuple, Union |
19 | 19 |
|
20 | 20 | import torch
|
21 | 21 | import torch.distributed as dist
|
22 | 22 | import torch_npu
|
23 | 23 | import torchair as tng # type: ignore
|
24 |
| -from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce |
| 24 | +from vllm.distributed import GroupCoordinator |
25 | 25 |
|
26 | 26 | import vllm_ascend.envs as envs_ascend
|
27 | 27 | from vllm_ascend.ascend_config import get_ascend_config
|
@@ -77,19 +77,9 @@ def apply_mlp(hidden_states: torch.Tensor,
|
77 | 77 | shared_experts = kwargs.get('shared_experts', None)
|
78 | 78 | if shared_experts:
|
79 | 79 | shared_gate_up = kwargs.get('shared_gate_up', None)
|
80 |
| - shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None) |
81 | 80 | with tng.scope.npu_stream_switch('cv'):
|
82 |
| - tng.scope.npu_wait_tensor(shared_gate_up, hidden_states) |
83 |
| - shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant( |
84 |
| - x=shared_gate_up, |
85 |
| - weight_scale=shared_experts.gate_up_proj.weight_scale_fp32, |
86 |
| - activation_scale=shared_dynamic_scale, |
87 |
| - bias=None, |
88 |
| - quant_scale=None, |
89 |
| - quant_offset=None, |
90 |
| - group_index=None, |
91 |
| - activate_left=True, |
92 |
| - quant_mode=1) |
| 81 | + tng.scope.npu_wait_tensor(shared_gate_up[0], hidden_states) |
| 82 | + shared_act = shared_experts.act_fn(shared_gate_up) |
93 | 83 |
|
94 | 84 | # gmm1: gate_up_proj
|
95 | 85 | hidden_states = torch_npu.npu_grouped_matmul(
|
@@ -122,16 +112,9 @@ def apply_mlp(hidden_states: torch.Tensor,
|
122 | 112 |
|
123 | 113 | if shared_experts:
|
124 | 114 | with tng.scope.npu_stream_switch('cv'):
|
125 |
| - tng.scope.npu_wait_tensor(shared_x, hidden_states) |
126 |
| - shared_output = torch_npu.npu_quant_matmul( |
127 |
| - shared_x, |
128 |
| - shared_experts.down_proj.weight, |
129 |
| - shared_experts.down_proj.weight_scale, |
130 |
| - pertoken_scale=shared_dynamic_scale, |
131 |
| - output_dtype=torch.bfloat16, |
132 |
| - ) |
133 |
| - if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1: |
134 |
| - shared_output = tensor_model_parallel_all_reduce(shared_output) |
| 115 | + tng.scope.npu_wait_tensor(shared_act[0], hidden_states) |
| 116 | + shared_output, _ = shared_experts.down_proj(shared_act) |
| 117 | + |
135 | 118 | if shared_experts:
|
136 | 119 | return hidden_states, shared_output
|
137 | 120 | return hidden_states
|
@@ -189,17 +172,10 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
189 | 172 | shared_hidden_states = kwargs.get('shared_hidden_states', None)
|
190 | 173 | with tng.scope.npu_stream_switch('cv'):
|
191 | 174 | tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
|
192 |
| - shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant( |
| 175 | + shared_gate_up, _ = shared_experts.gate_up_proj( |
193 | 176 | shared_hidden_states)
|
194 |
| - shared_gate_up = torch_npu.npu_quant_matmul( |
195 |
| - shared_x, |
196 |
| - shared_experts.gate_up_proj.weight, |
197 |
| - shared_experts.gate_up_proj.weight_scale, |
198 |
| - output_dtype=torch.int32, |
199 |
| - ) |
200 | 177 | kwargs.update({
|
201 | 178 | "shared_gate_up": shared_gate_up,
|
202 |
| - "shared_dynamic_scale": shared_dynamic_scale, |
203 | 179 | })
|
204 | 180 |
|
205 | 181 | output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
@@ -532,21 +508,31 @@ def get_perchannel_param(
|
532 | 508 | @staticmethod
|
533 | 509 | def apply(
|
534 | 510 | layer: torch.nn.Module,
|
535 |
| - x: torch.Tensor, |
| 511 | + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
536 | 512 | bias: Optional[torch.Tensor] = None,
|
537 | 513 | tp_rank: Optional[int] = 0,
|
538 | 514 | ) -> torch.Tensor:
|
539 |
| - original_dtype = x.dtype |
540 |
| - # use ATB quantize |
541 |
| - quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) |
542 |
| - return torch_npu.npu_quant_matmul( |
543 |
| - quant_out, |
| 515 | + config = getattr(layer, "_dynamic_quant_config", {}) |
| 516 | + if not isinstance(x, tuple): |
| 517 | + output_dtype = config.get("output_dtype", x.dtype) |
| 518 | + quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) |
| 519 | + else: |
| 520 | + assert "output_dtype" in config.keys(), ( |
| 521 | + f"DynamicLinearMethod needs explicitly specified `output_dtype`" |
| 522 | + f"for pre-quantized input, got config [{config}]") |
| 523 | + output_dtype = config["output_dtype"] |
| 524 | + quantized_x, dynamic_scale = x |
| 525 | + |
| 526 | + output = torch_npu.npu_quant_matmul( |
| 527 | + quantized_x, |
544 | 528 | layer.weight,
|
545 | 529 | layer.weight_scale,
|
546 | 530 | pertoken_scale=dynamic_scale,
|
547 | 531 | bias=bias,
|
548 |
| - output_dtype=original_dtype, |
| 532 | + output_dtype=output_dtype, |
549 | 533 | )
|
| 534 | + return ((output, dynamic_scale) |
| 535 | + if config.get("return_scale", False) else output) |
550 | 536 |
|
551 | 537 | def process_weights_after_loading(self, layer):
|
552 | 538 | if self.transpose_weight:
|
|
0 commit comments