Skip to content

[Model][3/N] Automatic conversion of CrossEncoding model #20168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 4, 2025

Conversation

noooop
Copy link
Contributor

@noooop noooop commented Jun 27, 2025

1. Cannot implicit run as_seq_cls_model, otherwise it will cause a circular reference on is_cross_encoder_model.

e.g., Try loading: jason9693/Qwen2.5-1.5B-apeach, Its architecture is Qwen2ForSequenceClassification.

  • previous:

    model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
    if model_config.task == "embed":
    model_cls = as_embedding_model(model_cls)
    elif model_config.task == "classify":
    model_cls = as_classification_model(model_cls)
    elif model_config.task == "reward":
    model_cls = as_reward_model(model_cls)

  • We hope to use as_seq_cls_model implicitly, after this pr.

    model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
    if model_config.task == "embed":
        model_cls = as_embedding_model(model_cls)
    elif model_config.task == "classify":
        model_cls = as_seq_cls_model(model_cls)
    elif model_config.task == "reward":
        model_cls = as_reward_model(model_cls)
  • but _ModelRegistry.is_cross_encoder_model does not consider implicitly conversion
    @property
    def is_cross_encoder(self) -> bool:
        return self.registry.is_cross_encoder_model(self.architectures)  <- here
    def is_cross_encoder_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)  <- here
        return model_cls.supports_cross_encoding
@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))  <- here 

    def load_model_cls(self) -> type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)
    @staticmethod
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
            architecture=model.__name__,
            is_text_generation_model=is_text_generation_model(model),
            is_pooling_model=True,  # Can convert any model into a pooling model
            supports_cross_encoding=supports_cross_encoding(model),   <- here # this model expected as_seq_cls_model(Qwen2ForCausalLM), but actually is Qwen2ForCausalLM, 
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
            is_hybrid=is_hybrid(model),
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
            has_noops=has_noops(model),
        )
  • When we tried to add a task parameter to registry.is_cross_encoder_model
    def _get_preferred_task(
        self,
        architectures: list[str],
        supported_tasks: set[_ResolvedTask],
    ) -> Optional[_ResolvedTask]:
        model_id = self.model
        if get_pooling_config(model_id, self.revision):
            return "embed"
        if self.registry.is_cross_encoder_model(architectures):  <- here
            return "classify"
        if self.registry.is_transcription_model(architectures):
            return "transcription"

But it need to use is_cross_encoder_model to get_preferred_task

  • A circular reference occurred

The modifying inspect_model_cls and _get_preferred_task are extremely complex, let's try not to touch them.

2. what is the actual purpose of is_cross_encoder_model

pooling now divides into three tasks

"pooling": ["embed", "classify", "reward"],

Among them, the "reward" is very easy to distinguish from "embed" and "classify", so we can exclude it first.

after #19978

When task_option == "embed" or *ForSequenceClassification & num_labels == 1, then allows users to use the score API.

  • The score calculation method for embed is embedding cosine distance.
  • The score calculation method for embed is from classification head for num_labels == 1

so the purpose of is_cross_encoder_model is to ensure the correct scoring calculation method

Redirecting *ForSequenceClassification to *ForCausalLM makes things complicated.

e.g.

"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),

At this time, *ForCausalLM might be used for classify and embed, need to pass the task parameter to distinguish. If parsing fails, incorrect calculation methods will be used and wrong results will be obtained.

The explicit (not implicit & automatic) conversion *ForCausalLM to *ForSequenceClassification will make things easier.

e.g.

Qwen2ForSequenceClassification= as_seq_cls_model(Qwen2ForCausalLM)

"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"),

All problems are solved effortlessly.

3. make explicit conversion code look good

  • Solution 1
Qwen2ForSequenceClassification= as_seq_cls_model(Qwen2ForCausalLM)
  • Solution 2
class Qwen2ForSequenceClassification(as_seq_cls_model(Qwen2ForCausalLM)):
    pass
  • Solution 3
    use more hacky way, e.g.
"Qwen2ForSequenceClassification": ("qwen2", "as_seq_cls_model+Qwen2ForCausalLM"),

Using the above syntax sugar will automatically explicitly convert.

  • Solution 4
    I am not very satisfied with the methods above, and I look forward to better ones.

4. This pr actually allows all ForCausalLM to support corresponding ForSequenceClassification.

Do we really need to list out all Auto-converted architectures?

Shall we consider the above hacky way.

If all ForCausalLM automatically add corresponding ForSequenceClassification, registration tests will be unhappy because there are no corresponding hf repositories for load testing.

5. Should we retire the classify task because its naming is slightly inconsistent with actual usage?

  • is_cross_encoder_model is used to automatically distinguish between "embed" and "classify" tasks, in _get_preferred_task.
    def _get_preferred_task(
        self,
        architectures: list[str],
        supported_tasks: set[_ResolvedTask],
    ) -> Optional[_ResolvedTask]:
        model_id = self.model
        if get_pooling_config(model_id, self.revision):
            return "embed"
        if self.registry.is_cross_encoder_model(architectures):  <- here
            return "classify"
  • is_cross_encoder_model is used to distinguish between _cross_encoding_score and _embedding_score in LLM. score.

now basically task=="classify"/ the architecture name is *ForSequenceClassification / is_cross_encoder is True are equal.

so task=="classify" naming is slightly inconsistent with actual usage

vllm/vllm/entrypoints/llm.py

Lines 1310 to 1324 in 3c545c0

if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, input_text_1,
input_text_2,
truncate_prompt_tokens, use_tqdm,
lora_request,
prompt_adapter_request)
else:
return self._embedding_score(
tokenizer,
input_text_1, # type: ignore[arg-type]
input_text_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)

this situation is_cross_encoder_model and task == "classify" are the same

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Online convert ForCausalLM into ForSequenceClassification model.

  • from_2_way_softmax:
    • Qwen3ForCausalLM
      • Qwen3-Reranker
    • Qwen2ForCausalLM
      • mxbai-rerank-v2
  • hitchhike
    • fix max_model_len in tests/models/language/pooling/test_embedding.py
    • fix NotImplementedError in tests/models/language/pooling/test_gte.py
    • Keep set served_model_name before maybe_model_redirect(self.model)

FIX (partial) #20461

Test Plan

pytest tests/models/language/pooling/test_qwen3_reranker.py
pytest tests/models/language/pooling/test_mxbai_rerank.py

Test Result

pass

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @noooop, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request, part of a series, focuses on enabling automatic conversion of ForCausalLM models to support ForSequenceClassification tasks. Specifically, it adds the necessary code to allow the Gemma model to be automatically adapted for sequence classification, expanding its utility within the framework.

Highlights

  • Model Conversion: Implemented automatic conversion for the Gemma model, allowing GemmaForCausalLM to function as GemmaForSequenceClassification by leveraging the as_seq_cls_model adapter.
  • Model Registration: Registered the newly created GemmaForSequenceClassification class within the _MODELS registry, making it discoverable and usable by the system for sequence classification tasks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces automatic conversion of CrossEncoding models and adds GemmaForSequenceClassification to the model registry. The changes involve modifying gemma.py and registry.py to support this new functionality.

@noooop
Copy link
Contributor Author

noooop commented Jun 27, 2025

@DarkLight1337 @maxdebayser @22quinn

1-2 Add some background information, and how the problem arose.

  1. Cannot implicit run as_seq_cls_model, otherwise it will cause a circular reference on is_cross_encoder_model.
  2. what is the actual purpose of is_cross_encoder_model

3-5 are some issues that need to be discussed to get the best solution.

  1. make explicit conversion code look good
  2. This pr actually allows all ForCausalLM to support corresponding ForSequenceClassification.
    Do we really need to list out all Auto-converted architectures?
  3. Should we retire the classify task because its naming is slightly inconsistent with actual usage?

@DarkLight1337
Copy link
Member

From a quick search, it seems that ModelConfig.is_cross_encoder is only used to switch between using the pooling results directly to get the score (assuming that pooled outputs are classification scores) vs. applying cosine similarity on the pooling results (assuming that the pooled outputs are embeddings).

@maxdebayser
Copy link
Contributor

Do we even need is_cross_encoder anymore? If "score" is just an API and no longer a task, then when "score" is called:

  • if the model task is embed, the prompts will be passed individually to the model and the results will be computed with cosine similarity
  • if the model task is classify, the prompts will be passed pairwise to the tokenizer and the result will be the score returned by the model on the tokenizer output.

@noooop
Copy link
Contributor Author

noooop commented Jun 28, 2025

Do we even need is_cross_encoder anymore? If "score" is just an API and no longer a task, then when "score" is called:

  • is_cross_encoder_model is used to automatically distinguish between "embed" and "classify" tasks, in _get_preferred_task.
    def _get_preferred_task(
        self,
        architectures: list[str],
        supported_tasks: set[_ResolvedTask],
    ) -> Optional[_ResolvedTask]:
        model_id = self.model
        if get_pooling_config(model_id, self.revision):
            return "embed"
        if self.registry.is_cross_encoder_model(architectures):  <- here
            return "classify"

The modifying inspect_model_cls and _get_preferred_task are extremely complex, let's try not to touch them.

This makes it impossible for us to completely remove the is_cross_encoder_model.

@noooop
Copy link
Contributor Author

noooop commented Jun 30, 2025

Let's first discuss two fundamental issues.

  1. Is it better to use explicit (not implicit & automatic) conversion from *ForCausalLM to *ForSequenceClassification to avoid modifying inspect_model_cls and _get_preferred_task?
  2. Should we rename the current "classify" task because its naming is slightly inconsistent with actual usage?

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jun 30, 2025

Since any model can be converted into a classification model via as_seq_cls_model, if you let the adapter support cross encoding, then I think we should consider any model to support cross-encoding without having to check is_cross_encoder. During inference, we can use is_cross_encoder to check the architecture that is finally used in order to switch between the different modes of Score API.

@noooop
Copy link
Contributor Author

noooop commented Jun 30, 2025

Although all models can be converted into a classification model via as_seq_cls_model, the corresponding weights must actually be from a *ForSequenceClassification model; otherwise, an error will occur during the loading phase.

as_seq_cls_model actually avoids duplicating code by not having each ForCausalLM implement ForSequenceClassification.

@DarkLight1337
Copy link
Member

the corresponding weights must actually be from a *ForSequenceClassification model; otherwise, an error will occur during the loading phase.

Sorry I mean that if the architecture name is *ForSequenceClassification, we can assume it supports cross encoding

@noooop
Copy link
Contributor Author

noooop commented Jun 30, 2025

Sorry I mean that if the architecture name is *ForSequenceClassification, we can assume it supports cross encoding

now basically task=="classify"/ the architecture name is *ForSequenceClassification / is_cross_encoder is True are equal.

so task=="classify" naming is slightly inconsistent with actual usage

@DarkLight1337
Copy link
Member

Yes. I think scoring can also be considered as a type of classification task. We can adjust the naming in another PR

@noooop
Copy link
Contributor Author

noooop commented Jun 30, 2025

Let's first discuss two fundamental issues.

  1. Is it better to use explicit (not implicit & automatic) conversion from *ForCausalLM to *ForSequenceClassification to avoid modifying inspect_model_cls and _get_preferred_task?
  2. Should we rename the current "classify" task because its naming is slightly inconsistent with actual usage?

@maxdebayser

I look forward to hearing your thoughts.

@maxdebayser
Copy link
Contributor

Let me see if I got things right:

  1. Automatic conversion is required so that the user can load a model and pass --task classify, right?
  2. So if we require explicit conversion, that wouldn't work, right?

The modifying inspect_model_cls and _get_preferred_task are extremely complex, let's try not to touch them.
I'm not sure why this is the case. If we can treat cross encoding as a special case of classification, we could drop the the if self.registry.is_cross_encoder_model and just rely on the ForSequenceClassification pattern to determine that the preferred task is classify. Or not? Perhaps I'm not seeing all the ramifications here.

@noooop
Copy link
Contributor Author

noooop commented Jul 1, 2025

@maxdebayser

  1. Automatic conversion is required so that the user can load a model and pass --task classify, right?

Theoretically, this (series of) PR could allow all *ForCausalLM models to automatically have *ForSequenceClassification implementation.

  1. So if we require explicit conversion, that wouldn't work, right?

Using implicit as_seq_cls_model would cause a circular reference.

The explicit conversion *ForCausalLM to *ForSequenceClassification will make things easier.


If we want to load Qwen2ForSequenceClassification, and using implicit as_seq_cls_model.

The routing is like this.

"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),

The code runs to _get_preferred_task

    def _get_preferred_task(
        self,
        architectures: list[str],
        supported_tasks: set[_ResolvedTask],
    ) -> Optional[_ResolvedTask]:
        model_id = self.model
        if get_pooling_config(model_id, self.revision):
            return "embed"
        if self.registry.is_cross_encoder_model(architectures):  <- here
            return "classify"
        if self.registry.is_transcription_model(architectures):
            return "transcription"

We need to infer what the task is at this time, so we don't know what the task is yet.

When we don't know that the task is classify, we cannot implicitly run as_seq_cls_model(Qwen2ForCausalLM), making is_cross_encoder_model(architectures) true.

Using implicit as_seq_cls_model would cause a circular reference.

The explicit conversion *ForCausalLM to *ForSequenceClassification will make things easier.

e.g.

Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM)

"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"),

  1. I think you shouldn't delete is_cross_encoder_model to avoid encountering support for cross_encoder where the model name isn't *ForSequenceClassification in the future.

@maxdebayser
Copy link
Contributor

maxdebayser commented Jul 1, 2025

Oh, I see, the explicit conversion would be done at the moment where the converted class is added to the registry. Thanks for the clarification.

I was missing the context from these two PRs: #19260 and #19675, I think now I understand better what you're trying to do.

Let me see if I got it right:

  1. The Qwen PR allows the user to do
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
  1. You realized that this could be done automatically instead of having to write a class like Qwen3ForSequenceClassification for each architecture.

  2. But then question is: how can _get_preferred_task deduce that the correct task is classify if is_cross_encoder doesn't consider the adapted class?

Is that an accurate description of the dilemma?

If that is correct, I would propose the following: the preferred task should never be deduced from the adapted class. Instead the user should always set the task explicitly. So then the command line would be:

vllm serve Qwen/Qwen3-Reranker-0.6B --task classify --hf_overrides '{"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'

Internally a Qwen3ForSequenceClassification class will be created by as_seq_cls_model but it wouldn't need to be added explicitly to the registry.

@noooop
Copy link
Contributor Author

noooop commented Jul 2, 2025

  • You realized that this could be done automatically instead of having to write a class like Qwen3ForSequenceClassification for each architecture.

My problem is simpler. Ordinary classification models also encounter this issue, for example, in repositories commonly used for testing classifications: jason9693/Qwen2.5-1.5B-apeach

In the main code,

model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
model_cls = as_embedding_model(model_cls)
elif model_config.task == "classify":
model_cls = as_classification_model(model_cls)
elif model_config.task == "reward":
model_cls = as_reward_model(model_cls)

  • using model_cls = as_classification_model(model_cls) implicitly converts Qwen2ForCausalLM into Qwen2ForSequenceClassification.

  • in vllm/model_executor/models/registry.py: "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),

