Skip to content

Commit c73953c

Browse files
author
weijinqian
committed
handle clean code
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent 0719b71 commit c73953c

File tree

6 files changed

+37
-47
lines changed

6 files changed

+37
-47
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool):
2828
return FusedMoEState.AllGather
2929
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
3030
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
31-
return (
32-
FusedMoEState.All2AllSeq
33-
if (ep_size < 16 or with_prefill)
34-
else FusedMoEState.MC2
35-
)
31+
return (FusedMoEState.All2AllSeq if
32+
(ep_size < 16 or with_prefill) else FusedMoEState.MC2)
3633
elif ep_size >= 16 and with_prefill and enable_chunk_mc2:
3734
return FusedMoEState.MC2_PREFILL
3835
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
@@ -58,19 +55,16 @@ def set_ascend_forward_context(
5855
We add some additional param into forward_context.
5956
"""
6057
with set_forward_context(
61-
attn_metadata,
62-
vllm_config,
63-
virtual_engine=virtual_engine,
64-
num_tokens=num_tokens,
65-
num_tokens_across_dp=num_tokens_across_dp,
58+
attn_metadata,
59+
vllm_config,
60+
virtual_engine=virtual_engine,
61+
num_tokens=num_tokens,
62+
num_tokens_across_dp=num_tokens_across_dp,
6663
):
6764
forward_context = get_forward_context()
6865
forward_context.with_prefill = with_prefill
69-
ep_size = (
70-
torch.distributed.get_world_size()
71-
if vllm_config.parallel_config.enable_expert_parallel
72-
else 1
73-
)
66+
ep_size = (torch.distributed.get_world_size() if
67+
vllm_config.parallel_config.enable_expert_parallel else 1)
7468

7569
fused_moe_state = get_fused_moe_state(ep_size, with_prefill)
7670

@@ -88,18 +82,16 @@ def set_ascend_forward_context(
8882
num_tokens = attn_metadata.num_actual_tokens
8983
else:
9084
# for v0 engine
91-
num_tokens = (
92-
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
93-
)
85+
num_tokens = (attn_metadata.num_prefill_tokens +
86+
attn_metadata.num_decode_tokens)
9487

9588
if num_actual_tokens is None:
9689
num_actual_tokens = num_tokens
9790

9891
dp_world_size = get_dp_group().world_size
9992
if dp_world_size > 1 and forward_context.dp_metadata is not None:
10093
max_tokens_across_dp = (
101-
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
102-
)
94+
forward_context.dp_metadata.max_tokens_across_dp_cpu.item())
10395
else:
10496
max_tokens_across_dp = num_tokens
10597

@@ -110,31 +102,26 @@ def set_ascend_forward_context(
110102
world_size = torch.distributed.get_world_size()
111103
# NOTE: token num which need to pad to when mc2
112104
forward_context.padded_num_tokens = (
113-
math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
114-
)
105+
math.ceil(max_tokens_across_dp / tp_world_size) *
106+
tp_world_size)
115107
# NOTE: mc2 op's param `global_bs`, add `world_size` to make `global_bs` absolutely larger than actual global_bs.
116108
forward_context.global_bs = (
117-
math.ceil(max_tokens_across_dp / tp_world_size) * world_size
118-
)
109+
math.ceil(max_tokens_across_dp / tp_world_size) * world_size)
119110

120111
if fused_moe_state == FusedMoEState.MC2_PREFILL:
121112
chunk_size = envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE
122113
forward_context.max_num_chunks = math.ceil(
123-
math.ceil(max_tokens_across_dp / tp_world_size) / chunk_size
124-
)
114+
math.ceil(max_tokens_across_dp / tp_world_size) /
115+
chunk_size)
125116

126-
forward_context.global_bs = (
127-
math.ceil(
128-
math.ceil(max_tokens_across_dp / tp_world_size)
129-
/ forward_context.max_num_chunks
130-
)
131-
* world_size
132-
)
117+
forward_context.global_bs = (math.ceil(
118+
math.ceil(max_tokens_across_dp / tp_world_size) /
119+
forward_context.max_num_chunks) * world_size)
133120

134121
min_num_tokens = forward_context.max_num_chunks * tp_world_size
135122
forward_context.padded_num_tokens = (
136-
math.ceil(max_tokens_across_dp / min_num_tokens) * min_num_tokens
137-
)
123+
math.ceil(max_tokens_across_dp / min_num_tokens) *
124+
min_num_tokens)
138125

139126
mc2_mask = torch.zeros(
140127
forward_context.padded_num_tokens,

vllm_ascend/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,6 @@ def register_model():
5858
ModelRegistry.register_model(
5959
"Qwen3MoeForCausalLM",
6060
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
61-
61+
6262
ModelRegistry.register_model(
6363
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")

vllm_ascend/models/deepseek_dbo.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ def __init__(
147147
intermediate_size=intermediate_size,
148148
hidden_act=config.hidden_act,
149149
quant_config=quant_config,
150-
reduce_results=not envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is separated in alltoallv for better overlap.
150+
reduce_results=not envs_ascend.
151+
VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is separated in alltoallv for better overlap.
151152
prefix=f"{prefix}.shared_experts",
152153
)
153154
CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok
@@ -232,7 +233,9 @@ def _forward_op_gating(
232233
chunk_hidden_states = torch.tensor_split(hidden_states,
233234
self.tp_size,
234235
dim=0)
235-
chunked_hidden_states_sizes = [x.shape[0] for x in chunk_hidden_states]
236+
chunked_hidden_states_sizes = [
237+
x.shape[0] for x in chunk_hidden_states
238+
]
236239
local_hidden_states = chunk_hidden_states[self.tp_rank]
237240
else:
238241
local_hidden_states = hidden_states
@@ -245,7 +248,7 @@ def _forward_op_gating(
245248
if self.config.n_routed_experts == 256:
246249
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
247250
router_logits,
248-
k=self.config.num_experts_per_tok,
251+
k=self.config.num_experts_per_tok,
249252
bias=self.gate.e_score_correction_bias,
250253
k_group=self.config.topk_group, # fix: 4
251254
group_count=self.config.n_group, # fix 8
@@ -273,7 +276,8 @@ def _forward_op_gating(
273276
# to avoid accumulating too much tokens on a single rank.
274277
# currently it is only activated when doing profile runs.
275278
if enable_force_load_balance:
276-
topk_ids = torch.randint_like(topk_ids, 0, self.config.n_routed_experts)
279+
topk_ids = torch.randint_like(topk_ids, 0,
280+
self.config.n_routed_experts)
277281

278282
return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes
279283

vllm_ascend/models/qwen3_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
3333
"gate_proj",
3434
"up_proj",
3535
],
36-
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
36+
"experts":
37+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
3738
}
3839
qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock

vllm_ascend/multistream/ms_split.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ def model_input_split_v1_attn(
294294
token_index)
295295

296296
is_only_prefill_pre = is_only_prefill_post = attn_metadata.is_only_prefill
297-
has_prefill_pre, _ = torch.any(
298-
query_lens_pre > 1).item(), torch.any(query_lens_post > 1).item()
297+
has_prefill_pre, _ = torch.any(query_lens_pre > 1).item(), torch.any(
298+
query_lens_post > 1).item()
299299

300300
if not attn_metadata.is_only_prefill:
301301
is_only_prefill_post = torch.all(query_lens_post > 1).item()

vllm_ascend/ops/moe_dispatcher/token_dispatcher.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ def preprocess(self,
232232

233233
ep_size = self.ep_size
234234

235-
236235
# Dropless
237236
self.num_out_tokens = indices.numel()
238237
if self.ep_size > 1 or self.num_local_experts > 1:
@@ -408,7 +407,6 @@ def preprocess_and_permtute1(self,
408407
shared_output = shared_experts(shared_experts_input)
409408
self.cached_shared_expert_output = shared_output
410409

411-
412410
hidden_states, self.reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
413411
tokens=hidden_states,
414412
indices=self.top_indices,
@@ -542,8 +540,8 @@ def alltoall_token_unpermutation2(permutated_local_input_tokens):
542540

543541
output = torch_npu.npu_moe_token_unpermute(
544542
permuted_tokens=permutated_local_input_tokens,
545-
sorted_indices=self.
546-
reversed_local_input_permutation_mapping.to(torch.int32),
543+
sorted_indices=self.reversed_local_input_permutation_mapping.
544+
to(torch.int32),
547545
probs=self.probs,
548546
restore_shape=self.hidden_shape_before_permute)
549547

0 commit comments

Comments
 (0)