Skip to content

Commit 9d27fcb

Browse files
committed
Introduce and use CustomDeepseekV2MergedReplicatedLinear
As the replicated version of MergedColumnParallelLinear, aiming at removing TP communication of DeepSeek-V2's `gate_up_proj` linear. Also, with replicated weight, the chunked input hidden_states can be used by shared experts. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 922eb59 commit 9d27fcb

File tree

3 files changed

+99
-58
lines changed

3 files changed

+99
-58
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,41 @@ def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
101101
return super().forward_oot(x)
102102

103103

104+
class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
105+
106+
def __init__(
107+
self,
108+
input_size: int,
109+
output_sizes: list[int],
110+
bias: bool = True,
111+
quant_config: Optional[QuantizationConfig] = None,
112+
prefix: str = "",
113+
):
114+
self.output_sizes = output_sizes
115+
super().__init__(input_size,
116+
sum(output_sizes),
117+
bias=bias,
118+
quant_config=quant_config,
119+
prefix=prefix)
120+
121+
def weight_loader(self, param: torch.nn.Parameter,
122+
loaded_weight: torch.Tensor, loaded_shard_id: int):
123+
# With no support for GGUF format yet.
124+
assert not getattr(param, "is_gguf_weight", False)
125+
assert not getattr(param, "is_gguf_weight_type", False)
126+
127+
assert loaded_shard_id < len(self.output_sizes)
128+
shard_offset = sum(self.output_sizes[:loaded_shard_id])
129+
shard_size = self.output_sizes[loaded_shard_id]
130+
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
131+
132+
assert shard.size() == loaded_weight.size(), (
133+
f"Tried to load weights of size {loaded_weight.size()}"
134+
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}"
135+
)
136+
shard.copy_(loaded_weight)
137+
138+
104139
class CustomDeepseekV2MLP(nn.Module):
105140

