Skip to content

Commit 6856f9d

Browse files
[v0.9.1][BugFix] Fix DBO bug after attn_metadata_refactor (#1445)
### What this PR does / why we need it? fix inference bug after `set_forward_context` refactor ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? test case in `tests/multicard/test_torchair_graph_mode.py` --------- Signed-off-by: shikang-hangzhou <459956190@qq.com>
1 parent 263af3b commit 6856f9d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

vllm_ascend/models/deepseek_dbo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from torch import nn
3535
from transformers import PretrainedConfig
3636
from vllm.attention import Attention, AttentionMetadata
37-
from vllm.config import CacheConfig, ModelConfig, VllmConfig
37+
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
38+
get_current_vllm_config)
3839
from vllm.distributed import (get_pp_group,
3940
get_tensor_model_parallel_world_size,
4041
get_tp_group, tensor_model_parallel_all_reduce)
@@ -196,7 +197,10 @@ def __init__(
196197

197198
self.tp_group = get_tp_group().device_group
198199
self.tp_rank = get_tp_group().rank_in_group
199-
200+
self.kv_consumer = None
201+
transfer_config = get_current_vllm_config().kv_transfer_config
202+
if transfer_config is not None:
203+
self.kv_consumer = transfer_config.kv_role = "kv_consumer"
200204
self.params_dtype = torch.get_default_dtype()
201205

202206
ascend_config = get_ascend_config()

0 commit comments

Comments
 (0)