Skip to content

Commit b01cc94

Browse files
committed
use additional_config to enable cv parallel
Signed-off-by: David9857 <985700846@qq.com>
1 parent d118d63 commit b01cc94

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

vllm_ascend/envs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39-
"VLLM_ENABLE_CV_PARALLEL":
40-
lambda: bool(int(os.getenv("VLLM_ENABLE_CV_PARALLEL", '0'))),
4139
"USING_LCCL_COM":
4240
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4341
"SOC_VERSION":

vllm_ascend/models/deepseek_v2.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7272

7373
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
74-
VLLM_ENABLE_CV_PARALLEL: bool = envs_ascend.VLLM_ENABLE_CV_PARALLEL
7574

7675

7776
class CustomDeepseekV2MLP(nn.Module):
@@ -179,6 +178,12 @@ def __init__(
179178
else:
180179
self.gate.e_score_correction_bias = None
181180

181+
self.enable_cv_parallel = False
182+
additional_config = get_current_vllm_config().additional_config
183+
if additional_config:
184+
self.enable_cv_parallel = additional_config.get(
185+
"enable_cv_parallel", False)
186+
182187
self.experts = AscendFusedMoE(
183188
num_experts=config.n_routed_experts,
184189
top_k=config.num_experts_per_tok,
@@ -226,7 +231,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
226231
enable_force_load_balance = False
227232
num_tokens, hidden_dim = hidden_states.shape
228233

229-
cv_parallel = VLLM_ENABLE_CV_PARALLEL and not is_prefill
234+
cv_parallel = self.enable_cv_parallel and not is_prefill
230235

231236
if self.n_shared_experts is not None:
232237
if not cv_parallel:

vllm_ascend/ops/fused_moe.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3838
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
39-
VLLM_ENABLE_CV_PARALLEL: bool = envs_ascend.VLLM_ENABLE_CV_PARALLEL
4039

4140

4241
def fused_experts_with_mc2(
@@ -811,6 +810,11 @@ def __init__(
811810

812811
self.quant_method.create_weights(layer=self, **moe_quant_params)
813812

813+
self.enable_cv_parallel = False
814+
if vllm_config.additional_config:
815+
self.enable_cv_parallel = vllm_config.additional_config.get(
816+
"enable_cv_parallel", False)
817+
814818
def forward(self,
815819
hidden_states: torch.Tensor,
816820
router_logits: torch.Tensor,
@@ -847,7 +851,7 @@ def forward(self,
847851
enable_force_load_balance=enable_force_load_balance,
848852
**kwargs)
849853

850-
if VLLM_ENABLE_CV_PARALLEL and not is_prefill:
854+
if self.enable_cv_parallel and not is_prefill:
851855
final_hidden_states, shared_output = final_hidden_states
852856

853857
if VLLM_ENABLE_MC2 and not is_prefill:
@@ -857,6 +861,6 @@ def forward(self,
857861
final_hidden_states = tensor_model_parallel_all_reduce(
858862
final_hidden_states)
859863

860-
if VLLM_ENABLE_CV_PARALLEL and not is_prefill:
864+
if self.enable_cv_parallel and not is_prefill:
861865
return final_hidden_states, shared_output
862866
return final_hidden_states

0 commit comments

Comments
 (0)