Skip to content

Commit 3b99491

Browse files
NNUCJzzzzwwjj
andauthored
add chunk mc2 for prefill (#1703)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? ```shell export HCCL_IF_IP=xxxxxx export GLOO_SOCKET_IFNAME=enp48s3u1u1 export TP_SOCKET_IFNAME=enp48s3u1u1 export HCCL_SOCKET_IFNAME=enp48s3u1u1 # export HCCL_BUFFSIZE=2048 export VLLM_USE_V1=1 export VLLM_ASCEND_ENABLE_CHUNK_MC2=1 export VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE=256 export HCCL_BUFFSIZE=2048 # export HCCL_BUFFSIZE=1024 export ASCEND_LAUNCH_BLOCKING=0 export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True" rm -rf ./.torchair_cache/ rm -rf ./dynamo_* model_path="model_path" python data_parallel.py \ --model=${model_path} \ --dp-size=2 \ --tp-size=8 \ --enforce-eager \ --trust-remote-code \ --node-size=1 \ --node-rank=0 \ ``` --------- Signed-off-by: NNUCJ <616151263@qq.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
1 parent 2351977 commit 3b99491

File tree

6 files changed

+251
-88
lines changed

6 files changed

+251
-88
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,31 @@
1+
import math
12
from contextlib import contextmanager
23
from enum import Enum
34
from typing import Any, Optional
45

56
import torch
67
from vllm.config import VllmConfig
7-
from vllm.distributed import get_dp_group
8+
from vllm.distributed import get_dp_group, get_tp_group
89
from vllm.forward_context import get_forward_context, set_forward_context
10+
from vllm.platforms import current_platform
11+
12+
import vllm_ascend.envs as envs
913

1014

1115
class FusedMoEState(Enum):
1216
AllGather = 0
1317
All2All = 1
1418
MC2 = 2
19+
MC2_PREFILL = 3
1520

1621

1722
# TODO(zzzzwwjj): add soc_version to choose branch
1823
def get_fused_moe_state(ep_size: int, with_prefill: bool):
24+
enable_chunk_mc2 = envs.VLLM_ASCEND_ENABLE_CHUNK_MC2
1925
if ep_size == 1:
2026
return FusedMoEState.AllGather
27+
elif ep_size >= 16 and with_prefill and enable_chunk_mc2:
28+
return FusedMoEState.MC2_PREFILL
2129
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
2230
elif ep_size < 16 or with_prefill:
2331
return FusedMoEState.All2All
@@ -33,7 +41,8 @@ def set_ascend_forward_context(
3341
num_tokens: Optional[int] = None,
3442
num_tokens_across_dp: Optional[torch.Tensor] = None,
3543
with_prefill: bool = True,
36-
in_profile_run: bool = False):
44+
in_profile_run: bool = False,
45+
num_actual_tokens: Optional[int] = None):
3746
"""A context manager that stores the current forward context,
3847
can be attention metadata, etc.
3948
We add some additional param into forward_context.
@@ -45,7 +54,6 @@ def set_ascend_forward_context(
4554
num_tokens_across_dp=num_tokens_across_dp):
4655
forward_context = get_forward_context()
4756
forward_context.with_prefill = with_prefill
48-
4957
ep_size = torch.distributed.get_world_size(
5058
) if vllm_config.parallel_config.enable_expert_parallel else 1
5159

@@ -59,19 +67,55 @@ def set_ascend_forward_context(
5967
# due to multiple warmups before actual capturing
6068
forward_context.capturing = False
6169

70+
if num_tokens is None and attn_metadata is not None:
71+
if hasattr(attn_metadata, 'num_actual_tokens'):
72+
# for v1 engine
73+
num_tokens = attn_metadata.num_actual_tokens
74+
else:
75+
# for v0 engine
76+
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
77+
78+
if num_actual_tokens is None:
79+
num_actual_tokens = num_tokens
80+
6281
dp_world_size = get_dp_group().world_size
6382
if dp_world_size > 1 and forward_context.dp_metadata is not None:
64-
forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
83+
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
6584
)
66-
elif num_tokens is not None:
67-
forward_context.max_tokens_across_dp = num_tokens
68-
elif attn_metadata is not None:
69-
if hasattr(attn_metadata, 'num_actual_tokens'):
70-
forward_context.max_tokens_across_dp = attn_metadata.num_actual_tokens
71-
else:
72-
forward_context.max_tokens_across_dp = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
7385
else:
74-
forward_context.max_tokens_across_dp = None
86+
max_tokens_across_dp = num_tokens
87+
88+
forward_context.max_tokens_across_dp = max_tokens_across_dp
89+
90+
if num_tokens is not None:
91+
tp_world_size = get_tp_group().world_size
92+
world_size = torch.distributed.get_world_size()
93+
# NOTE: token num which need to pad to when mc2
94+
forward_context.padded_num_tokens = math.ceil(
95+
max_tokens_across_dp / tp_world_size) * tp_world_size
96+
# NOTE: mc2 op's param `global_bs`, add `world_size` to make `global_bs` absolutely larger than actual global_bs.
97+
forward_context.global_bs = math.ceil(
98+
max_tokens_across_dp / tp_world_size) * world_size
99+
100+
if fused_moe_state == FusedMoEState.MC2_PREFILL:
101+
chunk_size = envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE
102+
forward_context.max_num_chunks = math.ceil(
103+
math.ceil(max_tokens_across_dp / tp_world_size) /
104+
chunk_size)
105+
106+
forward_context.global_bs = math.ceil(
107+
math.ceil(max_tokens_across_dp / tp_world_size) /
108+
forward_context.max_num_chunks) * world_size
109+
110+
min_num_tokens = forward_context.max_num_chunks * tp_world_size
111+
forward_context.padded_num_tokens = math.ceil(
112+
max_tokens_across_dp / min_num_tokens) * min_num_tokens
113+
114+
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
115+
dtype=torch.bool,
116+
device=current_platform.device_type)
117+
mc2_mask[:num_actual_tokens] = True
118+
forward_context.mc2_mask = mc2_mask
75119

76120
try:
77121
yield

vllm_ascend/attention/mla_v1.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from vllm.config import get_current_vllm_config
1212
from vllm.model_executor.layers.linear import (LinearBase,
1313
UnquantizedLinearMethod)
14-
from vllm.platforms import current_platform
1514
from vllm.utils import cdiv, round_down
1615

1716
from vllm_ascend import envs
@@ -71,6 +70,7 @@ class ChunkedContextMetadata:
7170
max_seq_lens: list[int]
7271
workspace: torch.Tensor
7372
chunk_seq_lens: torch.Tensor
73+
chunk_seq_lens_npu: torch.Tensor
7474

7575
attn_mask: torch.Tensor
7676
query_lens: list[int]
@@ -99,7 +99,6 @@ class AscendMLADecodeMetadata:
9999
attn_mask: Optional[torch.Tensor] = None
100100
sin: torch.Tensor = None
101101
cos: torch.Tensor = None
102-
mc2_mask: Optional[torch.Tensor] = None
103102

104103

105104
@dataclass
@@ -215,13 +214,6 @@ def __init__(self,
215214
self.cos_cache = None
216215
self.sin_cache = None
217216

218-
def generate_activate_mask(self, actual_seqs_num, batch_size):
219-
mc2_mask = torch.zeros(batch_size,
220-
dtype=torch.bool,
221-
device=current_platform.device_type)
222-
mc2_mask[:actual_seqs_num].fill_(True)
223-
return mc2_mask
224-
225217
def reorder_batch(self, input_batch: "InputBatch",
226218
scheduler_output: "SchedulerOutput") -> bool:
227219
# We now want to reorder the batch so that the "decode" requests are at
@@ -364,7 +356,6 @@ def build_torchair_graph_dummy(
364356
self.rope_dim,
365357
dtype=self.runner.dtype,
366358
device=device)
367-
mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs)
368359
decode_metadata = AscendMLADecodeMetadata(
369360
input_positions=input_positions,
370361
block_table=block_table,
@@ -374,8 +365,7 @@ def build_torchair_graph_dummy(
374365
attn_mask=self.runner.spec_attn_mask,
375366
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
376367
sin=sin,
377-
cos=cos,
378-
mc2_mask=mc2_mask)
368+
cos=cos)
379369
return self.metadata_cls( # type: ignore
380370
num_input_tokens=num_actual_tokens,
381371
num_actual_tokens=num_actual_tokens,
@@ -481,6 +471,7 @@ def build(
481471
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
482472
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
483473
chunk_seq_lens=chunk_seq_lens,
474+
chunk_seq_lens_npu=chunk_seq_lens.npu(),
484475
workspace=self.chunked_prefill_workspace,
485476
)
486477
prefill_input_positions = input_positions[tokens_start:]
@@ -554,8 +545,6 @@ def build(
554545
1).unsqueeze(2)
555546
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
556547
1).unsqueeze(2)
557-
mc2_mask = self.generate_activate_mask(
558-
num_actual_tokens, num_reqs + num_reqs_pad_size)
559548

560549
decode_metadata = AscendMLADecodeMetadata(
561550
input_positions=input_positions,
@@ -566,8 +555,7 @@ def build(
566555
attn_mask=self.runner.spec_attn_mask,
567556
actual_seq_q_lens=actual_seq_q_lens,
568557
sin=sin,
569-
cos=cos,
570-
mc2_mask=mc2_mask)
558+
cos=cos)
571559

572560
return self.metadata_cls( # type: ignore
573561
num_actual_tokens=num_actual_tokens,
@@ -749,6 +737,8 @@ def _compute_prefill_context(
749737
toks = prefill_metadata.chunked_context.seq_tot[i]
750738

751739
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
740+
seq_len2_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
741+
i]
752742
seq_len = torch.stack([seq_len1, seq_len2])
753743
kv_c_normed = torch.empty(toks,
754744
num_heads,
@@ -765,7 +755,7 @@ def _compute_prefill_context(
765755
cache_kv_c,
766756
cache_k_pe,
767757
prefill_metadata.block_table,
768-
seq_len2.to(query.device),
758+
seq_len2_npu,
769759
seq_starts=prefill_metadata.chunked_context.starts[i],
770760
key=kv_c_normed,
771761
value=k_pe,

vllm_ascend/envs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,13 @@
136136
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
137137
# and the mla_pa will be the default path of deepseek decode path.
138138
"VLLM_ASCEND_MLA_PA":
139-
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0))
139+
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)),
140+
# ENABLE chunk mc2
141+
"VLLM_ASCEND_ENABLE_CHUNK_MC2":
142+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CHUNK_MC2", "0"))),
143+
# Batch MC2 in prefill: The number of tokens in each batch
144+
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE":
145+
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")),
140146
}
141147

142148
# end-env-vars-definition

vllm_ascend/ops/fused_moe.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,7 @@ def forward(self,
11511151

11521152
num_tokens, hidden_size = hidden_states.shape
11531153

1154+
forward_context = get_forward_context()
11541155
fused_moe_state = get_forward_context().fused_moe_state
11551156
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
11561157
quantized_x_for_share, dynamic_scale_for_share = None, None
@@ -1170,31 +1171,29 @@ def forward(self,
11701171
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
11711172
shared_hidden_states = shared_experts(hidden_states)
11721173

1173-
attn_metadata = get_forward_context().attn_metadata
1174-
mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None
1175-
1174+
mc2_mask = forward_context.mc2_mask
11761175
tp_size = get_tensor_model_parallel_world_size()
1177-
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
1178-
if num_tokens < tp_size:
1176+
if fused_moe_state != FusedMoEState.AllGather:
1177+
if num_tokens < forward_context.padded_num_tokens:
11791178
hidden_states = nn.functional.pad(
1180-
hidden_states, (0, 0, 0, tp_size - num_tokens))
1179+
hidden_states,
1180+
(0, 0, 0, forward_context.padded_num_tokens - num_tokens))
11811181
router_logits = nn.functional.pad(
1182-
router_logits, (0, 0, 0, tp_size - num_tokens))
1183-
if mc2_mask is not None:
1184-
mc2_mask = nn.functional.pad(mc2_mask,
1185-
(0, tp_size - num_tokens))
1186-
chunk_hidden_states = torch.tensor_split(hidden_states,
1187-
tp_size,
1188-
dim=0)
1189-
chunk_router_logits = torch.tensor_split(router_logits,
1190-
tp_size,
1191-
dim=0)
1192-
tp_rank = get_tensor_model_parallel_rank()
1193-
hidden_states = chunk_hidden_states[tp_rank]
1194-
router_logits = chunk_router_logits[tp_rank]
1195-
1196-
if mc2_mask is not None:
1197-
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
1182+
router_logits,
1183+
(0, 0, 0, forward_context.padded_num_tokens - num_tokens))
1184+
if tp_size > 1:
1185+
chunk_hidden_states = torch.tensor_split(hidden_states,
1186+
tp_size,
1187+
dim=0)
1188+
chunk_router_logits = torch.tensor_split(router_logits,
1189+
tp_size,
1190+
dim=0)
1191+
chunk_mc2_mask = torch.tensor_split(forward_context.mc2_mask,
1192+
tp_size,
1193+
dim=0)
1194+
tp_rank = get_tensor_model_parallel_rank()
1195+
hidden_states = chunk_hidden_states[tp_rank]
1196+
router_logits = chunk_router_logits[tp_rank]
11981197
mc2_mask = chunk_mc2_mask[tp_rank]
11991198

12001199
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
@@ -1246,7 +1245,7 @@ def forward(self,
12461245
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
12471246
self.tp_group)
12481247
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
1249-
if num_tokens < tp_size:
1248+
if num_tokens < forward_context.padded_num_tokens:
12501249
final_hidden_states = final_hidden_states[:num_tokens]
12511250
dispose_tensor(e_hidden_states)
12521251
elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:

0 commit comments

Comments
 (0)