Skip to content

Commit c33bfd9

Browse files
committed
Fix wrong prefill skip attn metadata
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent fb5d610 commit c33bfd9

File tree

7 files changed

+45
-24
lines changed

7 files changed

+45
-24
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def forward(
194194
if decode_indices is None:
195195
decode_indices = torch.arange(positions.size(0),
196196
device=positions.device)
197+
197198
num_decodes = decode_indices.shape[0]
198199
assert num_decodes >= 1
199200
assert first_residual is not None
@@ -270,12 +271,14 @@ def load_weights(self, weights: Iterable[tuple[str,
270271

271272

272273
@fork_new_process_for_each_test
273-
@pytest.mark.parametrize("enforce_eager", [False, True])
274-
def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager):
275-
prompt = "What is the capital of France?"
274+
@pytest.mark.parametrize("enforce_eager", [True, False])
275+
def test_kv_sharing_skip_prefill(
276+
monkeypatch: pytest.MonkeyPatch,
277+
enforce_eager: bool,
278+
):
276279
ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM)
277280
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
278-
single_prompt = [prompt]
281+
prompts = ["What is the capital of France?"]
279282
compilation_config = CompilationConfig(
280283
level=CompilationLevel.PIECEWISE
281284
if not enforce_eager else CompilationLevel.NO_COMPILATION,
@@ -284,21 +287,22 @@ def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager):
284287
with monkeypatch.context() as m:
285288
m.setenv("VLLM_USE_V1", "1")
286289

287-
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
288-
enforce_eager=enforce_eager,
289-
compilation_config=compilation_config)
290-
responses = llm.generate(single_prompt, sampling_params)
290+
llm = LLM(
291+
model="Qwen/Qwen2-1.5B-Instruct",
292+
enforce_eager=enforce_eager,
293+
compilation_config=compilation_config,
294+
)
295+
responses = llm.generate(prompts, sampling_params)
291296
ref_output = responses[0].outputs[0].text
292297

293298
del llm
294299
gc.collect()
295300
torch.cuda.empty_cache()
296301

297-
m.setenv("VLLM_V1_KV_SHARING_SKIP_PREFILL", "1")
298-
299302
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
300303
enforce_eager=enforce_eager,
301-
compilation_config=compilation_config)
302-
responses = llm.generate(single_prompt, sampling_params)
304+
compilation_config=compilation_config,
305+
kv_sharing_skip_prefill=True)
306+
responses = llm.generate(prompts, sampling_params)
303307
output = responses[0].outputs[0].text
304308
assert output == ref_output

vllm/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,10 @@ class CacheConfig:
15641564
checkpoint if available. Otherwise, the scales will default to 1.0."""
15651565
cpu_kvcache_space_bytes: Optional[int] = None
15661566
"""(CPU backend only) CPU key-value cache space."""
1567+
kv_sharing_skip_prefill: bool = False
1568+
"""Skip prefill for tokens where applicable in KV cache sharing
1569+
scenarios where required key/value tensors have been populated
1570+
in earlier KV sharing target layers."""
15671571

15681572
# Will be set after profiling.
15691573
num_gpu_blocks: Optional[int] = field(default=None, init=False)
@@ -4115,7 +4119,10 @@ class CompilationConfig:
41154119
- None (default): capture sizes are inferred from vllm config.
41164120
- list[int]: capture sizes are specified as given."""
41174121
cudagraph_share_memory_pool: bool = True
4118-
"""Whether to share a single global memory pool for each graph capture"""
4122+
"""Whether to share a single global memory pool for each graph capture
4123+
When CUDA graphs are not replayed in the same order they are captured,
4124+
e.g. when compiling multiple modules in a model and modules take different
4125+
input shapes, it is unsafe to share memory across graph captures."""
41194126
cudagraph_copy_inputs: bool = False
41204127
"""Whether to copy input tensors for
41214128
cudagraph. If the caller can guarantee that the same input buffers

vllm/engine/arg_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ class EngineArgs:
472472
override_attention_dtype: str = ModelConfig.override_attention_dtype
473473

474474
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
475+
kv_sharing_skip_prefill: bool = CacheConfig.kv_sharing_skip_prefill
475476

476477
additional_config: dict[str, Any] = \
477478
get_field(VllmConfig, "additional_config")
@@ -748,6 +749,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
748749
**cache_kwargs["cpu_offload_gb"])
749750
cache_group.add_argument("--calculate-kv-scales",
750751
**cache_kwargs["calculate_kv_scales"])
752+
cache_group.add_argument("--kv-sharing-skip-prefill",
753+
**cache_kwargs["kv_sharing_skip_prefill"])
751754

752755
# Tokenizer arguments
753756
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
@@ -1158,6 +1161,7 @@ def create_engine_config(
11581161
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
11591162
cpu_offload_gb=self.cpu_offload_gb,
11601163
calculate_kv_scales=self.calculate_kv_scales,
1164+
kv_sharing_skip_prefill=self.kv_sharing_skip_prefill,
11611165
)
11621166

11631167
# Get the current placement group if Ray is initialized and

vllm/entrypoints/llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
override_pooler_config: Optional[PoolerConfig] = None,
195195
compilation_config: Optional[Union[int, dict[str, Any],
196196
CompilationConfig]] = None,
197+
kv_sharing_skip_prefill: bool = False,
197198
**kwargs,
198199
) -> None:
199200
"""LLM constructor."""
@@ -267,6 +268,7 @@ def __init__(
267268
mm_processor_kwargs=mm_processor_kwargs,
268269
override_pooler_config=override_pooler_config,
269270
compilation_config=compilation_config_instance,
271+
kv_sharing_skip_prefill=kv_sharing_skip_prefill,
270272
**kwargs,
271273
)
272274

vllm/envs.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@
139139
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
140140
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
141141
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
142-
VLLM_V1_KV_SHARING_SKIP_PREFILL: bool = False
143142

144143

145144
def get_default_cache_root():
@@ -966,8 +965,6 @@ def get_vllm_port() -> Optional[int]:
966965
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
967966
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
968967

969-
"VLLM_V1_KV_SHARING_SKIP_PREFILL":
970-
lambda: os.environ.get("VLLM_V1_KV_SHARING_SKIP_PREFILL", "0") == "1",
971968
}
972969

973970
# --8<-- [end:env-vars-definition]

vllm/v1/attention/backends/flash_attn.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,16 @@ def build(
267267
common_prefix_len: int,
268268
common_attn_metadata: CommonAttentionMetadata,
269269
) -> FlashAttentionMetadata:
270+
prefill_skipped_attn_metadata = None
271+
if common_attn_metadata.decode_indices is not None:
272+
# NOTE(sarckk): attention metadata for partial prefill skip case
273+
# needs to be built first, otherwise the line below
274+
# block_table.slot_mapping[num_actual_tokens:].fill_(-1)
275+
# will override the correct slot mapping
276+
prefill_skipped_attn_metadata = self.build_skip_prefill(
277+
common_prefix_len=0, # disable cascade attention
278+
common_attn_metadata=common_attn_metadata)
279+
270280
num_reqs = common_attn_metadata.num_reqs
271281
num_actual_tokens = common_attn_metadata.num_actual_tokens
272282
max_query_len = common_attn_metadata.max_query_len
@@ -415,12 +425,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
415425
# we only set num_splits when using cuda graphs.
416426
max_num_splits = self.max_num_splits
417427

418-
prefill_skipped_attn_metadata = None
419-
if common_attn_metadata.decode_indices is not None:
420-
prefill_skipped_attn_metadata = self.build_skip_prefill(
421-
common_prefix_len=0, # disable cascade attention
422-
common_attn_metadata=common_attn_metadata)
423-
424428
attn_metadata = FlashAttentionMetadata(
425429
num_actual_tokens=num_actual_tokens,
426430
max_query_len=max_query_len,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,8 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor):
583583
"""
584584
Pads logits_indices to align with CUDA graph capture sizes
585585
"""
586+
if not self.cache_config.kv_sharing_skip_prefill:
587+
return None
586588
num_decodes = logits_indices.shape[0]
587589
# TODO(sarckk): With chunked prefills, logits_indices contains
588590
# indices for partial requests though we do not sample any token
@@ -602,8 +604,9 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor):
602604
def _prepare_inputs(
603605
self,
604606
scheduler_output: "SchedulerOutput",
605-
) -> tuple[dict[str, Any], bool, torch.Tensor,
606-
Optional[SpecDecodeMetadata], np.ndarray, torch.Tensor]:
607+
) -> tuple[dict[str,
608+
Any], bool, torch.Tensor, Optional[SpecDecodeMetadata],
609+
np.ndarray, Optional[torch.Tensor]]:
607610
"""
608611
:return: tuple[
609612
attn_metadata: layer-to-attention_metadata mapping,

0 commit comments

Comments
 (0)