Skip to content

Commit ee82a17

Browse files
committed
Introduce and use CustomDeepseekV2MergedReplicatedLinear
As the replicated version of MergedColumnParallelLinear, aiming at removing TP communication of DeepSeek-V2's `gate_up_proj` linear. Also, with replicated weight, the chunked input hidden_states can be used by shared experts. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 3e74365 commit ee82a17

File tree

3 files changed

+99
-58
lines changed

3 files changed

+99
-58
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,41 @@ def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
9898
return super().forward_oot(x)
9999

100100

101+
class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
102+
103+
def __init__(
104+
self,
105+
input_size: int,
106+
output_sizes: list[int],
107+
bias: bool = True,
108+
quant_config: Optional[QuantizationConfig] = None,
109+
prefix: str = "",
110+
):
111+
self.output_sizes = output_sizes
112+
super().__init__(input_size,
113+
sum(output_sizes),
114+
bias=bias,
115+
quant_config=quant_config,
116+
prefix=prefix)
117+
118+
def weight_loader(self, param: torch.nn.Parameter,
119+
loaded_weight: torch.Tensor, loaded_shard_id: int):
120+
# With no support for GGUF format yet.
121+
assert not getattr(param, "is_gguf_weight", False)
122+
assert not getattr(param, "is_gguf_weight_type", False)
123+
124+
assert loaded_shard_id < len(self.output_sizes)
125+
shard_offset = sum(self.output_sizes[:loaded_shard_id])
126+
shard_size = self.output_sizes[loaded_shard_id]
127+
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
128+
129+
assert shard.size() == loaded_weight.size(), (
130+
f"Tried to load weights of size {loaded_weight.size()}"
131+
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}"
132+
)
133+
shard.copy_(loaded_weight)
134+
135+
101136
class CustomDeepseekV2MLP(nn.Module):
102137

103138
def __init__(
@@ -107,20 +142,33 @@ def __init__(
107142
hidden_act: str,
108143
quant_config: Optional[QuantizationConfig] = None,
109144
reduce_results: bool = True,
145+
force_replicate: bool = False,
110146
prefix: str = "",
111147
) -> None:
112148
super().__init__()
113-
self.gate_up_proj = MergedColumnParallelLinear(
114-
hidden_size, [intermediate_size] * 2,
115-
bias=False,
116-
quant_config=quant_config,
117-
prefix=f"{prefix}.gate_up_proj")
118-
self.down_proj = RowParallelLinear(intermediate_size,
119-
hidden_size,
120-
bias=False,
121-
quant_config=quant_config,
122-
reduce_results=reduce_results,
123-
prefix=f"{prefix}.down_proj")
149+
if not force_replicate:
150+
self.gate_up_proj = MergedColumnParallelLinear(
151+
hidden_size, [intermediate_size] * 2,
152+
bias=False,
153+
quant_config=quant_config,
154+
prefix=f"{prefix}.gate_up_proj")
155+
self.down_proj = RowParallelLinear(intermediate_size,
156+
hidden_size,
157+
bias=False,
158+
quant_config=quant_config,
159+
reduce_results=reduce_results,
160+
prefix=f"{prefix}.down_proj")
161+
else:
162+
self.gate_up_proj = CustomDeepseekV2MergedReplicatedLinear(
163+
hidden_size, [intermediate_size] * 2,
164+
bias=False,
165+
quant_config=quant_config,
166+
prefix=f"{prefix}.gate_up_proj")
167+
self.down_proj = ReplicatedLinear(intermediate_size,
168+
hidden_size,
169+
bias=False,
170+
quant_config=quant_config,
171+
prefix=f"{prefix}.down_proj")
124172
if hidden_act != "silu":
125173
raise ValueError(f"Unsupported activation: {hidden_act}. "
126174
"Only silu is supported for now.")
@@ -181,6 +229,12 @@ def __init__(
181229
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
182230
"Only silu is supported for now.")
183231

232+
ascend_config = get_ascend_config()
233+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
234+
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
235+
self.enable_multistream_moe = \
236+
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
237+
184238
self.gate = ReplicatedLinear(config.hidden_size,
185239
config.n_routed_experts,
186240
bias=False,
@@ -216,6 +270,7 @@ def __init__(
216270
hidden_act=config.hidden_act,
217271
quant_config=quant_config,
218272
reduce_results=True,
273+
force_replicate=self.enable_multistream_moe,
219274
prefix=f"{prefix}.shared_experts",
220275
)
221276
else:
@@ -230,12 +285,6 @@ def __init__(
230285

231286
self.params_dtype = torch.get_default_dtype()
232287

233-
ascend_config = get_ascend_config()
234-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
235-
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
236-
self.enable_multistream_moe = \
237-
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
238-
239288
def forward(
240289
self,
241290
hidden_states: torch.Tensor,
@@ -274,27 +323,22 @@ def forward(
274323
# router_logits: (num_tokens, n_experts)
275324
router_logits, _ = self.gate(hidden_states)
276325

277-
kwargs = {}
278-
if not use_separated_shared_experts:
279-
kwargs.update({
280-
"shared_experts": self.shared_experts,
281-
"shared_experts_input": old_hidden_states
282-
})
283-
284326
experts_hidden_states = self.experts(
285327
hidden_states=hidden_states,
286328
router_logits=router_logits,
287329
is_prefill=is_prefill,
288330
top_k=CustomDeepseekV2MoE.top_k,
289331
enable_force_load_balance=enable_force_load_balance,
290-
**kwargs)
332+
shared_experts=(self.shared_experts
333+
if not use_separated_shared_experts else None),
334+
)
291335

292336
if not isinstance(experts_hidden_states, tuple):
293337
hidden_states = experts_hidden_states * self.routed_scaling_factor
294338
else:
295-
hidden_states = experts_hidden_states[
296-
0] * self.routed_scaling_factor
297-
shared_hidden_states = experts_hidden_states[1]
339+
hidden_states = (
340+
experts_hidden_states[0] * self.routed_scaling_factor +
341+
experts_hidden_states[1])
298342

299343
if self.tp_size > 1:
300344
if (VLLM_ENABLE_MC2
@@ -309,10 +353,8 @@ def forward(
309353
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
310354

311355
if use_separated_shared_experts:
312-
shared_hidden_states = self.shared_experts(old_hidden_states)
313-
314-
if self.shared_experts is not None:
315-
hidden_states = hidden_states + shared_hidden_states
356+
hidden_states = hidden_states + self.shared_experts(
357+
old_hidden_states)
316358

317359
return hidden_states.view(num_tokens, hidden_size)
318360

vllm_ascend/ops/fused_moe.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

1818
import os
19-
from typing import Callable, List, Optional
19+
from typing import Any, Callable, List, Optional
2020

2121
import torch
2222
import torch.distributed as dist
@@ -1099,8 +1099,8 @@ def forward(self,
10991099
router_logits: torch.Tensor,
11001100
is_prefill: bool,
11011101
enable_force_load_balance: bool = False,
1102-
top_k=None,
1103-
**kwargs):
1102+
top_k: Optional[int] = None,
1103+
shared_experts: Optional[Any] = None):
11041104
assert self.quant_method is not None
11051105

11061106
if top_k:
@@ -1147,14 +1147,13 @@ def forward(self,
11471147
enable_force_load_balance=enable_force_load_balance,
11481148
log2phy=self.log2phy,
11491149
global_redundant_expert_num=self.global_redundant_expert_num,
1150-
**kwargs)
1150+
shared_experts=shared_experts,
1151+
)
11511152

1152-
shared_experts = kwargs.get("shared_experts", None)
1153-
shared_experts_input = kwargs.get("shared_experts_input", None)
11541153
if shared_experts is not None:
11551154
# Provide dummy implementation of "non-separated" shared experts.
11561155
if not isinstance(e_hidden_states, tuple):
1157-
return e_hidden_states, shared_experts(shared_experts_input)
1156+
return e_hidden_states, shared_experts(hidden_states)
11581157
else:
11591158
return e_hidden_states
11601159

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,21 @@ def apply_mlp(hidden_states: torch.Tensor,
105105
return hidden_states
106106

107107

108-
def fused_experts_with_mc2(hidden_states: torch.Tensor,
109-
w1: torch.Tensor,
110-
w2: torch.Tensor,
111-
w1_scale: torch.Tensor,
112-
w2_scale: torch.Tensor,
113-
topk_weights: torch.Tensor,
114-
topk_ids: torch.Tensor,
115-
top_k: int,
116-
expert_map: torch.Tensor = None,
117-
moe_all_to_all_group_name: str = "",
118-
log2phy: torch.Tensor = None,
119-
global_redundant_expert_num: int = 0,
120-
**kwargs) -> torch.Tensor:
108+
def fused_experts_with_mc2(
109+
hidden_states: torch.Tensor,
110+
w1: torch.Tensor,
111+
w2: torch.Tensor,
112+
w1_scale: torch.Tensor,
113+
w2_scale: torch.Tensor,
114+
topk_weights: torch.Tensor,
115+
topk_ids: torch.Tensor,
116+
top_k: int,
117+
expert_map: torch.Tensor = None,
118+
moe_all_to_all_group_name: str = "",
119+
log2phy: torch.Tensor = None,
120+
global_redundant_expert_num: int = 0,
121+
shared_experts: Optional[Any] = None,
122+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
121123

122124
topk_ids = log2phy[topk_ids]
123125
global_bs = 0
@@ -161,13 +163,10 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
161163
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
162164
0:5]
163165

164-
shared_experts = kwargs.get("shared_experts", None)
165-
shared_experts_input = kwargs.get("shared_experts_input", None)
166166
if shared_experts is not None:
167167
with npu_stream_switch("moe_secondary", 0):
168-
npu_wait_tensor(shared_experts_input, topk_weights)
169-
shared_gate_up, _ = shared_experts.gate_up_proj(
170-
shared_experts_input)
168+
npu_wait_tensor(hidden_states, topk_weights)
169+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
171170
npu_wait_tensor(shared_gate_up[0], expand_x)
172171
shared_act = shared_experts.act_fn(shared_gate_up)
173172

@@ -615,6 +614,7 @@ def apply(
615614
enable_force_load_balance: bool = True,
616615
log2phy: torch.Tensor = None,
617616
global_redundant_expert_num: int = 0,
617+
shared_experts: Optional[Any] = None,
618618
**kwargs,
619619
) -> torch.Tensor:
620620
assert router_logits.shape[
@@ -671,7 +671,7 @@ def apply(
671671
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
672672
log2phy=log2phy,
673673
global_redundant_expert_num=global_redundant_expert_num,
674-
**kwargs)
674+
shared_experts=shared_experts)
675675
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
676676
return fused_experts(hidden_states=x,
677677
w1=layer.w13_weight,

0 commit comments

Comments
 (0)