Skip to content

[V1] Partial prefill skip for layers reusing shared KV cache #19719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

sarckk
Copy link
Collaborator

@sarckk sarckk commented Jun 17, 2025

Motivation

KV cache techniques like SwiftKV reduce computation required during prefill. This is harder to implement in V1 where the scheduler groups tokens for prefill and decode in the same batch. This PR adds instrumentation to support prefill compute savings in V1 in KV cache sharing setups where KV sharing is used such that certain tokens can be skipped during prefill (as KV target layers have already populated the necessary key/value tensors required for decoding).

Example

Let's say we have a 24 layer model where first 12 layers allocate their own KV caches and next 12 layers re-use the shared KV cache of its corresponding KV target layer. Then given input prompt sequence of N tokens, we can skip prefill for N-1 tokens for the last 12 layers, because the key/value tensors used for decoding is already populated in the KV caches of the first 12 layers. Because vLLM v1 scheduler does not distinguish prefill/decode and employs continuous batching, we can instead perform forward on the last 12 layers with a reduced input size.

For example, if we have request 0 and request 1 with 4 prompt tokens each, then we might have tokens batched as such:

<----r0---> <----r1---->
[0, 1, 2, 3, 4, 5, 6, 7]

For the first 12 self-attention layers, we can do forward with the full input [0, 1, 2, 3, 4, 5, 6, 7], while for the last 12 cross-attention layers, we can do forward with the last token for each request [3,7], as these are the only positions where valid logits are required to sample output tokens from.

Frontend changes

This PR adds a new --kv-sharing-skip-prefill arg which is added to the CacheConfig. This causes FlashAttention backend to compute an extra set of metadata assuming prefill skip, but changes are still required on model side to take advantage of this.

Attention metadata

Attention metadata needs to be changed to account for the different query offsets and max lengths in the shared KV layers for which N-1 tokens are skipped during prefill.

Correctness Test

Unit test show outputs are roughly equivalent with and without this optimization (exact numerics will differ as batched mm op will yield slightly different results depending on batch size)

pytest tests/v1/e2e/test_kv_sharing_truncated_prefill.py::test_kv_sharing_truncated_prefill

Perf comparison

Set up: single batch and input length of 8192. Using compile+piecewise cuda graph

TestQwen2ForCausalLM model forward trace with optimization (enable_kv_sharing_truncated_prefill=True)

second layer group takes 9.7ms

Screenshot 2025-07-02 at 21 00 10

Trace without optimization (enable_kv_sharing_truncated_prefill=False)

second layer group takes 16.6ms

Screenshot 2025-07-02 at 20 58 20

Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jun 17, 2025
@houseroad houseroad requested a review from heheda12345 June 17, 2025 01:18
@sarckk sarckk requested a review from LucasWilkinson June 17, 2025 02:11
@heheda12345
Copy link
Collaborator

My concern is whether this optimization is too model specific. It works for models that the first k layers have kv cache. Does it work for models that every m layers share the same kv cache like Hunyuan?

@sarckk
Copy link
Collaborator Author

sarckk commented Jun 17, 2025

My concern is whether this optimization is too model specific. It works for models that the first k layers have kv cache. Does it work for models that every m layers share the same kv cache like Hunyuan?

It only works for the case where the first k layers have kv cache as you said. For general KV sharing cases, it should also apply for last N layers that reuse the KV cache (ie there are no other layers afterwards that have its own KV cache). So I agree it will not apply to a majority of models, but then I'm not sure if there is a better way to implement this kind of functionality.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a quick pass on this PR.

And I'm curious about your plan to support piecewise cuda graph. We need cuda graph for num_total_tokens in the first few layers, and num_decode_tokens in the following layers.

