Skip to content

Commit 8b369df

Browse files
author
weijinqian_v1
committed
handle code conflict
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
2 parents deb4319 + 8e42f71 commit 8b369df

File tree

15 files changed

+370
-110
lines changed

15 files changed

+370
-110
lines changed

tests/singlecard/test_aclgraph.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,16 @@ def test_deepseek_raises_error(monkeypatch: pytest.MonkeyPatch) -> None:
103103
max_model_len=1024,
104104
enforce_eager=False)
105105
assert "ACL Graph does not support deepseek" in str(excinfo.value)
106+
107+
108+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
109+
reason="aclgraph only support on v1")
110+
@pytest.mark.parametrize("model", MODELS)
111+
def test_ray_backend_sets_no_compilation(
112+
model: str, monkeypatch: pytest.MonkeyPatch) -> None:
113+
with monkeypatch.context() as m:
114+
m.setenv("VLLM_USE_V1", "1")
115+
runner = VllmRunner(model,
116+
enforce_eager=False,
117+
distributed_executor_backend="ray")
118+
assert runner.model.llm_engine.vllm_config.compilation_config.level == 0

vllm_ascend/ascend_forward_context.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import math
12
from contextlib import contextmanager
23
from enum import Enum
34
from typing import Any, Optional
45

56
import torch
67
from vllm.config import VllmConfig
7-
from vllm.distributed import get_dp_group
8+
from vllm.distributed import get_dp_group, get_tp_group
89
from vllm.forward_context import get_forward_context, set_forward_context
10+
from vllm.platforms import current_platform
911

1012
import vllm_ascend.envs as envs_ascend
1113

@@ -15,16 +17,20 @@ class FusedMoEState(Enum):
1517
All2All = 1
1618
MC2 = 2
1719
All2AllSeq = 3
20+
MC2_PREFILL = 4
1821

1922

2023
# TODO(zzzzwwjj): add soc_version to choose branch
2124
def get_fused_moe_state(ep_size: int, with_prefill: bool):
25+
enable_chunk_mc2 = envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2
2226
if ep_size == 1:
2327
return FusedMoEState.AllGather
2428
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
2529
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
2630
return FusedMoEState.All2AllSeq if (
2731
ep_size < 16 or with_prefill) else FusedMoEState.MC2
32+
elif ep_size >= 16 and with_prefill and enable_chunk_mc2:
33+
return FusedMoEState.MC2_PREFILL
2834
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
2935
elif ep_size < 16 or with_prefill:
3036
return FusedMoEState.All2All
@@ -40,7 +46,8 @@ def set_ascend_forward_context(
4046
num_tokens: Optional[int] = None,
4147
num_tokens_across_dp: Optional[torch.Tensor] = None,
4248
with_prefill: bool = True,
43-
in_profile_run: bool = False):
49+
in_profile_run: bool = False,
50+
num_actual_tokens: Optional[int] = None):
4451
"""A context manager that stores the current forward context,
4552
can be attention metadata, etc.
4653
We add some additional param into forward_context.
@@ -52,7 +59,6 @@ def set_ascend_forward_context(
5259
num_tokens_across_dp=num_tokens_across_dp):
5360
forward_context = get_forward_context()
5461
forward_context.with_prefill = with_prefill
55-
5662
ep_size = torch.distributed.get_world_size(
5763
) if vllm_config.parallel_config.enable_expert_parallel else 1
5864

@@ -66,19 +72,55 @@ def set_ascend_forward_context(
6672
# due to multiple warmups before actual capturing
6773
forward_context.capturing = False
6874

75+
if num_tokens is None and attn_metadata is not None:
76+
if hasattr(attn_metadata, 'num_actual_tokens'):
77+
# for v1 engine
78+
num_tokens = attn_metadata.num_actual_tokens
79+
else:
80+
# for v0 engine
81+
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
82+
83+
if num_actual_tokens is None:
84+
num_actual_tokens = num_tokens
85+
6986
dp_world_size = get_dp_group().world_size
7087
if dp_world_size > 1 and forward_context.dp_metadata is not None:
71-
forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
88+
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
7289
)
73-
elif num_tokens is not None:
74-
forward_context.max_tokens_across_dp = num_tokens
75-
elif attn_metadata is not None:
76-
if hasattr(attn_metadata, 'num_actual_tokens'):
77-
forward_context.max_tokens_across_dp = attn_metadata.num_actual_tokens
78-
else:
79-
forward_context.max_tokens_across_dp = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
8090
else:
81-
forward_context.max_tokens_across_dp = None
91+
max_tokens_across_dp = num_tokens
92+
93+
forward_context.max_tokens_across_dp = max_tokens_across_dp
94+
95+
if num_tokens is not None:
96+
tp_world_size = get_tp_group().world_size
97+
world_size = torch.distributed.get_world_size()
98+
# NOTE: token num which need to pad to when mc2
99+
forward_context.padded_num_tokens = math.ceil(
100+
max_tokens_across_dp / tp_world_size) * tp_world_size
101+
# NOTE: mc2 op's param `global_bs`, add `world_size` to make `global_bs` absolutely larger than actual global_bs.
102+
forward_context.global_bs = math.ceil(
103+
max_tokens_across_dp / tp_world_size) * world_size
104+
105+
if fused_moe_state == FusedMoEState.MC2_PREFILL:
106+
chunk_size = envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE
107+
forward_context.max_num_chunks = math.ceil(
108+
math.ceil(max_tokens_across_dp / tp_world_size) /
109+
chunk_size)
110+
111+
forward_context.global_bs = math.ceil(
112+
math.ceil(max_tokens_across_dp / tp_world_size) /
113+
forward_context.max_num_chunks) * world_size
114+
115+
min_num_tokens = forward_context.max_num_chunks * tp_world_size
116+
forward_context.padded_num_tokens = math.ceil(
117+
max_tokens_across_dp / min_num_tokens) * min_num_tokens
118+
119+
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
120+
dtype=torch.bool,
121+
device=current_platform.device_type)
122+
mc2_mask[:num_actual_tokens] = True
123+
forward_context.mc2_mask = mc2_mask
82124

83125
try:
84126
yield

vllm_ascend/attention/mla_v1.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from vllm.config import get_current_vllm_config
1212
from vllm.model_executor.layers.linear import (LinearBase,
1313
UnquantizedLinearMethod)
14-
from vllm.platforms import current_platform
1514
from vllm.utils import cdiv, round_down
1615

1716
from vllm_ascend import envs
@@ -71,6 +70,7 @@ class ChunkedContextMetadata:
7170
max_seq_lens: list[int]
7271
workspace: torch.Tensor
7372
chunk_seq_lens: torch.Tensor
73+
chunk_seq_lens_npu: torch.Tensor
7474

7575
attn_mask: torch.Tensor
7676
query_lens: list[int]
@@ -99,7 +99,6 @@ class AscendMLADecodeMetadata:
9999
attn_mask: Optional[torch.Tensor] = None
100100
sin: torch.Tensor = None
101101
cos: torch.Tensor = None
102-
mc2_mask: Optional[torch.Tensor] = None
103102

104103

105104
@dataclass
@@ -215,13 +214,6 @@ def __init__(self,
215214
self.cos_cache = None
216215
self.sin_cache = None
217216

218-
def generate_activate_mask(self, actual_seqs_num, batch_size):
219-
mc2_mask = torch.zeros(batch_size,
220-
dtype=torch.bool,
221-
device=current_platform.device_type)
222-
mc2_mask[:actual_seqs_num].fill_(True)
223-
return mc2_mask
224-
225217
def reorder_batch(self, input_batch: "InputBatch",
226218
scheduler_output: "SchedulerOutput") -> bool:
227219
# We now want to reorder the batch so that the "decode" requests are at
@@ -364,7 +356,6 @@ def build_torchair_graph_dummy(
364356
self.rope_dim,
365357
dtype=self.runner.dtype,
366358
device=device)
367-
mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs)
368359
decode_metadata = AscendMLADecodeMetadata(
369360
input_positions=input_positions,
370361
block_table=block_table,
@@ -374,8 +365,7 @@ def build_torchair_graph_dummy(
374365
attn_mask=self.runner.spec_attn_mask,
375366
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
376367
sin=sin,
377-
cos=cos,
378-
mc2_mask=mc2_mask)
368+
cos=cos)
379369
return self.metadata_cls( # type: ignore
380370
num_input_tokens=num_actual_tokens,
381371
num_actual_tokens=num_actual_tokens,
@@ -481,6 +471,7 @@ def build(
481471
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
482472
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
483473
chunk_seq_lens=chunk_seq_lens,
474+
chunk_seq_lens_npu=chunk_seq_lens.npu(),
484475
workspace=self.chunked_prefill_workspace,
485476
)
486477
prefill_input_positions = input_positions[tokens_start:]
@@ -547,15 +538,18 @@ def build(
547538
actual_seq_q_lens = query_start_loc[1:].tolist(
548539
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
549540
num_reqs_pad_size]
541+
# mtp torchair + PD scenario, last element of actual_seq_q_lens must equal to num_reqs_pad_size
542+
num_padded_token_size = slot_mapping.size(0)
543+
if actual_seq_q_lens[-1] != num_padded_token_size:
544+
actual_seq_q_lens.append(num_padded_token_size)
545+
seq_lens_list.append(0)
550546
else:
551547
seq_lens_list = seq_lens.tolist()
552548

553549
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
554550
1).unsqueeze(2)
555551
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
556552
1).unsqueeze(2)
557-
mc2_mask = self.generate_activate_mask(
558-
num_actual_tokens, num_reqs + num_reqs_pad_size)
559553

560554
decode_metadata = AscendMLADecodeMetadata(
561555
input_positions=input_positions,
@@ -566,8 +560,7 @@ def build(
566560
attn_mask=self.runner.spec_attn_mask,
567561
actual_seq_q_lens=actual_seq_q_lens,
568562
sin=sin,
569-
cos=cos,
570-
mc2_mask=mc2_mask)
563+
cos=cos)
571564

572565
return self.metadata_cls( # type: ignore
573566
num_actual_tokens=num_actual_tokens,
@@ -749,6 +742,8 @@ def _compute_prefill_context(
749742
toks = prefill_metadata.chunked_context.seq_tot[i]
750743

751744
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
745+
seq_len2_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
746+
i]
752747
seq_len = torch.stack([seq_len1, seq_len2])
753748
kv_c_normed = torch.empty(toks,
754749
num_heads,
@@ -765,7 +760,7 @@ def _compute_prefill_context(
765760
cache_kv_c,
766761
cache_k_pe,
767762
prefill_metadata.block_table,
768-
seq_len2.to(query.device),
763+
seq_len2_npu,
769764
seq_starts=prefill_metadata.chunked_context.starts[i],
770765
key=kv_c_normed,
771766
value=k_pe,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Optional
2+
3+
import torch
4+
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
5+
init_model_parallel_group)
6+
7+
# Currently, mc2 op need their own group coordinator.
8+
_MC2: Optional[GroupCoordinator] = None
9+
10+
11+
def get_mc2_group() -> GroupCoordinator:
12+
assert _MC2 is not None, ("mc2 group is not initialized")
13+
return _MC2
14+
15+
16+
def model_parallel_initialized():
17+
return (_MC2 is not None)
18+
19+
20+
def init_ascend_model_parallel(
21+
expert_parallel_size: int = 1,
22+
world_size: Optional[int] = None,
23+
backend: Optional[str] = None,
24+
):
25+
if model_parallel_initialized():
26+
return
27+
assert torch.distributed.is_initialized()
28+
world_size = world_size or torch.distributed.get_world_size()
29+
backend = backend or torch.distributed.get_backend(
30+
get_world_group().device_group)
31+
num_expert_parallel_groups = world_size // expert_parallel_size
32+
33+
global _MC2
34+
group_ranks = []
35+
for i in range(num_expert_parallel_groups):
36+
ranks = list(range(i, world_size, num_expert_parallel_groups))
37+
group_ranks.append(ranks)
38+
39+
_MC2 = init_model_parallel_group(group_ranks,
40+
get_world_group().local_rank,
41+
backend,
42+
group_name="mc2")
43+
44+
45+
def destroy_ascend_model_parallel():
46+
global _MC2
47+
if _MC2:
48+
_MC2.destroy()
49+
_MC2 = None

vllm_ascend/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,12 @@
142142
# 1: enable moe all2all seq.
143143
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ":
144144
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))),
145+
# ENABLE chunk mc2
146+
"VLLM_ASCEND_ENABLE_CHUNK_MC2":
147+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CHUNK_MC2", "0"))),
148+
# Batch MC2 in prefill: The number of tokens in each batch
149+
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE":
150+
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")),
145151
}
146152

147153
# end-env-vars-definition

vllm_ascend/models/qwen3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.model_executor.sampling_metadata import SamplingMetadata
1919
from vllm.sequence import IntermediateTensors
2020

21-
from vllm_ascend.ops.layernorm import AddRMSNormQuant
21+
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
2222

2323

2424
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
@@ -43,15 +43,15 @@ def __init__(
4343
assert isinstance(quant_config, AscendQuantConfig), \
4444
"Expected quant_config to be an instance of AscendQuantConfig"
4545

46-
if isinstance(self.self_attn.qkv_proj.quant_method,
46+
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
4747
AscendW8A8LinearMethod):
48-
self.input_layernorm = AddRMSNormQuant(
48+
self.input_layernorm = AddRMSNormW8A8Quant(
4949
config.hidden_size,
5050
layer=self.self_attn.qkv_proj,
5151
eps=config.rms_norm_eps)
52-
if isinstance(self.mlp.gate_up_proj.quant_method,
52+
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
5353
AscendW8A8LinearMethod):
54-
self.post_attention_layernorm = AddRMSNormQuant(
54+
self.post_attention_layernorm = AddRMSNormW8A8Quant(
5555
config.hidden_size,
5656
layer=self.mlp.gate_up_proj,
5757
eps=config.rms_norm_eps)

0 commit comments

Comments
 (0)