Skip to content

Commit e89c59d

Browse files
authored
[0.9.1][bugfix] fix ascend_forward_context (#1554)
### What this PR does / why we need it? 1. fix v0_model_runner, pooling_model_runner, draft_model_runner are not adapted to `ascend_forward_context` 2. fix moe_distributed_combine's param `global_bs` bug. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent adf436b commit e89c59d

File tree

5 files changed

+15
-11
lines changed

5 files changed

+15
-11
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def fused_experts_with_mc2(
214214
"expert_shard_type": 0,
215215
"shared_expert_rank_num": 0,
216216
"moe_expert_num": moe_expert_num,
217-
"global_bs": 0,
217+
"global_bs": global_bs,
218218
}
219219
tp_recv_counts = output[5]
220220
stage3_kwargs = {

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def fused_experts_with_mc2(
218218
"expert_shard_type": 0,
219219
"shared_expert_rank_num": 0,
220220
"moe_expert_num": moe_expert_num,
221-
"global_bs": 0,
221+
"global_bs": global_bs,
222222
}
223223
tp_recv_counts = torch.empty(1,
224224
dtype=torch.int32,

vllm_ascend/worker/draft_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import List, Optional
1919

2020
import torch
21-
from vllm.forward_context import set_forward_context
2221
from vllm.logger import logger
2322
from vllm.model_executor.layers.sampler import SamplerOutput
2423
from vllm.multimodal import MultiModalKwargs
@@ -27,6 +26,7 @@
2726
ModelRunnerInputBase,
2827
ModelRunnerWrapperBase)
2928

29+
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
3030
from vllm_ascend.attention.attention import AscendMetadata
3131

3232
# A flag to enable debug prints for the updated input tensors
@@ -261,8 +261,8 @@ def execute_model(
261261
spec_step_idx = kwargs.get("spec_step_idx", step)
262262
model_execute_kwargs["spec_step_idx"] = spec_step_idx
263263
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
264-
with set_forward_context(model_input.attn_metadata,
265-
self.vllm_config):
264+
with set_ascend_forward_context(model_input.attn_metadata,
265+
self.vllm_config):
266266

267267
if model_input.attn_metadata is not None:
268268
model_input.attn_metadata.input_positions = model_input.input_positions

vllm_ascend/worker/model_runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from vllm.core.scheduler import SchedulerOutputs
3636
from vllm.distributed import broadcast_tensor_dict, get_dp_group, get_pp_group
3737
from vllm.distributed.kv_transfer import get_kv_transfer_group
38-
from vllm.forward_context import set_forward_context
3938
from vllm.inputs import INPUT_REGISTRY, InputRegistry
4039
from vllm.logger import logger
4140
from vllm.lora.layers import LoRAMapping
@@ -66,6 +65,7 @@
6665
_init_sampling_metadata_from_tensor_dict)
6766

6867
from vllm_ascend.ascend_config import get_ascend_config
68+
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
6969

7070
if TYPE_CHECKING:
7171
from vllm.attention.backends.abstract import AttentionBackend
@@ -1431,8 +1431,12 @@ def execute_model(
14311431
model_forward_start.record()
14321432

14331433
if not bypass_model_exec:
1434-
with set_forward_context(model_input.attn_metadata,
1435-
self.vllm_config, virtual_engine):
1434+
with set_ascend_forward_context(
1435+
model_input.attn_metadata,
1436+
self.vllm_config,
1437+
virtual_engine,
1438+
with_prefill=prefill_meta is not None,
1439+
in_profile_run=self.in_profile_run):
14361440
if model_input.attn_metadata is not None:
14371441
model_input.attn_metadata.input_positions = model_input.input_positions
14381442
if self.torchair_graph_enabled:

vllm_ascend/worker/pooling_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121

2222
import torch
2323
from vllm.distributed import get_pp_group
24-
from vllm.forward_context import set_forward_context
2524
from vllm.model_executor.pooling_metadata import PoolingMetadata
2625
from vllm.multimodal import MultiModalKwargs
2726
from vllm.pooling_params import PoolingParams
2827
from vllm.sequence import (IntermediateTensors, SequenceData,
2928
SequenceGroupMetadata)
3029

30+
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
3131
from vllm_ascend.worker.model_runner import (ModelInputForNPU,
3232
ModelInputForNPUBuilder,
3333
NPUModelRunnerBase)
@@ -142,8 +142,8 @@ def execute_model(
142142
if model_input.token_types is not None:
143143
cross_enc_kwargs["token_type_ids"] = model_input.token_types
144144

145-
with set_forward_context(model_input.attn_metadata, self.vllm_config,
146-
virtual_engine):
145+
with set_ascend_forward_context(model_input.attn_metadata,
146+
self.vllm_config, virtual_engine):
147147
hidden_or_intermediate_states = model_executable(
148148
input_ids=model_input.input_tokens,
149149
positions=model_input.input_positions,

0 commit comments

Comments
 (0)