vllm/envs.py Outdated
@@ -128,6 +128,7 @@
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_V1_KV_SHARING_SKIP_PREFILL: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to add it as a cli arg.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -602,6 +620,11 @@ def forward(
# Profiling run.
return output

if (self.kv_sharing_target_layer_name is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is not true for hunyuan-style kv sharing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added logic to detect which layers are 'eligible' for this prefill skip optimization

Copy link

mergify bot commented Jun 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added needs-rebase qwen Related to Qwen models labels Jun 18, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really like to try to keep build signature of the metadata builders as simple as possible so hopefully we can create some nice unit testing infrastructure in the future. Do we really need to add decode_only_common_attn_metadata to the build call signature? can we make the kv sharing layers a different KVSpec and have separate build calls at this level:

for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = self.attn_metadata_builders[kv_cache_group_id]
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
builder,
)
attn_metadata_i = (builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))

we should probably be doing this for local attention too but that was added before we had the hybrid-KV cache (which enabled different build calls for different layer groups). We should probably migrate local attention to a scheme like this too

self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
decode_only_common_attn_metadata: Optional[
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we need to pass decode_only_common_attn_metadata as a separate arg; is there a reason we can't just use a different build call at the gpu model runner level? i.e. here-ish:

for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = self.attn_metadata_builders[kv_cache_group_id]
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
builder,
)
attn_metadata_i = (builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I initially had a separate build() call at the model runner level, but I needed to set this as a property of attention metadata for all different backends, and they don't share a common schema. So I thought I could pass the info and let each backend decide what to do with it.

But I do agree that your approach is a better abstraction, will follow up on that

decode_only_common_attn_metadata = None
if envs.VLLM_V1_KV_SHARING_SKIP_PREFILL:
decode_only_common_attn_metadata = (
compute_decode_only_common_attn_metadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we move this logic into metadata builder?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved this logic to flash attn metadata builder

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I think I missed this so not sure what the code looked like at this point but I think ideally we would keep this common metadata manipulation outside of the metadata builders so we can naturally just support all the backends (assuming we can keep a clean build interface). This is important for blackwell where FlashInfer has the best perf. I actually want to do something similar for local-attention since that could also be done via pure CommonAttentionMetadata manipulation and would enable iRoPe for FlashInfer.

see: #19719 (comment)

@sarckk sarckk force-pushed the decode-only-attn branch 2 times, most recently from 541f2a5 to a9783c3 Compare July 2, 2025 23:28
@sarckk sarckk changed the title [V1] Perf optimization for layers reusing shared KV cache [V1] Perf optimization for early exit inference Jul 3, 2025
@sarckk sarckk changed the title [V1] Perf optimization for early exit inference [V1] Partial prefill skip for layers reusing shared KV cache Jul 3, 2025
@mergify mergify bot added the frontend label Jul 3, 2025
@sarckk sarckk force-pushed the decode-only-attn branch from 587c1d6 to 226edcf Compare July 3, 2025 17:13
@sarckk sarckk marked this pull request as ready for review July 3, 2025 17:13
@simon-mo
Copy link
Collaborator

simon-mo commented Jul 3, 2025

@LucasWilkinson @heheda12345 ready for a review!

@@ -196,16 +198,83 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None

def build_skip_prefill(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this could be a common utility function that operates on CommonAttentionMetadata generically instead of inside FlashAttentionMetadataBuilder; so we could make it generic for all attention backends

num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len

max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_np = common_attn_metadata.query_start_loc_np
if query_start_loc_np is None:
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we make this non-optional? not necessarily needed for this PR but id eventually like to break the dependency on the runner in the metadata builders

@@ -43,6 +44,12 @@ class CommonAttentionMetadata:
max_query_len: int
"""Longest query in batch"""

decode_indices: Optional[torch.Tensor] = None
"""indices used for decoding"""
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Id prefer if we could keep CommonAttentionMetadata cleaner; i.e. not have anything that is overly specific to given attention scheme/model. It will make attention unit testing and benchmarking a bit easier to setup

If it's possible I think id prefer for these layers to be part of different KVCacheSpec group; then we might be able to handle the prefill filtering at that level. i.e. we could do something like

        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):
            ....
            builder = self.attn_metadata_builders[kv_cache_group_id]
            
            common_attn_metadata_ = common_attn_metadata
            if instance(kv_cache_group_spec, SharedKVSpec):
                    common_attn_metadata_ = filter_prefills(common_attn_metadata)

            attn_metadata_i = (builder.build(
                common_prefix_len=common_prefix_len,
                common_attn_metadata=common_attn_metadata_,
            ))

Where filter_prefills is basically build_skip_prefill

Something like this (maybe the KVCacheSpec is the right thing here and we should have an additional way to group attention layers but I think its reasonable) would allow us to be backend agnostic which will be important for Blackwell were we want to eventually default to FlashInfer

@heheda12345 do you have opinions on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Then does this kv cache grouping strategy make sense @sarckk @LucasWilkinson

  1. Only send layers that has its own kv cache to get_kv_cache_config so that kv_cache_manager doesn't need to be aware of any kv sharing logic (Achieved by [V1] Support cross-layer KV sharing #18212)
  2. On worker side, add layers with kv sharing but cannot enable this partial prefill skip optimization to the group that it shares kv with (Achieved by [V1] Support cross-layer KV sharing #18212) , as these layers should use the attention metadata without prefill skip.
  3. Create new kv cache groups for layers that enables this optimization as Lucas mentioned. Layers share kv with different kv cache groups should be put into different new groups.

In that case, this spec shouldn't be called SharedKVSpec but some thing more concrete for case 3.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great job! Added some comments.

Comment on lines 240 to 248
self.residual[:num_decodes].copy_(first_residual[decode_indices])
self.hidden_states[:num_decodes].copy_(
first_hidden_states[decode_indices])
positions[:num_decodes].copy_(positions[decode_indices])

second_hidden_states, second_residual = self.second_layer_group(
positions[:num_decodes],
self.hidden_states[:num_decodes],
self.residual[:num_decodes],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that we should avoid reimplementing these logic for each model.
As #18212 lacks an example, I believe people will refer to this test to implement new kv-sharing models. Therefore, if you want to left it to a future PR, can you:

  1. Add another model that uses kv-sharing logic but don't enable the optimization in this PR, to serve as an example that you suggest people to add a new model at this moment. It should also be a useful test.
  2. Add notes for it is WIP and don't copy the code here to a new model.

# NOTE(sarckk): Due to cudagraph padding, decode_indices may have
# trailing repeated indices. Attention output is only valid at the
# last index in this case.
last_index_mask = decode_indices == decode_indices[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it valid at the last index? I think the output of the last batch is second_hidden_states[num_decode - 1].

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm doing this because with CUDA graph capture, I'm padding num_decode so they are aligned with the graph captures sizes here. For example, if the original decode_indices was [0,19,22] then we might pad it to [0,19,22,22] where the last index is repeated for padding.

During attention, the index [22] would originally have any attention applied without any padding, e.g. for query_len of 1 and kv_len of 3, the masking (True is the masked positions) might look like:

[False, False, False]

but with padded query [22,22], the causal mask would look like:

[[False, False, True],
 [False, False, False]]

so only the last position gives the correct output. I also have a simple example in https://gist.github.com/sarckk/ffd59338994ca6f3863f0119aa09784d

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to piece-wise cuda graph, I think we only need to build attention metadata as if there is no padding as the attention part is executed in eager mode.

@@ -96,6 +96,9 @@ class ForwardContext:
dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False

decode_indices: Optional[torch.Tensor] = None
"""indices used for decoding"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you find a better name? I think you want to include "the last prefill token" + "all decode tokens" in this tensor. And is it possible to hide it in the attention metadata for kv sharing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe generation_indices? although the value is equal to logits_indices right now, I wanted to differentiate it from logits_indices as generation_indices eventually would not contain the last token for partial prefill chunks while logits_indices does.

what do you mean by hiding it in attention metadata for kv sharing?

) -> FlashAttentionMetadata:
prefill_skipped_attn_metadata = None
if common_attn_metadata.decode_indices is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if common_attn_metadata.decode_indices is not None:
if self.cache_config.kv_sharing_skip_prefill:

Prefer this straight-forward condition.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need the original condition because common_attn_metadata.decode_indices may still be None even if kv_sharing_skip_prefill is set, if all requests are on decode

@@ -435,6 +505,8 @@ def __init__(
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")

self.kv_sharing_skip_prefill = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note on it may be updated to True by gpu model runner. And can we remove this after putting these layers to a different kv cache group?

@@ -43,6 +44,12 @@ class CommonAttentionMetadata:
max_query_len: int
"""Longest query in batch"""

decode_indices: Optional[torch.Tensor] = None
"""indices used for decoding"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Then does this kv cache grouping strategy make sense @sarckk @LucasWilkinson

  1. Only send layers that has its own kv cache to get_kv_cache_config so that kv_cache_manager doesn't need to be aware of any kv sharing logic (Achieved by [V1] Support cross-layer KV sharing #18212)
  2. On worker side, add layers with kv sharing but cannot enable this partial prefill skip optimization to the group that it shares kv with (Achieved by [V1] Support cross-layer KV sharing #18212) , as these layers should use the attention metadata without prefill skip.
  3. Create new kv cache groups for layers that enables this optimization as Lucas mentioned. Layers share kv with different kv cache groups should be put into different new groups.

In that case, this spec shouldn't be called SharedKVSpec but some thing more concrete for case 3.

Copy link

mergify bot commented Jul 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 5, 2025
@heheda12345
Copy link
Collaborator

May be unrelated to this PR. We also need an elegant way to skip preparing kv for layers that don't need them.

qkv, _ = self.qkv_proj(hidden_states)

q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)

@heheda12345
Copy link
Collaborator

@sarckk Here is a PR for v0 YOCO optimization. #20702 Though it is simplified due to ignoring chunked prefill and cuda graph, you can take a look and check whether there are anything you can learn.
The key logic in that PR:

def forward_generate_kv_cache(
self, query: torch.Tensor, key: Optional[torch.Tensor],
value: Optional[torch.Tensor], k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor:
head_size = self.head_size
num_heads = self.num_heads // 2
num_kv_heads = self.num_kv_heads // 2
query = query.view(-1, num_heads, head_size)
if key is not None:
assert value is not None
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
else:
assert value is None
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[
0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch"
assert value.shape[
0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch"
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
if key is not None and value is not None:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens, "query shape mismatch"
assert decode_query.shape[
0] == num_decode_tokens, "decode query shape mismatch"
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if k_cache.numel() == 0 \
or prefill_meta.block_tables is None \
or prefill_meta.block_tables.numel() == 0:
# normal attention
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)
assert prefill_output.shape == output[:
num_prefill_tokens].shape
output[:num_prefill_tokens] = prefill_output
else:
raise Exception("prefix caching not supported")
if decode_meta := attn_metadata.decode_metadata:
block_tables_arg = decode_meta.block_tables
try:
output[num_prefill_tokens:] = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=k_cache,
v_cache=v_cache,
block_table=block_tables_arg,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
except Exception as e:
logger.error("Error in PagedAttention.forward_decode: %s",
str(e))
raise e
# Reshape the output tensor.
return output.view(-1, num_heads, head_size)
def forward_with_kv_cache_only(
self,
query: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata,
):
if not attn_metadata.decode_metadata:
block_tables_arg = attn_metadata.cross_layer_shared_block_tables
else:
block_tables_arg = attn_metadata.block_tables
output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=k_cache,
v_cache=v_cache,
block_table=block_tables_arg,
cache_seqlens=attn_metadata.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
return output
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
if self.lambda_full is None:
self.lambda_init = self.differential_flash_attention_config[
"lambda_init"]
lambda_q1 = self.differential_flash_attention_config["lambda_q1"]
lambda_k1 = self.differential_flash_attention_config["lambda_k1"]
lambda_q2 = self.differential_flash_attention_config["lambda_q2"]
lambda_k2 = self.differential_flash_attention_config["lambda_k2"]
lambda_1 = torch.exp(
torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(
torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q)
self.lambda_full = lambda_1 - lambda_2 + self.lambda_init
if not self.used_shared_kv_cache: # need to generate kv-cache
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size)
q1, q2 = self.split_heads(q)
k1, k2 = self.split_heads(k)
v1, v2 = self.split_heads(v)
# kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501
# Split by half along the first dimension.
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous"
assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous"
if kv_cache1.numel() != 0:
self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata)
self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata)
key_cache1, value_cache1 = self.split_kv_cache(kv_cache1)
key_cache2, value_cache2 = self.split_kv_cache(kv_cache2)
else:
key_cache1, value_cache1 = torch.empty(0), torch.empty(0)
key_cache2, value_cache2 = torch.empty(0), torch.empty(0)
attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1,
value_cache1,
attn_metadata)
attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1,
value_cache2,
attn_metadata)
attn11 = attn11.view(q1.shape)
attn12 = attn12.view(q1.shape)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2,
value_cache1,
attn_metadata)
attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2,
value_cache2,
attn_metadata)
attn21 = attn21.view(q2.shape)
attn22 = attn22.view(q2.shape)
attn2 = torch.cat([attn21, attn22], dim=-1)
attn = attn1 - self.lambda_full * attn2
# attn shape (-1, self.num_heads // 2, 2 * self.head_dim)
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# reshape back to 2 * num_head
attn_output = rearrange(attn,
"... H (two D) -> ... (H two) D",
two=2)
else: # re-use the kv cache, full attention
q = q.view(-1, self.num_heads, self.head_size)
q1, q2 = self.split_heads(q)
# kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1]
key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1]
attn11 = self.forward_with_kv_cache_only(q1, key_cache1,
value_cache1,
attn_metadata)
attn12 = self.forward_with_kv_cache_only(q1, key_cache1,
value_cache2,
attn_metadata)
attn11 = attn11.view(q1.shape)
attn12 = attn12.view(q1.shape)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = self.forward_with_kv_cache_only(q2, key_cache2,
value_cache1,
attn_metadata)
attn22 = self.forward_with_kv_cache_only(q2, key_cache2,
value_cache2,
attn_metadata)
attn21 = attn21.view(q2.shape)
attn22 = attn22.view(q2.shape)
attn2 = torch.cat([attn21, attn22], dim=-1)
attn = attn1 - self.lambda_full * attn2
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# reshape back to 2 * num_head
attn_output = rearrange(attn,
"... H (two D) -> ... (H two) D",
two=2)
attn_output = attn_output.view(-1, self.num_heads * self.head_size)
return attn_output

# Starting from this layer, we do not need to calculate
# the kv cache since we reuse the kv cache from last layer.
# If in prefill phase, we can <s>prune></s> truncate
# the hidden state to save computation cost.
if attn_metadata.prefill_metadata:
selected_token_indices = torch.cumsum(
attn_metadata.seq_lens_tensor, dim=0) - 1
hidden_states = hidden_states.index_select(
0, selected_token_indices)
ssm_output = ssm_output.index_select(
0, selected_token_indices)

sarckk added 10 commits July 11, 2025 13:50
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
@sarckk sarckk force-pushed the decode-only-attn branch from e171dd5 to 3cd2474 Compare July 11, 2025 21:18
@mergify mergify bot removed the needs-rebase label Jul 11, 2025
sarckk added 3 commits July 11, 2025 14:52
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
@mergify mergify bot added the tpu Related to Google TPUs label Jul 14, 2025
Copy link

mergify bot commented Jul 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend needs-rebase qwen Related to Qwen models tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants