Skip to content

[Core][Model] PrithviMAE Enablement on vLLM v1 engine (with zero kv_cache_groups) #20577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

christian-pinto
Copy link
Contributor

@christian-pinto christian-pinto commented Jul 7, 2025

This is a PR to be used for comparison with #20072 where support for attention free models was done via overloading the KVCacheManager.

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
@christian-pinto christian-pinto changed the title Open [Core][Model] PrithviMAE Enablement on vLLM v1 engine (with zero kv_cache_groups) [Core][Model] PrithviMAE Enablement on vLLM v1 engine (with zero kv_cache_groups) Jul 7, 2025
@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) v1 labels Jul 7, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @christian-pinto, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances vLLM's capability to support a broader range of models, specifically focusing on attention-free architectures like PrithviMAE. It introduces a streamlined approach for handling models that do not require a KV cache and enables direct processing of raw multimodal inputs, moving away from previous workarounds and integrating these functionalities more natively into the vLLM v1 engine.

Highlights

  • Attention-Free Model Support: Introduced comprehensive support for attention-free models, such as PrithviMAE, by allowing the vLLM engine to operate with zero KV cache groups. This involves modifications across the KV cache management, scheduler, and engine initialization to bypass KV cache allocation and related operations for these specific models.
  • Raw Multimodal Input Handling: Implemented a new mechanism for multimodal models to directly process raw input data (e.g., pixel values) without requiring tokenization or embedding lookups. This includes new interfaces (SupportsMultiModalWithRawInput), configuration flags, and adjustments to the model runner to pass raw data directly to the model's forward pass.
  • PrithviMAE Integration: Integrated the PrithviMAE model as a concrete example of an attention-free, raw multimodal input model. This involved updating its class inheritance, adjusting its multimodal data processing, and adding a dedicated test case to validate its functionality within the vLLM framework.
  • Conditional Tokenizer Initialization: Modified the engine and processor to conditionally initialize and use the tokenizer, allowing models that do not rely on text tokenization (like PrithviMAE) to skip this step, improving efficiency and flexibility.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

mergify bot commented Jul 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @christian-pinto.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces changes to enable PrithviMAE on the vLLM v1 engine with zero kv_cache_groups. The changes include modifications to several files, including adding a new test file, updating the configuration, modifying the model executor, and updating the scheduler. The code changes appear to be well-structured and address the objective of the pull request. There are a few suggestions to improve the code's robustness and readability.

Comment on lines +146 to +147
dtype="float16",
enforce_eager=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using torch.float16 instead of the string literal for specifying the dtype. This approach is more type-safe and avoids potential errors due to typos or inconsistencies.

Suggested change
dtype="float16",
enforce_eager=True,
dtype=torch.float16,
enforce_eager=True,

@@ -1120,10 +1125,10 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
return self.get_hf_config_sliding_window()