But this PR directly replacing as_classification_model with as_seq_cls_model doesn't work and will cause the cyclic dependencies mentioned above.

Explicit conversion can avoid the problem of circular dependencies.

Qwen2ForSequenceClassification= as_seq_cls_model(Qwen2ForCausalLM)

"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"),

@noooop
Copy link
Contributor Author

noooop commented Jul 3, 2025

Finally, Let's discuss two more implementation details.

  1. make explicit conversion code look good
  • Solution 1
Qwen2ForSequenceClassification= as_seq_cls_model(Qwen2ForCausalLM)
  • Solution 2
class Qwen2ForSequenceClassification(as_seq_cls_model(Qwen2ForCausalLM)):
    pass
  • Solution 3
    use more hacky way, e.g.
"Qwen2ForSequenceClassification": ("qwen2", "as_seq_cls_model+Qwen2ForCausalLM"),

Using the above syntax sugar will automatically explicitly convert.

  1. This pr actually allows all ForCausalLM to support corresponding ForSequenceClassification.

Do we really need to list out all Auto-converted architectures?

Shall we consider the above hacky way.

If all ForCausalLM automatically add corresponding ForSequenceClassification, registration tests will be unhappy because there are no corresponding hf repositories for load testing.


Perhaps the simplest way is already sufficient.

Qwen2ForSequenceClassification= as_seq_cls_model(Qwen2ForCausalLM)

"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"),

Copy link

mergify bot commented Jul 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @noooop.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 3, 2025
@noooop noooop closed this Jul 4, 2025
@noooop noooop force-pushed the as_seq_cls_model branch from 0298ff8 to a7bab0c Compare July 4, 2025 05:05
Signed-off-by: wang.yuqi <noooop@126.com>
noooop added 2 commits July 4, 2025 13:51
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
@noooop noooop requested review from houseroad and hmellor as code owners July 4, 2025 08:07
@noooop
Copy link
Contributor Author

noooop commented Jul 4, 2025

@DarkLight1337

ready to review

Online convert ForCausalLM into ForSequenceClassification model.

  • from_2_way_softmax:
    • Qwen3ForCausalLM
      • Qwen3-Reranker
    • Qwen2ForCausalLM
      • mxbai-rerank-v2
  • hitchhike
    • fix max_model_len in tests/models/language/pooling/test_embedding.py
    • fix NotImplementedError in tests/models/language/pooling/test_gte.py
    • Keep set served_model_name before maybe_model_redirect(self.model)

Signed-off-by: wang.yuqi <noooop@126.com>
@mergify mergify bot added the documentation Improvements or additions to documentation label Jul 4, 2025
noooop added 2 commits July 4, 2025 16:47
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
@noooop
Copy link
Contributor Author

noooop commented Jul 4, 2025

FAILED models/language/pooling/test_scoring.py::test_cross_encoder_1_to_1[cross-encoder/ms-marco-MiniLM-L-6-v2] - assert
FAILED models/language/pooling/test_scoring.py::test_cross_encoder_1_to_N[cross-encoder/ms-marco-MiniLM-L-6-v2] - assert
FAILED models/language/pooling/test_scoring.py::test_cross_encoder_N_to_N[cross-encoder/ms-marco-MiniLM-L-6-v2] - assert

These tests cannot be easily fixed @maxdebayser

@vllm-bot vllm-bot merged commit 2e26f91 into vllm-project:main Jul 4, 2025
15 of 17 checks passed
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged to unblock CI

sfeng33 pushed a commit to sfeng33/vllm that referenced this pull request Jul 6, 2025
@DarkLight1337 DarkLight1337 added this to the v0.9.2 milestone Jul 6, 2025
huydhn pushed a commit to huydhn/vllm that referenced this pull request Jul 8, 2025
Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation qwen Related to Qwen models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants