Skip to content

Commit d112b85

Browse files
committed
Fix wrong prefill skip attn metadata
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent 1cbd312 commit d112b85

File tree

7 files changed

+45
-24
lines changed

7 files changed

+45
-24
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def forward(
194194
if decode_indices is None:
195195
decode_indices = torch.arange(positions.size(0),
196196
device=positions.device)
197+
197198
num_decodes = decode_indices.shape[0]
198199
assert num_decodes >= 1
199200
assert first_residual is not None
@@ -270,12 +271,14 @@ def load_weights(self, weights: Iterable[tuple[str,
270271

271272

272273
@fork_new_process_for_each_test
273-
@pytest.mark.parametrize("enforce_eager", [False, True])
274-
def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager):
275-
prompt = "What is the capital of France?"
274+
@pytest.mark.parametrize("enforce_eager", [True, False])
275+
def test_kv_sharing_skip_prefill(
276+
monkeypatch: pytest.MonkeyPatch,
277+
enforce_eager: bool,
278+
):
276279
ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM)
277280
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
278-
single_prompt = [prompt]
281+
prompts = ["What is the capital of France?"]
279282
compilation_config = CompilationConfig(
280283
level=CompilationLevel.PIECEWISE
281284
if not enforce_eager else CompilationLevel.NO_COMPILATION,
@@ -284,21 +287,22 @@ def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager):
284287
with monkeypatch.context() as m:
285288
m.setenv("VLLM_USE_V1", "1")
286289

287-
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
288-
enforce_eager=enforce_eager,
289-
compilation_config=compilation_config)
290-
responses = llm.generate(single_prompt, sampling_params)
290+
llm = LLM(
291+
model="Qwen/Qwen2-1.5B-Instruct",
292+
enforce_eager=enforce_eager,
293+
compilation_config=compilation_config,
294+
)
295+
responses = llm.generate(prompts, sampling_params)
291296
ref_output = responses[0].outputs[0].text
292297

293298
del llm
294299
gc.collect()
295300
torch.cuda.empty_cache()
296301

297-
m.setenv("VLLM_V1_KV_SHARING_SKIP_PREFILL", "1")
298-
299302
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
300303
enforce_eager=enforce_eager,
301-
compilation_config=compilation_config)
302-
responses = llm.generate(single_prompt, sampling_params)
304+
compilation_config=compilation_config,
305+
kv_sharing_skip_prefill=True)
306+
responses = llm.generate(prompts, sampling_params)
303307
output = responses[0].outputs[0].text
304308
assert output == ref_output

vllm/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,10 @@ class CacheConfig:
15281528
checkpoint if available. Otherwise, the scales will default to 1.0."""
15291529
cpu_kvcache_space_bytes: Optional[int] = None
15301530
"""(CPU backend only) CPU key-value cache space."""
1531+
kv_sharing_skip_prefill: bool = False
1532+
"""Skip prefill for tokens where applicable in KV cache sharing
1533+
scenarios where required key/value tensors have been populated
1534+
in earlier KV sharing target layers."""
15311535

15321536
# Will be set after profiling.
15331537
num_gpu_blocks: Optional[int] = field(default=None, init=False)
@@ -4066,7 +4070,10 @@ class CompilationConfig:
40664070
- None (default): capture sizes are inferred from vllm config.
40674071
- list[int]: capture sizes are specified as given."""
40684072
cudagraph_share_memory_pool: bool = True
4069-
"""Whether to share a single global memory pool for each graph capture"""
4073+
"""Whether to share a single global memory pool for each graph capture
4074+
When CUDA graphs are not replayed in the same order they are captured,
4075+
e.g. when compiling multiple modules in a model and modules take different
4076+
input shapes, it is unsafe to share memory across graph captures."""
40704077
cudagraph_copy_inputs: bool = False
40714078
"""Whether to copy input tensors for
40724079
cudagraph. If the caller can guarantee that the same input buffers

vllm/engine/arg_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ class EngineArgs:
459459
override_attention_dtype: str = ModelConfig.override_attention_dtype
460460

461461
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
462+
kv_sharing_skip_prefill: bool = CacheConfig.kv_sharing_skip_prefill
462463

463464
additional_config: dict[str, Any] = \
464465
get_field(VllmConfig, "additional_config")
@@ -735,6 +736,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
735736
**cache_kwargs["cpu_offload_gb"])
736737
cache_group.add_argument("--calculate-kv-scales",
737738
**cache_kwargs["calculate_kv_scales"])
739+
cache_group.add_argument("--kv-sharing-skip-prefill",
740+
**cache_kwargs["kv_sharing_skip_prefill"])
738741

739742
# Tokenizer arguments
740743
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
@@ -1120,6 +1123,7 @@ def create_engine_config(
11201123
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
11211124
cpu_offload_gb=self.cpu_offload_gb,
11221125
calculate_kv_scales=self.calculate_kv_scales,
1126+
kv_sharing_skip_prefill=self.kv_sharing_skip_prefill,
11231127
)
11241128

11251129
# Get the current placement group if Ray is initialized and

vllm/entrypoints/llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def __init__(
191191
override_pooler_config: Optional[PoolerConfig] = None,
192192
compilation_config: Optional[Union[int, dict[str, Any],
193193
CompilationConfig]] = None,
194+
kv_sharing_skip_prefill: bool = False,
194195
**kwargs,
195196
) -> None:
196197
"""LLM constructor."""
@@ -264,6 +265,7 @@ def __init__(
264265
mm_processor_kwargs=mm_processor_kwargs,
265266
override_pooler_config=override_pooler_config,
266267
compilation_config=compilation_config_instance,
268+
kv_sharing_skip_prefill=kv_sharing_skip_prefill,
267269
**kwargs,
268270
)
269271

vllm/envs.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@
138138
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
139139
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
140140
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
141-
VLLM_V1_KV_SHARING_SKIP_PREFILL: bool = False
142141

143142

144143
def get_default_cache_root():
@@ -955,8 +954,6 @@ def get_vllm_port() -> Optional[int]:
955954
# models
956955
"VLLM_USE_NVFP4_CT_EMULATIONS":
957956
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))),
958-
"VLLM_V1_KV_SHARING_SKIP_PREFILL":
959-
lambda: os.environ.get("VLLM_V1_KV_SHARING_SKIP_PREFILL", "0") == "1",
960957
}
961958

962959
# --8<-- [end:env-vars-definition]

vllm/v1/attention/backends/flash_attn.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,16 @@ def build(
256256
common_prefix_len: int,
257257
common_attn_metadata: CommonAttentionMetadata,
258258
) -> FlashAttentionMetadata:
259+
prefill_skipped_attn_metadata = None
260+
if common_attn_metadata.decode_indices is not None:
261+
# NOTE(sarckk): attention metadata for partial prefill skip case
262+
# needs to be built first, otherwise the line below
263+
# block_table.slot_mapping[num_actual_tokens:].fill_(-1)
264+
# will override the correct slot mapping
265+
prefill_skipped_attn_metadata = self.build_skip_prefill(
266+
common_prefix_len=0, # disable cascade attention
267+
common_attn_metadata=common_attn_metadata)
268+
259269
num_reqs = common_attn_metadata.num_reqs
260270
num_actual_tokens = common_attn_metadata.num_actual_tokens
261271
max_query_len = common_attn_metadata.max_query_len
@@ -404,12 +414,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
404414
# we only set num_splits when using cuda graphs.
405415
max_num_splits = self.max_num_splits
406416

407-
prefill_skipped_attn_metadata = None
408-
if common_attn_metadata.decode_indices is not None:
409-
prefill_skipped_attn_metadata = self.build_skip_prefill(
410-
common_prefix_len=0, # disable cascade attention
411-
common_attn_metadata=common_attn_metadata)
412-
413417
attn_metadata = FlashAttentionMetadata(
414418
num_actual_tokens=num_actual_tokens,
415419
max_query_len=max_query_len,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor):
580580
"""
581581
Pads logits_indices to align with CUDA graph capture sizes
582582
"""
583+
if not self.cache_config.kv_sharing_skip_prefill:
584+
return None
583585
num_decodes = logits_indices.shape[0]
584586
# TODO(sarckk): With chunked prefills, logits_indices contains
585587
# indices for partial requests though we do not sample any token
@@ -599,8 +601,9 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor):
599601
def _prepare_inputs(
600602
self,
601603
scheduler_output: "SchedulerOutput",
602-
) -> tuple[dict[str, Any], bool, torch.Tensor,
603-
Optional[SpecDecodeMetadata], np.ndarray, torch.Tensor]:
604+
) -> tuple[dict[str,
605+
Any], bool, torch.Tensor, Optional[SpecDecodeMetadata],
606+
np.ndarray, Optional[torch.Tensor]]:
604607
"""
605608
:return: tuple[
606609
attn_metadata: layer-to-attention_metadata mapping,

0 commit comments

Comments
 (0)