def get_vocab_size(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider providing a default value when using getattr to avoid potential AttributeError if hf_text_config does not have the vocab_size attribute. This makes the code more robust.

Suggested change
def get_vocab_size(self) -> int:
return getattr(self.hf_text_config, "vocab_size", 0)

Comment on lines 1130 to +1131
def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
return getattr(self.hf_text_config, "hidden_size", 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider providing a default value when using getattr to avoid potential AttributeError if hf_text_config does not have the hidden_size attribute. This makes the code more robust.

Suggested change
def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
return getattr(self.hf_text_config, "hidden_size", 0)
return getattr(self.hf_text_config, "hidden_size", 0)

@@ -169,7 +184,7 @@ def _parse_and_validate_multimodal_data(
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
pixel_values = torch.unbind(pixel_values, dim=0)[0]
# pixel_values = torch.unbind(pixel_values, dim=0)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is commented out. If it's no longer needed, consider removing it to reduce code clutter and improve readability. If it's temporarily disabled for debugging, add a comment explaining why and when it should be re-enabled.

Comment on lines +1059 to 1067
if curr_group_outputs:
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
)

for output in curr_group_outputs:
encoder_outputs.append(output)
for output in curr_group_outputs:
encoder_outputs.append(output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code is checking if curr_group_outputs exists before proceeding with sanity checks and processing. However, the else block is empty, which might indicate a missing handling case. If curr_group_outputs is empty, it might be necessary to handle this scenario or log a warning message to indicate that no outputs were generated.

Suggested change
if curr_group_outputs:
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
)
for output in curr_group_outputs:
encoder_outputs.append(output)
for output in curr_group_outputs:
encoder_outputs.append(output)
if curr_group_outputs:
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
)
for output in curr_group_outputs:
encoder_outputs.append(output)
else:
logger.warning("No encoder outputs generated for this group.")

Copy link

github-actions bot commented Jul 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the code in block_pool.py, kv_cache_coordinator.py, kv_cache_manager.py, kv_cache_utils.py, core.py gpu_input_batch.py and feel that it is possible to implement 0 kv cache config cleanly. I'm not sure whether the changes in scheduler.py is required.

It's quite hard for me to find all code related to kv cache manager. Can you help to put them into a separate PR if you want to continue with 0 kv cache group option?

# needed for this request.
if self.connector is not None:
self.connector.update_state_after_alloc(
if not self.vllm_config.model_config.is_attention_free:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need these lines in #20072 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a wild indentation!

zip(kv_cache_specs, available_gpu_memory)
]
#TODO: CP start from here
if vllm_config.model_config.is_attention_free:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to handle attention free model here

def get_kv_cache_config(
. Would be great if it can be achieved by another branch in addtion to is_kv_cache_type_uniform and is_kv_cache_page_size_uniform and pls tell me if further modfications are needed.

@@ -295,7 +295,9 @@ def add_request(
self.num_tokens_no_spec[req_index] = request.num_tokens

self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line does nothing if there is 0 kv cache groups.

@@ -36,7 +36,8 @@ def __init__(
enable_caching: bool,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# num_gpu_blocks can be 0 for attention free models
assert isinstance(num_gpu_blocks, int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can always have at least 1 gpu block so that we don't need to handle null block.

blocks[i].prev_free_block = blocks[i - 1]
if i < self.num_free_blocks - 1:
blocks[i].next_free_block = blocks[i + 1]
# This is 0 in attention free models
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These modifications are not needed if with at least 1 gpu block

assert num_gpu_blocks is not None and num_gpu_blocks > 0

# num_gpu_blocks can be zero for attention free models
assert num_gpu_blocks is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed if with at least 1 gpu block

@@ -246,6 +248,12 @@ def schedule(self) -> SchedulerOutput:
request.num_tokens, 0)

while True:
# This model is attention free and we do not need to allocate KVCache blocks
# for serving requests.
if self.vllm_config.model_config.is_attention_free:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this change? I think allocate_slots should always succeed,

self.verify_and_split_kv_cache_groups()
# attention free models are initialized with 0 kv_cache_groups
if len(self.kv_cache_config.kv_cache_groups) > 0:
self.verify_and_split_kv_cache_groups()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm comfortable with adding another coordinator for 0 kv cache groups and re-implement find_longest_cache_hit for it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I've observed more and more case that current find_longest_cache_hit can't handle, I'm suggesting a new KVCacheCoordinatorNoPrefixCache and use it when prefix caching is disabled. Can you sync with the author of #20661 to avoid duplication of work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christian-pinto I have introduced a KVCacheCoordinatorNoPrefixCache in this PR (#20661 ). I think it should handle your case as well. Could you give it a try?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nopperl thanks for that. Your approach solves my issue too.

@@ -327,6 +329,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
Args:
scheduler_output: The scheduler output.
"""

# nothing to be reordered when the mdoel is attention free
if self.model_config.is_attention_free:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with this change temporarily. I'll refactor this function soon to handle both attention-free case and many other unsupported cases.

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
- Improved formatting around
- made is_pooling_model a @Property in ModelConfig

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
- Remove unused functions
- merged functions not called anywhere else

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
… maanger.

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
@christian-pinto christian-pinto force-pushed the prithvi_v1_embeddings_zero_kv_cache_group branch from e0dd56a to 645d061 Compare July 9, 2025 11:08
@christian-pinto
Copy link
Contributor Author

christian-pinto commented Jul 9, 2025

@heheda12345 I have followed out your suggestions and instantiated the kv_cache config with 1 block and most of the changes I initially made are not needed. Many thanks!

Also, I have implemented the management for the attention free modes in kv_cache_utils.get_kv_cache_config() as you suggested.

Please have a look at the last commit (645d061) to see all the relevant changes.

If the zero kv_cache groups is the preferred approach compared to the overloading of the KVCacheManager, please let me know and I will have it in a separate branch and open a PR. I keep it here for the time being as it easier for me to test.

@mergify mergify bot added the new-model Requests to new models label Jul 10, 2025
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I like the current implementation of 0 kv cache group. Can you make a new PR for that?

self.verify_and_split_kv_cache_groups()
# attention free models are initialized with 0 kv_cache_groups
if len(self.kv_cache_config.kv_cache_groups) > 0:
self.verify_and_split_kv_cache_groups()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I've observed more and more case that current find_longest_cache_hit can't handle, I'm suggesting a new KVCacheCoordinatorNoPrefixCache and use it when prefix caching is disabled. Can you sync with the author of #20661 to avoid duplication of work?

output = self.collective_rpc("determine_available_memory")
return output

def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
if self.vllm_config.model_config.is_attention_free:
return [{"attention_free": KVCacheSpec(block_size=0)}]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just return an empty dict?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) needs-rebase new-model Requests to new models v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants