Skip to content

[Fix] Support cls pooling in ModernBertPooler #20067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 25, 2025

Conversation

lsz05
Copy link
Contributor

@lsz05 lsz05 commented Jun 25, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

#16648 implemented ModernBERT. However, it supported only mean pooling, causing some models implemented with cls pooling don't work properly.

This PR fixes the issue by loading pooling method from the model config, and then apply it in ModernBertPooler.

Note that I implement only cls and mean pooling, as huggingface transformers' implementation doesn't support choices other than cls and mean. (https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/modernbert/configuration_modernbert.py#L93)

Test Plan

Test cases are designed with reference to #16648 (comment).

test code
import pytest
import torch
from sentence_transformers import CrossEncoder

from vllm import LLM


def test_modernbert_reranker(model_name):
    st_model = CrossEncoder(model_name, model_kwargs={"torch_dtype": torch.float32})
    vllm_model = LLM(model_name, task="score", dtype="float32")

    sentences = [
        ("ping", "pong"),
        ("ping", "pong" * 16),
        ("ping", "pong" * 24),
        ("ping", "pong" * 32),
        ("ping", "pong" * 48),
        ("ping", "pong" * 64),
        ("ping", "pong" * 128),
    ]

    st_scores = st_model.predict(sentences)

    texts_1 = [x[0] for x in sentences]
    texts_2 = [x[1] for x in sentences]
    outputs = vllm_model.score(texts_1, texts_2)
    vllm_scores = [output.outputs.score for output in outputs]

    def test_close(s1, s2):
        return float(s1) == pytest.approx(float(s2), rel=0.01)

    print(
        model_name,
        ":\t",
        [test_close(st_scores[i], vllm_scores[i]) for i in range(len(st_scores))],
    )


if __name__ == "__main__":
    model_names = [
        "Alibaba-NLP/gte-reranker-modernbert-base",  # mean pooling
        "cl-nagoya/ruri-v3-reranker-310m",  # cls pooling
    ]
    for model_name in model_names:
        test_modernbert_reranker(model_name)

Test Result

Before this PR:

Alibaba-NLP/gte-reranker-modernbert-base :       [True, True, True, True, False, False, False]
cl-nagoya/ruri-v3-reranker-310m :        [False, False, False, False, False, False, False]

After this PR:

Alibaba-NLP/gte-reranker-modernbert-base :       [True, True, True, True, False, False, False]
cl-nagoya/ruri-v3-reranker-310m :        [True, True, True, True, True, False, False]

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @lsz05, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical issue in the ModernBertPooler by introducing support for cls pooling, alongside the existing mean pooling. It enables the pooling method to be configured dynamically based on the model's settings, ensuring compatibility with a wider range of ModernBERT models that rely on cls pooling for correct operation.

Highlights

  • ModernBertPooler Enhancement: The ModernBertPooler now dynamically determines the pooling method (mean or cls) by reading the classifier_pooling attribute from the model's configuration. Previously, it was hardcoded to use only mean pooling.
  • CLS Pooling Support: Implemented support for cls pooling, where the pooled output is taken from the first token's hidden state (pooled_output[0, :]). This addresses an issue where models requiring cls pooling were not functioning correctly.
  • Pooling Type Validation: Added validation to ensure that only cls or mean pooling types are accepted, raising a ValueError for any other unsupported types, aligning with Hugging Face Transformers' implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces support for 'cls' pooling in ModernBertPooler, addressing the issue where only mean pooling was supported. The changes involve loading the pooling method from the model config and applying it in the forward pass. The implementation includes handling for 'cls' and 'mean' pooling types, with a ValueError raised for unsupported types. The test plan includes test cases for both 'mean' and 'cls' pooling.

Comment on lines 273 to 274
else:
raise ValueError(f"Pooling type should be either `cls` or `mean`, but got {self.pooling_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a more descriptive error message that includes the available pooling types to guide users on the expected values.

Suggested change
else:
raise ValueError(f"Pooling type should be either `cls` or `mean`, but got {self.pooling_type}")
raise ValueError(f"Pooling type should be either `cls` or `mean`, but got {self.pooling_type}. Available pooling types are 'cls' and 'mean'.")

@@ -258,14 +258,20 @@ def __init__(self, config: ModernBertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
self.pooling_type = config.classifier_pooling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's good to see that the pooling type is being loaded from the model config. However, ensure that config.classifier_pooling is validated elsewhere to prevent unexpected values from being used. If not, consider adding a validation step here or in the config definition to ensure it's one of the supported types (cls or mean).

@noooop
Copy link
Contributor

noooop commented Jun 25, 2025

Thanks for fixing!

cc @DarkLight1337

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) June 25, 2025 09:38
lsz05 added 2 commits June 25, 2025 18:38
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
auto-merge was automatically disabled June 25, 2025 09:39

Head branch was pushed to by a user without write access

@lsz05 lsz05 force-pushed the fix/modernbert_pooling branch from f23e871 to bf343f9 Compare June 25, 2025 09:39
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 25, 2025
@noooop
Copy link
Contributor

noooop commented Jun 25, 2025

But the implementation of ModernBERT still shows a slight difference from sentence_transformers when lengths > 32.

Please help to find out the reasons.

Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
@lsz05
Copy link
Contributor Author

lsz05 commented Jun 25, 2025

But the implementation of ModernBERT still shows a slight difference from sentence_transformers when lengths > 32.

Please help to find out the reasons.

I tried to print out the outputs of each layer, transformers vs vllm.
Results are here. Slightly different.

`transformers` vs `vllm`
Model type: modernbert
Hidden size: 768
Num hidden layers: 25

Input shape: torch.Size([1, 30])
Input tokens: ['<s>', 'What', '▁is', '▁machine', '▁', 'learning', '?', '</s>', '<s>', 'Machine', '▁', 'learning', '▁is', '▁a', '▁subset']...
Base model type: ModernBertModel
Head layer found: ModernBertPredictionHead
Classifier: Linear(in_features=768, out_features=1, bias=True)

LAYER-BY-LAYER OUTPUTS:
================================================================================

Embeddings Output:
  CLS token L2 norm: 13.144588
  CLS token (first 5 dims): tensor([ 0.1957, -0.5083,  0.0802, -0.1042, -0.7673])

Layer 0 Output:
  CLS token L2 norm: 17.758429
  CLS token (first 5 dims): tensor([ 0.2088, -0.4176,  0.2003, -0.0341, -0.6268])

Layer 1 Output:
  CLS token L2 norm: 18.395269
  CLS token (first 5 dims): tensor([ 0.3858, -0.2665,  0.1996,  0.1435, -0.6992])

Layer 2 Output:
  CLS token L2 norm: 18.193880
  CLS token (first 5 dims): tensor([ 0.3238,  0.1079,  0.3410,  0.2015, -0.4926])

Layer 3 Output:
  CLS token L2 norm: 17.828152
  CLS token (first 5 dims): tensor([ 0.2019,  0.2908,  0.2132,  0.0368, -0.4293])

Layer 4 Output:
  CLS token L2 norm: 16.872932
  CLS token (first 5 dims): tensor([ 0.1666,  0.2989,  0.1783, -0.0995, -0.2977])

Layer 5 Output:
  CLS token L2 norm: 18.311413
  CLS token (first 5 dims): tensor([ 0.1853,  0.5024,  0.0683, -0.2753, -0.0633])

Layer 6 Output:
  CLS token L2 norm: 19.102503
  CLS token (first 5 dims): tensor([-0.2429,  0.4856, -0.0042, -0.2018, -0.1589])

Layer 7 Output:
  CLS token L2 norm: 19.715675
  CLS token (first 5 dims): tensor([-0.3828,  0.0979, -0.0400, -0.1928, -0.4273])

Layer 8 Output:
  CLS token L2 norm: 21.385695
  CLS token (first 5 dims): tensor([-0.3315, -0.1413,  0.0094, -0.1185, -0.3142])

Layer 9 Output:
  CLS token L2 norm: 20.368444
  CLS token (first 5 dims): tensor([-0.6293,  0.0436, -0.2129, -0.1568, -0.1966])

Layer 10 Output:
  CLS token L2 norm: 20.373852
  CLS token (first 5 dims): tensor([-0.2535,  0.0486,  0.0441, -0.2005, -0.2305])

Layer 11 Output:
  CLS token L2 norm: 22.784698
  CLS token (first 5 dims): tensor([ 0.0795,  0.0065,  0.0634, -0.4684, -0.2879])

Layer 12 Output:
  CLS token L2 norm: 22.311489
  CLS token (first 5 dims): tensor([ 0.2256,  0.6219, -0.2107, -0.0330, -0.1838])

Layer 13 Output:
  CLS token L2 norm: 21.946304
  CLS token (first 5 dims): tensor([ 0.5228,  1.4983, -0.1146,  0.0608,  0.2115])

Layer 14 Output:
  CLS token L2 norm: 23.971558
  CLS token (first 5 dims): tensor([0.8149, 1.5493, 0.0458, 0.2478, 0.2798])

Layer 15 Output:
  CLS token L2 norm: 21.835360
  CLS token (first 5 dims): tensor([0.5714, 1.8345, 0.2864, 0.1760, 0.0625])

Layer 16 Output:
  CLS token L2 norm: 23.511017
  CLS token (first 5 dims): tensor([ 0.9932,  2.5996, -0.1660,  0.4263, -0.3848])

Layer 17 Output:
  CLS token L2 norm: 29.513973
  CLS token (first 5 dims): tensor([ 0.5083,  2.9969,  0.4608,  0.1687, -0.7423])

Layer 18 Output:
  CLS token L2 norm: 34.840321
  CLS token (first 5 dims): tensor([ 0.5436,  3.7001, -0.2022,  0.1846, -0.4819])

Layer 19 Output:
  CLS token L2 norm: 41.959427
  CLS token (first 5 dims): tensor([ 0.8146,  4.2985,  1.2059,  0.0423, -1.3452])

Layer 20 Output:
  CLS token L2 norm: 48.433048
  CLS token (first 5 dims): tensor([ 0.4837,  4.6051,  1.2632,  1.2908, -1.9856])

Layer 21 Output:
  CLS token L2 norm: 58.428329
  CLS token (first 5 dims): tensor([-0.1005,  5.4263,  0.3762,  1.1651, -2.2596])

Layer 22 Output:
  CLS token L2 norm: 76.530060
  CLS token (first 5 dims): tensor([-0.0858,  6.6375, -0.0907,  0.4788, -3.5127])

Layer 23 Output:
  CLS token L2 norm: 87.558037
  CLS token (first 5 dims): tensor([-0.3373,  7.2413,  0.3764,  1.1714, -4.2528])

Layer 24 Output:
  CLS token L2 norm: 79.475433
  CLS token (last layer, first 10 dims): tensor([-0.4380,  8.2350,  0.2989,  1.1986, -2.9499,  1.3670, -0.0452, -1.0086,
         0.9013,  0.3143])

Final Norm Output:
  CLS token L2 norm: 33.906776
  CLS token (first 5 dims): tensor([-0.1496,  4.0176,  0.1019,  0.4079, -1.3001])

Head Layer Output:
  CLS token L2 norm: 27.641144
  CLS token (first 5 dims): tensor([ 1.5874, -0.3208, -0.9166, -1.1039, -0.1872])

Classifier Output:
  Logit value: 6.614357

FINAL LOGIT: 6.614357
SIGMOID: 0.998661

================================================================================
vLLM: Layer-by-Layer Trace with Hooks
================================================================================
...

vLLM final score: 0.998661
Calculated logit: 6.614434

vLLM LAYER-BY-LAYER OUTPUTS:
================================================================================

vLLM Embeddings Output:
  CLS token L2 norm: 13.144588
  CLS token (first 5 dims): tensor([ 0.1957, -0.5083,  0.0802, -0.1042, -0.7673])

vLLM Layer 0 Output:
  CLS token L2 norm: 17.758410
  CLS token (first 5 dims): tensor([ 0.2088, -0.4176,  0.2003, -0.0341, -0.6268])

vLLM Layer 1 Output:
  CLS token L2 norm: 18.395294
  CLS token (first 5 dims): tensor([ 0.3858, -0.2665,  0.1996,  0.1436, -0.6992])

vLLM Layer 2 Output:
  CLS token L2 norm: 18.193869
  CLS token (first 5 dims): tensor([ 0.3238,  0.1079,  0.3410,  0.2015, -0.4925])

vLLM Layer 3 Output:
  CLS token L2 norm: 17.828121
  CLS token (first 5 dims): tensor([ 0.2019,  0.2908,  0.2132,  0.0368, -0.4293])

vLLM Layer 4 Output:
  CLS token L2 norm: 16.872902
  CLS token (first 5 dims): tensor([ 0.1666,  0.2990,  0.1783, -0.0995, -0.2977])

vLLM Layer 5 Output:
  CLS token L2 norm: 18.311504
  CLS token (first 5 dims): tensor([ 0.1854,  0.5024,  0.0683, -0.2753, -0.0632])

vLLM Layer 6 Output:
  CLS token L2 norm: 19.102507
  CLS token (first 5 dims): tensor([-0.2428,  0.4856, -0.0042, -0.2018, -0.1589])

vLLM Layer 7 Output:
  CLS token L2 norm: 19.715733
  CLS token (first 5 dims): tensor([-0.3827,  0.0979, -0.0401, -0.1928, -0.4273])

vLLM Layer 8 Output:
  CLS token L2 norm: 21.385784
  CLS token (first 5 dims): tensor([-0.3314, -0.1413,  0.0094, -0.1185, -0.3142])

vLLM Layer 9 Output:
  CLS token L2 norm: 20.368574
  CLS token (first 5 dims): tensor([-0.6293,  0.0436, -0.2129, -0.1569, -0.1966])

vLLM Layer 10 Output:
  CLS token L2 norm: 20.374107
  CLS token (first 5 dims): tensor([-0.2534,  0.0486,  0.0442, -0.2005, -0.2306])

vLLM Layer 11 Output:
  CLS token L2 norm: 22.785034
  CLS token (first 5 dims): tensor([ 0.0795,  0.0064,  0.0635, -0.4683, -0.2880])

vLLM Layer 12 Output:
  CLS token L2 norm: 22.311720
  CLS token (first 5 dims): tensor([ 0.2256,  0.6218, -0.2107, -0.0330, -0.1839])

vLLM Layer 13 Output:
  CLS token L2 norm: 21.946621
  CLS token (first 5 dims): tensor([ 0.5228,  1.4983, -0.1146,  0.0608,  0.2114])

vLLM Layer 14 Output:
  CLS token L2 norm: 23.971849
  CLS token (first 5 dims): tensor([0.8149, 1.5493, 0.0458, 0.2479, 0.2798])

vLLM Layer 15 Output:
  CLS token L2 norm: 21.835527
  CLS token (first 5 dims): tensor([0.5713, 1.8344, 0.2865, 0.1760, 0.0625])

vLLM Layer 16 Output:
  CLS token L2 norm: 23.510683
  CLS token (first 5 dims): tensor([ 0.9931,  2.5994, -0.1659,  0.4263, -0.3849])

vLLM Layer 17 Output:
  CLS token L2 norm: 29.513456
  CLS token (first 5 dims): tensor([ 0.5083,  2.9967,  0.4608,  0.1688, -0.7423])

vLLM Layer 18 Output:
  CLS token L2 norm: 34.839432
  CLS token (first 5 dims): tensor([ 0.5435,  3.6997, -0.2022,  0.1846, -0.4820])

vLLM Layer 19 Output:
  CLS token L2 norm: 41.958763
  CLS token (first 5 dims): tensor([ 0.8144,  4.2982,  1.2060,  0.0423, -1.3455])

vLLM Layer 20 Output:
  CLS token L2 norm: 48.433121
  CLS token (first 5 dims): tensor([ 0.4835,  4.6047,  1.2633,  1.2908, -1.9859])

vLLM Layer 21 Output:
  CLS token L2 norm: 58.428425
  CLS token (first 5 dims): tensor([-0.1005,  5.4260,  0.3762,  1.1649, -2.2600])

vLLM Layer 22 Output:
  CLS token L2 norm: 76.530487
  CLS token (first 5 dims): tensor([-0.0858,  6.6374, -0.0908,  0.4788, -3.5131])

vLLM Layer 23 Output:
  CLS token L2 norm: 87.558456
  CLS token (first 5 dims): tensor([-0.3373,  7.2414,  0.3762,  1.1714, -4.2532])

vLLM Layer 24 Output:
  CLS token L2 norm: 79.475525
  CLS token (last layer, first 10 dims): tensor([-0.4378,  8.2351,  0.2987,  1.1986, -2.9504,  1.3669, -0.0453, -1.0086,
         0.9013,  0.3142])

vLLM Final Norm Output:
  CLS token L2 norm: 33.906750
  CLS token (first 5 dims): tensor([-0.1495,  4.0176,  0.1019,  0.4079, -1.3003])

vLLM Classifier Output:
  Shape: torch.Size([1, 1])
  Logit value: tensor([[6.6145]])

vLLM Pooler Output:
  Logit value: 0.9986609220504761


================================================================================
COMPARISON: Transformers vs vLLM vs Sentence-Transformers
================================================================================

Final Outputs:
  Transformers logit: 6.614357
  vLLM logit: 6.614434
  Difference (Trans-vLLM): 0.000077
  Sentence-Transformers logit: 6.614345
  Difference (Trans-ST): 0.000012

Detailed comparisons:

Embeddings comparison:
  Transformers L2 norm: 13.144588
  vLLM L2 norm: 13.144588
  L2 difference: 0.000000
  ST L2 norm: 13.144588
  L2 difference (Trans-ST): 0.000001

Head layer comparison:
  Transformers L2 norm: 27.641144
  ST L2 norm: 27.641144
  L2 difference (Trans-ST): 0.000020

Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
@noooop
Copy link
Contributor

noooop commented Jun 25, 2025

Results are here. Slightly different.

you might consider using a longer input,
from my experience, the difference when using fp32 should be very small
weird

You also need to fix pre-commit.

@lsz05
Copy link
Contributor Author

lsz05 commented Jun 25, 2025

Results are here. Slightly different.

you might consider using a longer input, from my experience, the difference when using fp32 should be very small weird

You also need to fix pre-commit.

I use fp32 for transformers, sentence-transformers and vllm.
Differences are very slight and I don't know why.

Script: https://gist.github.com/lsz05/a8820632f8d97aee2b6533d97252839d
Output: https://gist.github.com/lsz05/50a65242f602605eed955e8dd8b07cd2

@mgoin mgoin merged commit 23a04e0 into vllm-project:main Jun 25, 2025
72 checks passed
m-misiura pushed a commit to m-misiura/vllm that referenced this pull request Jun 26, 2025
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
gmarinho2 pushed a commit to gmarinho2/vllm that referenced this pull request Jun 26, 2025
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
Signed-off-by: Will Eaton <weaton@redhat.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
wwl2755-google pushed a commit to wwl2755-google/vllm that referenced this pull request Jul 1, 2025
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants