Skip to content

Commit 847d52d

Browse files
authored
Merge branch 'v0.9.1-dev' into v0.9.1-dev
2 parents d4ad734 + da2d5ac commit 847d52d

File tree

11 files changed

+348
-99
lines changed

11 files changed

+348
-99
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import math
12
from contextlib import contextmanager
23
from enum import Enum
34
from typing import Any, Optional
45

56
import torch
67
from vllm.config import VllmConfig
7-
from vllm.distributed import get_dp_group
8+
from vllm.distributed import get_dp_group, get_tp_group
89
from vllm.forward_context import get_forward_context, set_forward_context
10+
from vllm.platforms import current_platform
11+
12+
import vllm_ascend.envs as envs
913

1014
import vllm_ascend.envs as envs_ascend
1115

@@ -14,17 +18,21 @@ class FusedMoEState(Enum):
1418
AllGather = 0
1519
All2All = 1
1620
MC2 = 2
17-
All2AllSeq = 3
21+
MC2_PREFILL = 3
22+
All2AllSeq = 4
1823

1924

2025
# TODO(zzzzwwjj): add soc_version to choose branch
2126
def get_fused_moe_state(ep_size: int, with_prefill: bool):
27+
enable_chunk_mc2 = envs.VLLM_ASCEND_ENABLE_CHUNK_MC2
2228
if ep_size == 1:
2329
return FusedMoEState.AllGather
2430
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
2531
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
2632
return FusedMoEState.All2AllSeq if (
2733
ep_size < 16 or with_prefill) else FusedMoEState.MC2
34+
elif ep_size >= 16 and with_prefill and enable_chunk_mc2:
35+
return FusedMoEState.MC2_PREFILL
2836
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
2937
elif ep_size < 16 or with_prefill:
3038
return FusedMoEState.All2All
@@ -40,7 +48,8 @@ def set_ascend_forward_context(
4048
num_tokens: Optional[int] = None,
4149
num_tokens_across_dp: Optional[torch.Tensor] = None,
4250
with_prefill: bool = True,
43-
in_profile_run: bool = False):
51+
in_profile_run: bool = False,
52+
num_actual_tokens: Optional[int] = None):
4453
"""A context manager that stores the current forward context,
4554
can be attention metadata, etc.
4655
We add some additional param into forward_context.
@@ -52,7 +61,6 @@ def set_ascend_forward_context(
5261
num_tokens_across_dp=num_tokens_across_dp):
5362
forward_context = get_forward_context()
5463
forward_context.with_prefill = with_prefill
55-
5664
ep_size = torch.distributed.get_world_size(
5765
) if vllm_config.parallel_config.enable_expert_parallel else 1
5866

@@ -66,19 +74,55 @@ def set_ascend_forward_context(
6674
# due to multiple warmups before actual capturing
6775
forward_context.capturing = False
6876

77+
if num_tokens is None and attn_metadata is not None:
78+
if hasattr(attn_metadata, 'num_actual_tokens'):
79+
# for v1 engine
80+
num_tokens = attn_metadata.num_actual_tokens
81+
else:
82+
# for v0 engine
83+
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
84+
85+
if num_actual_tokens is None:
86+
num_actual_tokens = num_tokens
87+
6988
dp_world_size = get_dp_group().world_size
7089
if dp_world_size > 1 and forward_context.dp_metadata is not None:
71-
forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
90+
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
7291
)
73-
elif num_tokens is not None:
74-
forward_context.max_tokens_across_dp = num_tokens
75-
elif attn_metadata is not None:
76-
if hasattr(attn_metadata, 'num_actual_tokens'):
77-
forward_context.max_tokens_across_dp = attn_metadata.num_actual_tokens
78-
else:
79-
forward_context.max_tokens_across_dp = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
8092
else:
81-
forward_context.max_tokens_across_dp = None
93+
max_tokens_across_dp = num_tokens
94+
95+
forward_context.max_tokens_across_dp = max_tokens_across_dp
96+
97+
if num_tokens is not None:
98+
tp_world_size = get_tp_group().world_size
99+
world_size = torch.distributed.get_world_size()
100+
# NOTE: token num which need to pad to when mc2
101+
forward_context.padded_num_tokens = math.ceil(
102+
max_tokens_across_dp / tp_world_size) * tp_world_size
103+
# NOTE: mc2 op's param `global_bs`, add `world_size` to make `global_bs` absolutely larger than actual global_bs.
104+
forward_context.global_bs = math.ceil(
105+
max_tokens_across_dp / tp_world_size) * world_size
106+
107+
if fused_moe_state == FusedMoEState.MC2_PREFILL:
108+
chunk_size = envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE
109+
forward_context.max_num_chunks = math.ceil(
110+
math.ceil(max_tokens_across_dp / tp_world_size) /
111+
chunk_size)
112+
113+
forward_context.global_bs = math.ceil(
114+
math.ceil(max_tokens_across_dp / tp_world_size) /
115+
forward_context.max_num_chunks) * world_size
116+
117+
min_num_tokens = forward_context.max_num_chunks * tp_world_size
118+
forward_context.padded_num_tokens = math.ceil(
119+
max_tokens_across_dp / min_num_tokens) * min_num_tokens
120+
121+
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
122+
dtype=torch.bool,
123+
device=current_platform.device_type)
124+
mc2_mask[:num_actual_tokens] = True
125+
forward_context.mc2_mask = mc2_mask
82126

83127
try:
84128
yield

vllm_ascend/attention/mla_v1.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from vllm.config import get_current_vllm_config
1212
from vllm.model_executor.layers.linear import (LinearBase,
1313
UnquantizedLinearMethod)
14-
from vllm.platforms import current_platform
1514
from vllm.utils import cdiv, round_down
1615

1716
from vllm_ascend import envs
@@ -71,6 +70,7 @@ class ChunkedContextMetadata:
7170
max_seq_lens: list[int]
7271
workspace: torch.Tensor
7372
chunk_seq_lens: torch.Tensor
73+
chunk_seq_lens_npu: torch.Tensor
7474

7575
attn_mask: torch.Tensor
7676
query_lens: list[int]
@@ -99,7 +99,6 @@ class AscendMLADecodeMetadata:
9999
attn_mask: Optional[torch.Tensor] = None
100100
sin: torch.Tensor = None
101101
cos: torch.Tensor = None
102-
mc2_mask: Optional[torch.Tensor] = None
103102

104103

105104
@dataclass
@@ -215,13 +214,6 @@ def __init__(self,
215214
self.cos_cache = None
216215
self.sin_cache = None
217216

218-
def generate_activate_mask(self, actual_seqs_num, batch_size):
219-
mc2_mask = torch.zeros(batch_size,
220-
dtype=torch.bool,
221-
device=current_platform.device_type)
222-
mc2_mask[:actual_seqs_num].fill_(True)
223-
return mc2_mask
224-
225217
def reorder_batch(self, input_batch: "InputBatch",
226218
scheduler_output: "SchedulerOutput") -> bool:
227219
# We now want to reorder the batch so that the "decode" requests are at
@@ -364,7 +356,6 @@ def build_torchair_graph_dummy(
364356
self.rope_dim,
365357
dtype=self.runner.dtype,
366358
device=device)
367-
mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs)
368359
decode_metadata = AscendMLADecodeMetadata(
369360
input_positions=input_positions,
370361
block_table=block_table,
@@ -374,8 +365,7 @@ def build_torchair_graph_dummy(
374365
attn_mask=self.runner.spec_attn_mask,
375366
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
376367
sin=sin,
377-
cos=cos,
378-
mc2_mask=mc2_mask)
368+
cos=cos)
379369
return self.metadata_cls( # type: ignore
380370
num_input_tokens=num_actual_tokens,
381371
num_actual_tokens=num_actual_tokens,
@@ -481,6 +471,7 @@ def build(
481471
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
482472
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
483473
chunk_seq_lens=chunk_seq_lens,
474+
chunk_seq_lens_npu=chunk_seq_lens.npu(),
484475
workspace=self.chunked_prefill_workspace,
485476
)
486477
prefill_input_positions = input_positions[tokens_start:]
@@ -547,15 +538,18 @@ def build(
547538
actual_seq_q_lens = query_start_loc[1:].tolist(
548539
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
549540
num_reqs_pad_size]
541+
# mtp torchair + PD scenario, last element of actual_seq_q_lens must equal to num_reqs_pad_size
542+
num_padded_token_size = slot_mapping.size(0)
543+
if actual_seq_q_lens[-1] != num_padded_token_size:
544+
actual_seq_q_lens.append(num_padded_token_size)
545+
seq_lens_list.append(0)
550546
else:
551547
seq_lens_list = seq_lens.tolist()
552548

553549
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
554550
1).unsqueeze(2)
555551
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
556552
1).unsqueeze(2)
557-
mc2_mask = self.generate_activate_mask(
558-
num_actual_tokens, num_reqs + num_reqs_pad_size)
559553

560554
decode_metadata = AscendMLADecodeMetadata(
561555
input_positions=input_positions,
@@ -566,8 +560,7 @@ def build(
566560
attn_mask=self.runner.spec_attn_mask,
567561
actual_seq_q_lens=actual_seq_q_lens,
568562
sin=sin,
569-
cos=cos,
570-
mc2_mask=mc2_mask)
563+
cos=cos)
571564

572565
return self.metadata_cls( # type: ignore
573566
num_actual_tokens=num_actual_tokens,
@@ -749,6 +742,8 @@ def _compute_prefill_context(
749742
toks = prefill_metadata.chunked_context.seq_tot[i]
750743

751744
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
745+
seq_len2_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
746+
i]
752747
seq_len = torch.stack([seq_len1, seq_len2])
753748
kv_c_normed = torch.empty(toks,
754749
num_heads,
@@ -765,7 +760,7 @@ def _compute_prefill_context(
765760
cache_kv_c,
766761
cache_k_pe,
767762
prefill_metadata.block_table,
768-
seq_len2.to(query.device),
763+
seq_len2_npu,
769764
seq_starts=prefill_metadata.chunked_context.starts[i],
770765
key=kv_c_normed,
771766
value=k_pe,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Optional
2+
3+
import torch
4+
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
5+
init_model_parallel_group)
6+
7+
# Currently, mc2 op need their own group coordinator.
8+
_MC2: Optional[GroupCoordinator] = None
9+
10+
11+
def get_mc2_group() -> GroupCoordinator:
12+
assert _MC2 is not None, ("mc2 group is not initialized")
13+
return _MC2
14+
15+
16+
def model_parallel_initialized():
17+
return (_MC2 is not None)
18+
19+
20+
def init_ascend_model_parallel(
21+
expert_parallel_size: int = 1,
22+
world_size: Optional[int] = None,
23+
backend: Optional[str] = None,
24+
):
25+
if model_parallel_initialized():
26+
return
27+
assert torch.distributed.is_initialized()
28+
world_size = world_size or torch.distributed.get_world_size()
29+
backend = backend or torch.distributed.get_backend(
30+
get_world_group().device_group)
31+
num_expert_parallel_groups = world_size // expert_parallel_size
32+
33+
global _MC2
34+
group_ranks = []
35+
for i in range(num_expert_parallel_groups):
36+
ranks = list(range(i, world_size, num_expert_parallel_groups))
37+
group_ranks.append(ranks)
38+
39+
_MC2 = init_model_parallel_group(group_ranks,
40+
get_world_group().local_rank,
41+
backend,
42+
group_name="mc2")
43+
44+
45+
def destroy_ascend_model_parallel():
46+
global _MC2
47+
if _MC2:
48+
_MC2.destroy()
49+
_MC2 = None

vllm_ascend/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,12 @@
142142
# 1: enable moe all2all seq.
143143
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ":
144144
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))),
145+
# ENABLE chunk mc2
146+
"VLLM_ASCEND_ENABLE_CHUNK_MC2":
147+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CHUNK_MC2", "0"))),
148+
# Batch MC2 in prefill: The number of tokens in each batch
149+
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE":
150+
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")),
145151
}
146152

147153
# end-env-vars-definition

vllm_ascend/ops/fused_moe.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import vllm_ascend.envs as envs_ascend
4040
from vllm_ascend.ascend_config import get_ascend_config
4141
from vllm_ascend.ascend_forward_context import FusedMoEState
42+
from vllm_ascend.distributed.parallel_state import get_mc2_group
4243
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4344
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
4445
MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig)
@@ -127,7 +128,7 @@ def fused_experts_with_mc2(
127128
mc2_mask: Optional[torch.Tensor] = None,
128129
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
129130
quant_mode = 0
130-
ep_group = get_ep_group()
131+
ep_group = get_mc2_group()
131132
ep_rank_id = ep_group.rank_in_group
132133
ep_world_size = ep_group.world_size
133134
tp_world_size = get_tp_group().world_size
@@ -889,7 +890,7 @@ def __init__(self, moe: MoEConfig = None):
889890
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
890891

891892
try:
892-
device_group = get_ep_group().device_group
893+
device_group = get_mc2_group().device_group
893894
# TODO: Try local_rank = ep_group.rank_in_group
894895
local_rank = torch.distributed.get_rank(group=device_group)
895896
backend = device_group._get_backend(torch.device("npu"))
@@ -1191,6 +1192,7 @@ def forward(self,
11911192

11921193
num_tokens, hidden_size = hidden_states.shape
11931194

1195+
forward_context = get_forward_context()
11941196
fused_moe_state = get_forward_context().fused_moe_state
11951197
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
11961198
quantized_x_for_share, dynamic_scale_for_share = None, None
@@ -1210,32 +1212,30 @@ def forward(self,
12101212
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
12111213
shared_hidden_states = shared_experts(hidden_states)
12121214

1213-
attn_metadata = get_forward_context().attn_metadata
1214-
mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and getattr(
1215-
attn_metadata, "decode", None) is not None else None
12161215

1216+
mc2_mask = forward_context.mc2_mask
12171217
tp_size = get_tensor_model_parallel_world_size()
1218-
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
1219-
if num_tokens < tp_size:
1218+
if fused_moe_state != FusedMoEState.AllGather:
1219+
if num_tokens < forward_context.padded_num_tokens:
12201220
hidden_states = nn.functional.pad(
1221-
hidden_states, (0, 0, 0, tp_size - num_tokens))
1221+
hidden_states,
1222+
(0, 0, 0, forward_context.padded_num_tokens - num_tokens))
12221223
router_logits = nn.functional.pad(
1223-
router_logits, (0, 0, 0, tp_size - num_tokens))
1224-
if mc2_mask is not None:
1225-
mc2_mask = nn.functional.pad(mc2_mask,
1226-
(0, tp_size - num_tokens))
1227-
chunk_hidden_states = torch.tensor_split(hidden_states,
1228-
tp_size,
1229-
dim=0)
1230-
chunk_router_logits = torch.tensor_split(router_logits,
1231-
tp_size,
1232-
dim=0)
1233-
tp_rank = get_tensor_model_parallel_rank()
1234-
hidden_states = chunk_hidden_states[tp_rank]
1235-
router_logits = chunk_router_logits[tp_rank]
1236-
1237-
if mc2_mask is not None:
1238-
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
1224+
router_logits,
1225+
(0, 0, 0, forward_context.padded_num_tokens - num_tokens))
1226+
if tp_size > 1:
1227+
chunk_hidden_states = torch.tensor_split(hidden_states,
1228+
tp_size,
1229+
dim=0)
1230+
chunk_router_logits = torch.tensor_split(router_logits,
1231+
tp_size,
1232+
dim=0)
1233+
chunk_mc2_mask = torch.tensor_split(forward_context.mc2_mask,
1234+
tp_size,
1235+
dim=0)
1236+
tp_rank = get_tensor_model_parallel_rank()
1237+
hidden_states = chunk_hidden_states[tp_rank]
1238+
router_logits = chunk_router_logits[tp_rank]
12391239
mc2_mask = chunk_mc2_mask[tp_rank]
12401240

12411241
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
@@ -1287,7 +1287,7 @@ def forward(self,
12871287
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
12881288
self.tp_group)
12891289
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
1290-
if num_tokens < tp_size:
1290+
if num_tokens < forward_context.padded_num_tokens:
12911291
final_hidden_states = final_hidden_states[:num_tokens]
12921292
dispose_tensor(e_hidden_states)
12931293
elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:

0 commit comments

Comments
 (0)