Skip to content

Commit fa116eb

Browse files
authored
[0.7.3][V1] Support the feature of prefix cache in v1 (#559)
### What this PR does / why we need it? Support the feature of prefix cache in v1. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Comprehensive unit tests for ops accuracy have been performed and will be included in another PR. Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent d922fb9 commit fa116eb

File tree

3 files changed

+37
-14
lines changed

3 files changed

+37
-14
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ def copy_blocks(
8585

8686

8787
class AscendAttentionState(Enum):
88-
PrefillOnly = 0
89-
DecodeOnly = 1
90-
ChunkedPrefill = 2
88+
PrefillNoCache = 0
89+
PrefillCacheHit = 1
90+
DecodeOnly = 2
91+
ChunkedPrefill = 3
9192

9293

9394
@dataclass
@@ -214,7 +215,7 @@ def forward(
214215
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
215216
pass
216217
# V0-Style scheduler situation.
217-
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
218+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
218219
assert attn_metadata is not None
219220
assert attn_metadata.attn_mask is not None
220221
mask = attn_metadata.attn_mask
@@ -227,16 +228,31 @@ def forward(
227228
num_heads=self.num_heads,
228229
num_kv_heads=self.num_kv_heads,
229230
out=output)
231+
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
232+
assert attn_metadata is not None
233+
assert attn_metadata.attn_mask is not None
234+
compress_mask = attn_metadata.attn_mask
235+
torch_npu._npu_flash_attention_qlens(
236+
query=query,
237+
key_cache=self.key_cache,
238+
value_cache=self.value_cache,
239+
block_table=attn_metadata.block_tables,
240+
mask=compress_mask,
241+
seq_len=attn_metadata.seq_lens,
242+
context_lens=attn_metadata.context_lens,
243+
num_kv_heads=self.num_kv_heads,
244+
num_heads=self.num_heads,
245+
scale_value=self.scale,
246+
out=output)
230247
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
231-
block_tables = attn_metadata.block_tables
232248
torch_npu._npu_paged_attention(
233249
query=query,
234250
key_cache=self.key_cache,
235251
value_cache=self.value_cache,
236252
num_kv_heads=self.num_kv_heads,
237253
num_heads=self.num_heads,
238254
scale_value=self.scale,
239-
block_table=block_tables,
255+
block_table=attn_metadata.block_tables,
240256
context_lens=attn_metadata.context_lens,
241257
out=output)
242258
# Normal V1 situation.

vllm_ascend/platform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
138138

139139
if envs.VLLM_USE_V1 and cache_config and cache_config.enable_prefix_caching:
140140
logger.warning(
141-
"Prefix caching is not supported for V1 now, disable prefix caching"
141+
"Prefix caching is now supported for V1 on NPU, "
142+
"but it is still experimental and there may be issues with accuracy."
142143
)
143-
cache_config.enable_prefix_caching = False
144144

145145
if envs.VLLM_USE_V1:
146146
# Activate custom ops for v1.

vllm_ascend/worker/model_runner_v1.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
7171
self.speculative_config = vllm_config.speculative_config
7272
self.prompt_adapter_config = vllm_config.prompt_adapter_config
7373
self.observability_config = vllm_config.observability_config
74+
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
7475

7576
model_config = self.model_config
7677
cache_config = self.cache_config
@@ -419,11 +420,15 @@ def make_attention_mask(self, seq_lens, query_lens, position,
419420
if attn_state == AscendAttentionState.ChunkedPrefill:
420421
return self.attn_mask_builder.get_splitfuse_attn_mask(
421422
seq_lens, query_lens, position, self.dtype, self.device)
422-
# Prefill-only situation.
423-
elif attn_state == AscendAttentionState.PrefillOnly:
423+
# Prefill without cache situation.
424+
elif attn_state == AscendAttentionState.PrefillNoCache:
424425
max_seq_len = max(seq_lens, default=0)
425426
return self.attn_mask_builder.get_attn_mask(
426427
max_seq_len, self.dtype, self.device)
428+
# Prefill with cache hit.
429+
elif attn_state == AscendAttentionState.PrefillCacheHit:
430+
return self.attn_mask_builder.get_attn_mask(
431+
128, self.dtype, self.device)
427432
# Decode-only situation.
428433
else:
429434
return None
@@ -492,13 +497,15 @@ def _process_reqs(
492497
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
493498
self.device, non_blocking=True)
494499

495-
attn_state = AscendAttentionState.ChunkedPrefill
496-
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
497-
attn_state = AscendAttentionState.PrefillOnly
500+
if self.chunked_prefill_enabled:
501+
attn_state = AscendAttentionState.ChunkedPrefill
502+
elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
503+
attn_state = AscendAttentionState.PrefillNoCache
504+
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
498505
elif np.all(num_scheduled_tokens == 1):
499506
attn_state = AscendAttentionState.DecodeOnly
500507
else:
501-
attn_state = AscendAttentionState.ChunkedPrefill
508+
attn_state = AscendAttentionState.PrefillCacheHit
502509

503510
attn_mask = self.make_attention_mask(seq_lens=seq_lens,
504511
query_lens=num_scheduled_tokens,

0 commit comments

Comments
 (0)