Skip to content

Commit 8e70c20

Browse files
authored
Revert "Support multistream of shared experts in FusedMoE (#997)"
This reverts commit 7bdc606.
1 parent 7bdc606 commit 8e70c20

File tree

11 files changed

+308
-296
lines changed

11 files changed

+308
-296
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ jobs:
188188
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
189189
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
190190
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
191-
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
192191
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
193192
fi
194193
@@ -219,6 +218,5 @@ jobs:
219218
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
220219
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
221220
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
222-
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
223221
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
224222
fi

docs/source/user_guide/additional_config.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ The details of each config option are as follows:
3939
| Name | Type | Default | Description |
4040
| ---- | ---- | ------- | ----------- |
4141
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
42-
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
4342
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
4443
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
4544
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
4645
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
46+
| `enable_multistream_shared_expert`| bool | `False` | Whether to enable multistream shared expert |
4747

4848
**ascend_scheduler_config**
4949

@@ -64,7 +64,7 @@ A full example of additional configuration is as follows:
6464
"use_cached_graph": true,
6565
"graph_batch_sizes": [1, 2, 4, 8],
6666
"graph_batch_sizes_init": false,
67-
"enable_multistream_moe": false
67+
"enable_multistream_shared_expert": false
6868
},
6969
"ascend_scheduler_config": {
7070
"enabled": true,

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ warn_unused_configs = True
66
[mypy-torch_npu.*]
77
ignore_missing_imports = True
88

9-
[mypy-torchair.*]
10-
ignore_missing_imports = True
11-
129
[mypy-transformers.*]
1310
ignore_missing_imports = True
1411

tests/multicard/test_offline_inference_distributed.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import os
2424
from unittest.mock import patch
2525

26-
from modelscope import snapshot_download # type: ignore
26+
import vllm # noqa: F401
2727
from vllm import SamplingParams
2828

2929
from tests.conftest import VllmRunner
@@ -95,20 +95,3 @@ def test_models_distributed_DeepSeek_dbo():
9595
distributed_executor_backend="mp",
9696
) as vllm_model:
9797
vllm_model.generate(example_prompts, sampling_params)
98-
99-
100-
def test_models_distributed_DeepSeek_W8A8():
101-
example_prompts = [
102-
"Hello, my name is",
103-
]
104-
max_tokens = 5
105-
106-
with VllmRunner(
107-
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
108-
max_model_len=8192,
109-
enforce_eager=True,
110-
dtype="auto",
111-
tensor_parallel_size=4,
112-
quantization="ascend",
113-
) as vllm_model:
114-
vllm_model.generate_greedy(example_prompts, max_tokens)

tests/singlecard/test_ascend_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_run_with_ascend_config():
5858
"use_cached_graph": True,
5959
"graph_batch_sizes": [1, 2, 4, 8],
6060
"graph_batch_sizes_init": False,
61-
"enable_multistream_moe": True,
61+
"enable_multistream_shared_expert": True,
6262
},
6363
"ascend_scheduler_config": {
6464
"enabled": True,
@@ -79,7 +79,7 @@ def test_run_with_ascend_config():
7979
1, 2, 4, 8
8080
]
8181
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
82-
assert ascend_config.torchair_graph_config.enable_multistream_moe
82+
assert ascend_config.torchair_graph_config.enable_multistream_shared_expert
8383
assert ascend_config.ascend_scheduler_config.enabled
8484
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill
8585
assert ascend_config.expert_tensor_parallel_size == 1

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def __init__(self, torchair_graph_config):
5454
"graph_batch_sizes", [])
5555
self.graph_batch_sizes_init = torchair_graph_config.get(
5656
"graph_batch_sizes_init", False)
57-
self.enable_multistream_moe = torchair_graph_config.get(
58-
"enable_multistream_moe", False)
57+
self.enable_multistream_shared_expert = torchair_graph_config.get(
58+
"enable_multistream_shared_expert", False)
5959
self.enable_view_optimize = torchair_graph_config.get(
6060
"enable_view_optimize", True)
6161

vllm_ascend/models/deepseek_dbo.py

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import torch
3131
import torch.distributed as dist
32-
import torch_npu # noqa: F401
32+
import torch_npu
3333
import vllm.envs as envs
3434
from torch import nn
3535
from transformers import PretrainedConfig
@@ -40,10 +40,13 @@
4040
get_tp_group, tensor_model_parallel_all_reduce)
4141
from vllm.distributed.parallel_state import get_dp_group
4242
from vllm.forward_context import get_forward_context
43+
from vllm.model_executor.layers.activation import SiluAndMul
4344
from vllm.model_executor.layers.layernorm import RMSNorm
4445
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
46+
MergedColumnParallelLinear,
4547
ReplicatedLinear,
46-
RowParallelLinear)
48+
RowParallelLinear,
49+
UnquantizedLinearMethod)
4750
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4851
from vllm.model_executor.layers.quantization import QuantizationConfig
4952
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -64,7 +67,6 @@
6467

6568
import vllm_ascend.envs as envs_ascend
6669
from vllm_ascend.ascend_config import get_ascend_config
67-
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
6870
from vllm_ascend.multistream.base import MSEventKey
6971
from vllm_ascend.multistream.context import (
7072
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -76,17 +78,117 @@
7678
make_multistream_metadata_ds)
7779
from vllm_ascend.multistream.ms_split import compute_split_seq_index
7880
from vllm_ascend.ops.fused_moe import AscendFusedMoE
81+
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7982
from vllm_ascend.utils import dispose_tensor
8083

8184
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
8285
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
8386

8487

85-
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
88+
class CustomDeepseekDBOMLP(nn.Module):
89+
90+
def __init__(
91+
self,
92+
hidden_size: int,
93+
intermediate_size: int,
94+
hidden_act: str,
95+
quant_config: Optional[QuantizationConfig] = None,
96+
reduce_results: bool = True,
97+
prefix: str = "",
98+
) -> None:
99+
super().__init__()
100+
self.gate_up_proj = MergedColumnParallelLinear(
101+
hidden_size, [intermediate_size] * 2,
102+
bias=False,
103+
quant_config=quant_config,
104+
prefix=f"{prefix}.gate_up_proj")
105+
self.down_proj = RowParallelLinear(intermediate_size,
106+
hidden_size,
107+
bias=False,
108+
quant_config=quant_config,
109+
reduce_results=reduce_results,
110+
prefix=f"{prefix}.down_proj")
111+
if hidden_act != "silu":
112+
raise ValueError(f"Unsupported activation: {hidden_act}. "
113+
"Only silu is supported for now.")
114+
self.act_fn = SiluAndMul()
115+
116+
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
117+
self.is_dynamic_quant = not isinstance(
118+
self.gate_up_proj.quant_method,
119+
UnquantizedLinearMethod) and isinstance(
120+
self.gate_up_proj.quant_method.quant_method,
121+
AscendW8A8DynamicLinearMethod)
122+
123+
def forward(self, x):
124+
if self.is_dynamic_quant:
125+
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
126+
x = torch_npu.npu_quant_matmul(
127+
x,
128+
self.gate_up_proj.weight,
129+
self.gate_up_proj.weight_scale,
130+
output_dtype=torch.int32,
131+
)
132+
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
133+
x=x,
134+
weight_scale=self.gate_up_proj.weight_scale_fp32,
135+
activation_scale=dynamic_scale,
136+
bias=None,
137+
quant_scale=None,
138+
quant_offset=None,
139+
group_index=None,
140+
activate_left=True,
141+
quant_mode=1)
142+
x = torch_npu.npu_quant_matmul(
143+
x,
144+
self.down_proj.weight,
145+
self.down_proj.weight_scale,
146+
pertoken_scale=dynamic_scale,
147+
output_dtype=torch.bfloat16,
148+
)
149+
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
150+
x = tensor_model_parallel_all_reduce(x)
151+
return x
152+
gate_up, _ = self.gate_up_proj(x)
153+
x = self.act_fn(gate_up)
154+
x, _ = self.down_proj(x)
155+
return x
86156

87157
def _forward_ms_mlp(self, x):
88158
current_ms_metadata = get_multistream_comm_context()
89159
assert current_ms_metadata is not None
160+
if self.is_dynamic_quant:
161+
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
162+
x = torch_npu.npu_quant_matmul(
163+
x,
164+
self.gate_up_proj.weight,
165+
self.gate_up_proj.weight_scale,
166+
output_dtype=torch.int32,
167+
)
168+
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
169+
x=x,
170+
weight_scale=self.gate_up_proj.weight_scale_fp32,
171+
activation_scale=dynamic_scale,
172+
bias=None,
173+
quant_scale=None,
174+
quant_offset=None,
175+
group_index=None,
176+
activate_left=True,
177+
quant_mode=1)
178+
x = torch_npu.npu_quant_matmul(
179+
x,
180+
self.down_proj.weight,
181+
self.down_proj.weight_scale,
182+
pertoken_scale=dynamic_scale,
183+
output_dtype=torch.bfloat16,
184+
)
185+
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
186+
current_ms_metadata.before_comm_event.record()
187+
with torch.npu.stream(current_ms_metadata.comm_stream):
188+
current_ms_metadata.before_comm_event.wait()
189+
x = tensor_model_parallel_all_reduce(x)
190+
current_ms_metadata.after_comm_event.record()
191+
return x
90192
gate_up, _ = self.gate_up_proj(x)
91193
x = self.act_fn(gate_up)
92194
current_ms_metadata.before_comm_event.record()

0 commit comments

Comments
 (0)