106141
def __init__(
@@ -110,20 +145,33 @@ def __init__(
110145
hidden_act: str,
111146
quant_config: Optional[QuantizationConfig] = None,
112147
reduce_results: bool = True,
148+
force_replicate: bool = False,
113149
prefix: str = "",
114150
) -> None:
115151
super().__init__()
116-
self.gate_up_proj = MergedColumnParallelLinear(
117-
hidden_size, [intermediate_size] * 2,
118-
bias=False,
119-
quant_config=quant_config,
120-
prefix=f"{prefix}.gate_up_proj")
121-
self.down_proj = RowParallelLinear(intermediate_size,
122-
hidden_size,
123-
bias=False,
124-
quant_config=quant_config,
125-
reduce_results=reduce_results,
126-
prefix=f"{prefix}.down_proj")
152+
if not force_replicate:
153+
self.gate_up_proj = MergedColumnParallelLinear(
154+
hidden_size, [intermediate_size] * 2,
155+
bias=False,
156+
quant_config=quant_config,
157+
prefix=f"{prefix}.gate_up_proj")
158+
self.down_proj = RowParallelLinear(intermediate_size,
159+
hidden_size,
160+
bias=False,
161+
quant_config=quant_config,
162+
reduce_results=reduce_results,
163+
prefix=f"{prefix}.down_proj")
164+
else:
165+
self.gate_up_proj = CustomDeepseekV2MergedReplicatedLinear(
166+
hidden_size, [intermediate_size] * 2,
167+
bias=False,
168+
quant_config=quant_config,
169+
prefix=f"{prefix}.gate_up_proj")
170+
self.down_proj = ReplicatedLinear(intermediate_size,
171+
hidden_size,
172+
bias=False,
173+
quant_config=quant_config,
174+
prefix=f"{prefix}.down_proj")
127175
if hidden_act != "silu":
128176
raise ValueError(f"Unsupported activation: {hidden_act}. "
129177
"Only silu is supported for now.")
@@ -189,6 +237,12 @@ def __init__(
189237
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
190238
"Only silu is supported for now.")
191239

240+
ascend_config = get_ascend_config()
241+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
242+
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
243+
self.enable_multistream_moe = \
244+
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
245+
192246
self.gate = ReplicatedLinear(config.hidden_size,
193247
config.n_routed_experts,
194248
bias=False,
@@ -224,6 +278,7 @@ def __init__(
224278
hidden_act=config.hidden_act,
225279
quant_config=quant_config,
226280
reduce_results=True,
281+
force_replicate=self.enable_multistream_moe,
227282
prefix=f"{prefix}.shared_experts",
228283
)
229284
else:
@@ -238,12 +293,6 @@ def __init__(
238293

239294
self.params_dtype = torch.get_default_dtype()
240295

241-
ascend_config = get_ascend_config()
242-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
243-
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
244-
self.enable_multistream_moe = \
245-
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
246-
247296
def forward(
248297
self,
249298
hidden_states: torch.Tensor,
@@ -282,27 +331,22 @@ def forward(
282331
# router_logits: (num_tokens, n_experts)
283332
router_logits, _ = self.gate(hidden_states)
284333

285-
kwargs = {}
286-
if not use_separated_shared_experts:
287-
kwargs.update({
288-
"shared_experts": self.shared_experts,
289-
"shared_experts_input": old_hidden_states
290-
})
291-
292334
experts_hidden_states = self.experts(
293335
hidden_states=hidden_states,
294336
router_logits=router_logits,
295337
is_prefill=is_prefill,
296338
top_k=CustomDeepseekV2MoE.top_k,
297339
enable_force_load_balance=enable_force_load_balance,
298-
**kwargs)
340+
shared_experts=(self.shared_experts
341+
if not use_separated_shared_experts else None),
342+
)
299343

300344
if not isinstance(experts_hidden_states, tuple):
301345
hidden_states = experts_hidden_states * self.routed_scaling_factor
302346
else:
303-
hidden_states = experts_hidden_states[
304-
0] * self.routed_scaling_factor
305-
shared_hidden_states = experts_hidden_states[1]
347+
hidden_states = (
348+
experts_hidden_states[0] * self.routed_scaling_factor +
349+
experts_hidden_states[1])
306350

307351
if self.tp_size > 1:
308352
if (VLLM_ENABLE_MC2
@@ -317,10 +361,8 @@ def forward(
317361
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
318362

319363
if use_separated_shared_experts:
320-
shared_hidden_states = self.shared_experts(old_hidden_states)
321-
322-
if self.shared_experts is not None:
323-
hidden_states = hidden_states + shared_hidden_states
364+
hidden_states = hidden_states + self.shared_experts(
365+
old_hidden_states)
324366

325367
return hidden_states.view(num_tokens, hidden_size)
326368

vllm_ascend/ops/fused_moe.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

1818
import os
19-
from typing import Callable, List, Optional
19+
from typing import Any, Callable, List, Optional
2020

2121
import torch
2222
import torch.distributed as dist
@@ -1099,8 +1099,8 @@ def forward(self,
10991099
router_logits: torch.Tensor,
11001100
is_prefill: bool,
11011101
enable_force_load_balance: bool = False,
1102-
top_k=None,
1103-
**kwargs):
1102+
top_k: Optional[int] = None,
1103+
shared_experts: Optional[Any] = None):
11041104
assert self.quant_method is not None
11051105

11061106
if top_k:
@@ -1147,14 +1147,13 @@ def forward(self,
11471147
enable_force_load_balance=enable_force_load_balance,
11481148
log2phy=self.log2phy,
11491149
global_redundant_expert_num=self.global_redundant_expert_num,
1150-
**kwargs)
1150+
shared_experts=shared_experts,
1151+
)
11511152

1152-
shared_experts = kwargs.get("shared_experts", None)
1153-
shared_experts_input = kwargs.get("shared_experts_input", None)
11541153
if shared_experts is not None:
11551154
# Provide dummy implementation of "non-separated" shared experts.
11561155
if not isinstance(e_hidden_states, tuple):
1157-
return e_hidden_states, shared_experts(shared_experts_input)
1156+
return e_hidden_states, shared_experts(hidden_states)
11581157
else:
11591158
return e_hidden_states
11601159

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,21 @@ def apply_mlp(hidden_states: torch.Tensor,
105105
return hidden_states
106106

107107

108-
def fused_experts_with_mc2(hidden_states: torch.Tensor,
109-
w1: torch.Tensor,
110-
w2: torch.Tensor,
111-
w1_scale: torch.Tensor,
112-
w2_scale: torch.Tensor,
113-
topk_weights: torch.Tensor,
114-
topk_ids: torch.Tensor,
115-
top_k: int,
116-
expert_map: torch.Tensor = None,
117-
moe_all_to_all_group_name: str = "",
118-
log2phy: torch.Tensor = None,
119-
global_redundant_expert_num: int = 0,
120-
**kwargs) -> torch.Tensor:
108+
def fused_experts_with_mc2(
109+
hidden_states: torch.Tensor,
110+
w1: torch.Tensor,
111+
w2: torch.Tensor,
112+
w1_scale: torch.Tensor,
113+
w2_scale: torch.Tensor,
114+
topk_weights: torch.Tensor,
115+
topk_ids: torch.Tensor,
116+
top_k: int,
117+
expert_map: torch.Tensor = None,
118+
moe_all_to_all_group_name: str = "",
119+
log2phy: torch.Tensor = None,
120+
global_redundant_expert_num: int = 0,
121+
shared_experts: Optional[Any] = None,
122+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
121123
if log2phy:
122124
topk_ids = log2phy[topk_ids]
123125
global_bs = 0
@@ -161,13 +163,10 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
161163
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
162164
0:5]
163165

164-
shared_experts = kwargs.get("shared_experts", None)
165-
shared_experts_input = kwargs.get("shared_experts_input", None)
166166
if shared_experts is not None:
167167
with npu_stream_switch("moe_secondary", 0):
168-
npu_wait_tensor(shared_experts_input, topk_weights)
169-
shared_gate_up, _ = shared_experts.gate_up_proj(
170-
shared_experts_input)
168+
npu_wait_tensor(hidden_states, topk_weights)
169+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
171170
npu_wait_tensor(shared_gate_up[0], expand_x)
172171
shared_act = shared_experts.act_fn(shared_gate_up)
173172

@@ -618,6 +617,7 @@ def apply(
618617
enable_force_load_balance: bool = True,
619618
log2phy: torch.Tensor = None,
620619
global_redundant_expert_num: int = 0,
620+
shared_experts: Optional[Any] = None,
621621
**kwargs,
622622
) -> torch.Tensor:
623623
assert router_logits.shape[
@@ -674,7 +674,7 @@ def apply(
674674
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
675675
log2phy=log2phy,
676676
global_redundant_expert_num=global_redundant_expert_num,
677-
**kwargs)
677+
shared_experts=shared_experts)
678678
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
679679
return fused_experts(hidden_states=x,
680680
w1=layer.w13_weight,

0 commit comments

Comments
 (0)