Skip to content

Support embedding models in V1 #16188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 98 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
f36c4f9
Remove guardrails that prevent V1 from trying to run embedding models
maxdebayser Mar 24, 2025
acf4638
hack v1 flash_attn to support encoder_only
maxdebayser Apr 3, 2025
b13bbc0
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 3, 2025
8debea0
Revert changes to disable kv caching for encoder-only models
maxdebayser Apr 3, 2025
8d97b9c
Add pooling support in v1
maxdebayser Apr 5, 2025
d60b22b
First end-to-end working version of Bert embeddings in V1
maxdebayser Apr 7, 2025
6bebbb8
Support warmup for pooling models in V1
maxdebayser Apr 7, 2025
6dafd71
address review comments
maxdebayser Apr 7, 2025
e2724a2
address review comments
maxdebayser Apr 7, 2025
56ff6cd
remove debug prints
maxdebayser Apr 7, 2025
fc57edd
address review comments
maxdebayser Apr 7, 2025
64a0e62
Fix cross encoder models in V1 and enable tests for pooling models
maxdebayser Apr 8, 2025
4014d41
address review comments
maxdebayser Apr 8, 2025
87a95a8
Merge branch 'main' into v1_embeddings
maxdebayser Apr 8, 2025
902c129
address review comments
maxdebayser Apr 8, 2025
2c68855
re-enable large embedding models
maxdebayser Apr 8, 2025
8afd8f5
address review comments
maxdebayser Apr 8, 2025
7762976
Merge branch 'main' into v1_embeddings
maxdebayser Apr 8, 2025
d7537ae
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 8, 2025
a9e7747
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 9, 2025
17520bd
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 14, 2025
90c611a
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 15, 2025
dec2441
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 17, 2025
a5e83f4
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 23, 2025
187f69b
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 24, 2025
69a0332
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 29, 2025
a9f1721
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 29, 2025
4b066a3
fix merge problems
maxdebayser Apr 30, 2025
43a26dc
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 30, 2025
ca34513
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 30, 2025
bf3033d
Fix missing qwen embedding model param
maxdebayser Apr 30, 2025
67bf727
Make pooling params reach the pooling in V1
maxdebayser May 1, 2025
93b6361
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 1, 2025
d916b88
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 10, 2025
bad4211
fix merge problems
maxdebayser May 10, 2025
35d9bd9
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 11, 2025
dcc6100
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 12, 2025
a4f85b5
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 13, 2025
a5f328a
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 15, 2025
7c5be88
fix merge problem
maxdebayser May 15, 2025
29b75c9
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 4, 2025
6aa204c
backport changes from the other PR
maxdebayser Jun 4, 2025
e81470c
fix merge errors
maxdebayser Jun 4, 2025
20e7140
address review comments
maxdebayser Jun 4, 2025
6bc1e3d
address review comments
maxdebayser Jun 4, 2025
22825bd
simplify PR
maxdebayser Jun 4, 2025
c889b2e
fix mistake
maxdebayser Jun 4, 2025
24462e4
workaround qwen model test issue
maxdebayser Jun 6, 2025
b5f21f2
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 6, 2025
79d1b95
revert unecessary change
maxdebayser Jun 6, 2025
b3a0491
remove duplicated code
maxdebayser Jun 6, 2025
b4ab556
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 6, 2025
1a82e56
remove encoder model support to simplify PR
maxdebayser Jun 7, 2025
a66801b
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 9, 2025
660dd9c
fix several tests
maxdebayser Jun 9, 2025
808c996
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 9, 2025
cdd70c9
Fix test
maxdebayser Jun 9, 2025
0832115
disable bert test
maxdebayser Jun 9, 2025
10bbf74
fix tests
maxdebayser Jun 9, 2025
ee892aa
limit context length to fit test GPU
maxdebayser Jun 9, 2025
2e12eba
limit context length to fit test GPU
maxdebayser Jun 9, 2025
14fcf24
fix test
maxdebayser Jun 10, 2025
0624435
fix test
maxdebayser Jun 10, 2025
706fdb2
Merge branch 'main' into v1_embeddings
22quinn Jun 10, 2025
051f6d4
Fix _construct_cached_request_state
22quinn Jun 10, 2025
214cf06
Fix v1 tests
22quinn Jun 10, 2025
8193bd0
Merge pull request #1 from 22quinn/v1_embeddings
maxdebayser Jun 10, 2025
65b8377
fix test
maxdebayser Jun 10, 2025
33d7f74
Merge branch 'v1_embeddings' of github.com:maxdebayser/vllm into v1_e…
maxdebayser Jun 10, 2025
4ee822a
reduce max_model_len to fit in test gpu
maxdebayser Jun 10, 2025
7242731
fix test
maxdebayser Jun 10, 2025
a4f460b
fix test
maxdebayser Jun 10, 2025
35ca640
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 12, 2025
17f6177
fix test
maxdebayser Jun 12, 2025
3f0d42e
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 12, 2025
74d73cc
use torch.split
maxdebayser Jun 12, 2025
e6a66dc
enable cuda graphs
maxdebayser Jun 12, 2025
4cca774
fix unecessary config.py changes
maxdebayser Jun 12, 2025
8ef1982
fix error message
maxdebayser Jun 12, 2025
28d00d1
remove unused import
maxdebayser Jun 12, 2025
e634f60
fix docstring
maxdebayser Jun 12, 2025
053475c
revert unnecessary code changes
maxdebayser Jun 12, 2025
6228f64
remove debug prints
maxdebayser Jun 12, 2025
42c802a
fix refactoring bug
maxdebayser Jun 12, 2025
f771a19
fix refactoring bug
maxdebayser Jun 12, 2025
02c47ad
Fix default chunked prefill for pooling models
maxdebayser Jun 13, 2025
1fd252c
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 13, 2025
c5c0d97
Revert handling of case that can never happen
maxdebayser Jun 13, 2025
acfc9cc
fix small bug
maxdebayser Jun 13, 2025
225b808
fix small bugs
maxdebayser Jun 13, 2025
2b86c13
fix silly mistake
maxdebayser Jun 13, 2025
2983252
reduce memory usage for small ci gpus
maxdebayser Jun 13, 2025
58c556d
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 13, 2025
878d56a
enable chunked prefill by default for models that support it
maxdebayser Jun 14, 2025
2db273f
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 14, 2025
114af27
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 16, 2025
bc0219d
address review comments
maxdebayser Jun 16, 2025
221f013
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,9 @@ def _init_pooler_config(
) -> Optional["PoolerConfig"]:

if self.runner_type == "pooling":
logger.warning("CUDA graph is not supported for pooling yet, "
"fallback to the eager mode.")
self.enforce_eager = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it not supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some models were returning incorrect output after compiling. I left this as a follow-up work

user_config = override_pooler_config or PoolerConfig()

base_config = get_pooling_config(self.model, self.revision)
Expand Down
6 changes: 0 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,12 +1439,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# No Embedding Models so far.
if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",
recommend_to_remove=False)
return False
Comment on lines -1352 to -1356
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to enable all tasks here?

TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
                     "score", "reward", "transcription"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll double check the "transcription" task, but the others yes. Is this causing a problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it just caused a conflcit in my branch where I had enabled "transcription" and thought maybe it had enabled more than you intended. It's fine if you mean to!

I'm not sure about transcription, either. I know it wouldn't work with whisper, but that'll still get blocked because the model is marked as v0-only. Since all models should have the V0-only marker where needed, this check probably isn't necessary.


# No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures,
Expand Down
45 changes: 37 additions & 8 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
from typing_extensions import assert_never

from vllm.config import PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]


class PoolingType(IntEnum):
Expand Down Expand Up @@ -78,6 +82,8 @@ def get_prompt_lens(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens

Expand Down Expand Up @@ -181,12 +187,27 @@ def __init__(
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids

def get_prompt_token_ids(
self,
pooling_metadata: PoolingMetadata,
) -> List[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata):
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [
seq_data_i.prompt_token_ids
for seq_data_i in pooling_metadata.seq_data.values()
]

def extract_states(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)

returned_token_ids = self.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
Expand All @@ -196,12 +217,11 @@ def extract_states(

offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len, seq_data_i in zip(prompt_lens,
pooling_metadata.seq_data.values()):
for i, prompt_len in enumerate(prompt_lens):
pooled_data_i = hidden_states[offset:offset + prompt_len]
if step_tag_id is not None:
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
pooled_data_i = pooled_data_i[prompt_token_ids[i] ==
step_tag_id]

offset += prompt_len
pooled_data.append(pooled_data_i)
Expand Down Expand Up @@ -287,15 +307,24 @@ def __init__(
self.default_activation_function = \
get_cross_encoder_activation_function(config)

def get_prompt_lens(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens

def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools sentence pair scores from the hidden_states."""

prompt_lens = PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

offset = 0
pooled_data_lst = []
Expand Down
35 changes: 22 additions & 13 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand All @@ -26,7 +27,7 @@
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)

from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .interfaces import SupportsCrossEncoding
from .utils import WeightsMapper, maybe_prefix


Expand Down Expand Up @@ -115,7 +116,7 @@ def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
for layer in self.layer:
for i, layer in enumerate(self.layer):
hidden_states = layer(hidden_states)
return hidden_states

Expand Down Expand Up @@ -323,6 +324,7 @@ def __init__(self,
add_pooling_layer: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
Expand All @@ -340,12 +342,16 @@ def forward(
hidden_states = inputs_embeds
else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings(
input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
seq_lens = None
if attn_metadata is not None: # Can be None during warmup
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
attn_metadata.seq_lens)
assert seq_lens is not None
hidden_states = self.embeddings(input_ids=input_ids,
seq_lens=seq_lens,
position_ids=position_ids,
token_type_ids=token_type_ids)

return self.encoder(hidden_states)

def load_weights(self, weights: Iterable[Tuple[str,
Expand Down Expand Up @@ -385,7 +391,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loaded_params


class BertEmbeddingModel(nn.Module, SupportsV0Only):
class BertEmbeddingModel(nn.Module):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand All @@ -403,6 +409,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)
# TODO: Remove test scaffolding after pooling is implemented
self.sampler = get_sampler()

def forward(
self,
Expand All @@ -411,10 +419,11 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
hidden_states = self.model(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
return hidden_states

def pooler(
self,
Expand Down
52 changes: 27 additions & 25 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,31 +80,33 @@ def forward(
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
pos_list = []
token_list = []
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[offset:offset + seq_len])
token_list.append(input_ids[offset:offset + seq_len])
offset += seq_len

new_pos_list = []
for positions, tokens in zip(pos_list, token_list):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)
if seq_lens is not None: # Can be None during warmup
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
pos_list = []
token_list = []
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[offset:offset + seq_len])
token_list.append(input_ids[offset:offset + seq_len])
offset += seq_len

new_pos_list = []
for positions, tokens in zip(pos_list, token_list):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens,
self.padding_idx))
position_ids = torch.cat(new_pos_list)

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
Expand Down
9 changes: 7 additions & 2 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,11 @@ class PoolingRequestOutput(Generic[_O]):
finished (bool): A flag indicating whether the pooling is completed.
"""

def __init__(self, request_id: str, outputs: _O,
def __init__(self, request_id: str, outputs: _O, prompt: Optional[str],
prompt_token_ids: list[int], finished: bool):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
self.prompt = prompt
self.finished = finished
self.outputs = outputs

Expand All @@ -359,9 +360,10 @@ def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput":
data = pooled_data.to(dtype=torch.float32, device="cpu")
output = PoolingOutput(data)
prompt_token_ids = seq_group.prompt_token_ids
prompt = seq_group.prompt
finished = seq_group.is_finished()

return PoolingRequestOutput(seq_group.request_id, output,
return PoolingRequestOutput(seq_group.request_id, output, prompt,
prompt_token_ids, finished)

def __repr__(self):
Expand Down Expand Up @@ -426,6 +428,7 @@ def from_base(request_output: PoolingRequestOutput):
return EmbeddingRequestOutput(
request_id=request_output.request_id,
outputs=EmbeddingOutput.from_base(request_output.outputs),
prompt=request_output.prompt,
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
Expand Down Expand Up @@ -464,6 +467,7 @@ def from_base(request_output: PoolingRequestOutput):
return ClassificationRequestOutput(
request_id=request_output.request_id,
outputs=ClassificationOutput.from_base(request_output.outputs),
prompt=request_output.prompt,
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
Expand Down Expand Up @@ -503,6 +507,7 @@ def from_base(request_output: PoolingRequestOutput):
return ScoringRequestOutput(
request_id=request_output.request_id,
outputs=ScoringOutput.from_base(request_output.outputs),
prompt=request_output.prompt,
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
36 changes: 29 additions & 7 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,13 @@ def __init__(
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/decoder cross-attention "
"is not implemented for "
"FlashAttentionImpl")
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
Expand Down Expand Up @@ -265,8 +267,7 @@ def forward(
layer._k_scale,
layer._v_scale,
)
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])

if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
Expand All @@ -280,6 +281,9 @@ def forward(
# Compute attention and update output up to `num_actual_tokens`.
if not attn_metadata.use_cascade:
# Regular attention (common case).

descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
Expand All @@ -290,7 +294,7 @@ def forward(
seqused_k=attn_metadata.seq_lens,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
causal=_get_causal_option(self.attn_type),
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
Expand Down Expand Up @@ -483,3 +487,21 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)


def _get_causal_option(attn_type: str) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.

Args:
attn_type (AttentionType): The type of attention being evaluated

Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)
3 changes: 2 additions & 1 deletion vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def get_computed_blocks(
self.req_to_block_hashes[request.request_id] = block_hashes

self.prefix_cache_stats.requests += 1
if request.sampling_params.prompt_logprobs is None:
if request.sampling_params and \
request.sampling_params.prompt_logprobs is None:
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
Expand Down
Loading