Skip to content

Commit 5ffdb6b

Browse files
sdmyzlpwangxiaoxin (A)
authored andcommitted
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness. <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Implement multi-stream parallelism for MoE layers with shared experts, where computation of shared experts will be overlapped with expert token dispatch and combine. Also, when multi-stream is enabled, weights of shared experts will be force to replicate across all cards, regardless of any tensor parallelism configurations, to avoid AllReduce operations. With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? No. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent e96a559 commit 5ffdb6b

File tree

11 files changed

+296
-308
lines changed

11 files changed

+296
-308
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ jobs:
186186
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
187187
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
188188
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
189+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
189190
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
190191
fi
191192
@@ -216,5 +217,6 @@ jobs:
216217
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
217218
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
218219
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
220+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
219221
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
220222
fi

docs/source/user_guide/additional_config.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ The details of each config option are as follows:
4040
| Name | Type | Default | Description |
4141
| ---- | ---- | ------- | ----------- |
4242
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
43+
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
4344
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
4445
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
4546
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
4647
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
47-
| `enable_multistream_shared_expert`| bool | `False` | Whether to enable multistream shared expert |
4848

4949
**ascend_scheduler_config**
5050

@@ -65,7 +65,7 @@ A full example of additional configuration is as follows:
6565
"use_cached_graph": true,
6666
"graph_batch_sizes": [1, 2, 4, 8],
6767
"graph_batch_sizes_init": false,
68-
"enable_multistream_shared_expert": false
68+
"enable_multistream_moe": false
6969
},
7070
"ascend_scheduler_config": {
7171
"enabled": true,

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ warn_unused_configs = True
66
[mypy-torch_npu.*]
77
ignore_missing_imports = True
88

9+
[mypy-torchair.*]
10+
ignore_missing_imports = True
11+
912
[mypy-transformers.*]
1013
ignore_missing_imports = True
1114

tests/multicard/test_offline_inference_distributed.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import os
2424
from unittest.mock import patch
2525

26-
import vllm # noqa: F401
26+
from modelscope import snapshot_download # type: ignore
2727
from vllm import SamplingParams
2828

2929
from tests.conftest import VllmRunner
@@ -95,3 +95,20 @@ def test_models_distributed_DeepSeek_dbo():
9595
distributed_executor_backend="mp",
9696
) as vllm_model:
9797
vllm_model.generate(example_prompts, sampling_params)
98+
99+
100+
def test_models_distributed_DeepSeek_W8A8():
101+
example_prompts = [
102+
"Hello, my name is",
103+
]
104+
max_tokens = 5
105+
106+
with VllmRunner(
107+
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
108+
max_model_len=8192,
109+
enforce_eager=True,
110+
dtype="auto",
111+
tensor_parallel_size=4,
112+
quantization="ascend",
113+
) as vllm_model:
114+
vllm_model.generate_greedy(example_prompts, max_tokens)

tests/singlecard/test_ascend_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_run_with_ascend_config():
5858
"use_cached_graph": True,
5959
"graph_batch_sizes": [1, 2, 4, 8],
6060
"graph_batch_sizes_init": False,
61-
"enable_multistream_shared_expert": True,
61+
"enable_multistream_moe": True,
6262
},
6363
"ascend_scheduler_config": {
6464
"enabled": True,
@@ -79,7 +79,7 @@ def test_run_with_ascend_config():
7979
1, 2, 4, 8
8080
]
8181
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
82-
assert ascend_config.torchair_graph_config.enable_multistream_shared_expert
82+
assert ascend_config.torchair_graph_config.enable_multistream_moe
8383
assert ascend_config.ascend_scheduler_config.enabled
8484
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill
8585
assert ascend_config.expert_tensor_parallel_size == 1

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def __init__(self, torchair_graph_config):
5656
"graph_batch_sizes", [])
5757
self.graph_batch_sizes_init = torchair_graph_config.get(
5858
"graph_batch_sizes_init", False)
59-
self.enable_multistream_shared_expert = torchair_graph_config.get(
60-
"enable_multistream_shared_expert", False)
59+
self.enable_multistream_moe = torchair_graph_config.get(
60+
"enable_multistream_moe", False)
6161
self.enable_view_optimize = torchair_graph_config.get(
6262
"enable_view_optimize", True)
6363

vllm_ascend/models/deepseek_dbo.py

Lines changed: 4 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import torch
3131
import torch.distributed as dist
32-
import torch_npu
32+
import torch_npu # noqa: F401
3333
import vllm.envs as envs
3434
from torch import nn
3535
from transformers import PretrainedConfig
@@ -40,13 +40,10 @@
4040
get_tp_group, tensor_model_parallel_all_reduce)
4141
from vllm.distributed.parallel_state import get_dp_group
4242
from vllm.forward_context import get_forward_context
43-
from vllm.model_executor.layers.activation import SiluAndMul
4443
from vllm.model_executor.layers.layernorm import RMSNorm
4544
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
46-
MergedColumnParallelLinear,
4745
ReplicatedLinear,
48-
RowParallelLinear,
49-
UnquantizedLinearMethod)
46+
RowParallelLinear)
5047
from vllm.model_executor.layers.logits_processor import LogitsProcessor
5148
from vllm.model_executor.layers.quantization import QuantizationConfig
5249
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -67,6 +64,7 @@
6764

6865
import vllm_ascend.envs as envs_ascend
6966
from vllm_ascend.ascend_config import get_ascend_config
67+
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
7068
from vllm_ascend.multistream.base import MSEventKey
7169
from vllm_ascend.multistream.context import (
7270
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -78,117 +76,17 @@
7876
make_multistream_metadata_ds)
7977
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8078
from vllm_ascend.ops.fused_moe import AscendFusedMoE
81-
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
8279
from vllm_ascend.utils import dispose_tensor
8380

8481
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
8582
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
8683

8784

88-
class CustomDeepseekDBOMLP(nn.Module):
89-
90-
def __init__(
91-
self,
92-
hidden_size: int,
93-
intermediate_size: int,
94-
hidden_act: str,
95-
quant_config: Optional[QuantizationConfig] = None,
96-
reduce_results: bool = True,
97-
prefix: str = "",
98-
) -> None:
99-
super().__init__()
100-
self.gate_up_proj = MergedColumnParallelLinear(
101-
hidden_size, [intermediate_size] * 2,
102-
bias=False,
103-
quant_config=quant_config,
104-
prefix=f"{prefix}.gate_up_proj")
105-
self.down_proj = RowParallelLinear(intermediate_size,
106-
hidden_size,
107-
bias=False,
108-
quant_config=quant_config,
109-
reduce_results=reduce_results,
110-
prefix=f"{prefix}.down_proj")
111-
if hidden_act != "silu":
112-
raise ValueError(f"Unsupported activation: {hidden_act}. "
113-
"Only silu is supported for now.")
114-
self.act_fn = SiluAndMul()
115-
116-
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
117-
self.is_dynamic_quant = not isinstance(
118-
self.gate_up_proj.quant_method,
119-
UnquantizedLinearMethod) and isinstance(
120-
self.gate_up_proj.quant_method.quant_method,
121-
AscendW8A8DynamicLinearMethod)
122-
123-
def forward(self, x):
124-
if self.is_dynamic_quant:
125-
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
126-
x = torch_npu.npu_quant_matmul(
127-
x,
128-
self.gate_up_proj.weight,
129-
self.gate_up_proj.weight_scale,
130-
output_dtype=torch.int32,
131-
)
132-
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
133-
x=x,
134-
weight_scale=self.gate_up_proj.weight_scale_fp32,
135-
activation_scale=dynamic_scale,
136-
bias=None,
137-
quant_scale=None,
138-
quant_offset=None,
139-
group_index=None,
140-
activate_left=True,
141-
quant_mode=1)
142-
x = torch_npu.npu_quant_matmul(
143-
x,
144-
self.down_proj.weight,
145-
self.down_proj.weight_scale,
146-
pertoken_scale=dynamic_scale,
147-
output_dtype=torch.bfloat16,
148-
)
149-
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
150-
x = tensor_model_parallel_all_reduce(x)
151-
return x
152-
gate_up, _ = self.gate_up_proj(x)
153-
x = self.act_fn(gate_up)
154-
x, _ = self.down_proj(x)
155-
return x
85+
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
15686

15787
def _forward_ms_mlp(self, x):
15888
current_ms_metadata = get_multistream_comm_context()
15989
assert current_ms_metadata is not None
160-
if self.is_dynamic_quant:
161-
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
162-
x = torch_npu.npu_quant_matmul(
163-
x,
164-
self.gate_up_proj.weight,
165-
self.gate_up_proj.weight_scale,
166-
output_dtype=torch.int32,
167-
)
168-
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
169-
x=x,
170-
weight_scale=self.gate_up_proj.weight_scale_fp32,
171-
activation_scale=dynamic_scale,
172-
bias=None,
173-
quant_scale=None,
174-
quant_offset=None,
175-
group_index=None,
176-
activate_left=True,
177-
quant_mode=1)
178-
x = torch_npu.npu_quant_matmul(
179-
x,
180-
self.down_proj.weight,
181-
self.down_proj.weight_scale,
182-
pertoken_scale=dynamic_scale,
183-
output_dtype=torch.bfloat16,
184-
)
185-
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
186-
current_ms_metadata.before_comm_event.record()
187-
with torch.npu.stream(current_ms_metadata.comm_stream):
188-
current_ms_metadata.before_comm_event.wait()
189-
x = tensor_model_parallel_all_reduce(x)
190-
current_ms_metadata.after_comm_event.record()
191-
return x
19290
gate_up, _ = self.gate_up_proj(x)
19391
x = self.act_fn(gate_up)
19492
current_ms_metadata.before_comm_event.record()

0 commit comments

Comments
 (0)