Skip to content

[Model] support modernbert #16648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,11 @@ If your model is not in the above list, we will try to automatically convert the
* `BAAI/bge-reranker-v2-m3`, etc.
*
*
- * `ModernBertForSequenceClassification`
* ModernBert-based
* `Alibaba-NLP/gte-reranker-modernbert-base`, etc.
*
*
:::

(supported-mm-models)=
Expand Down
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ def check_available_online(
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
"ModernBertForSequenceClassification":
_HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base",
min_transformers_version="4.49"),
}

_MULTIMODAL_EXAMPLE_MODELS = {
Expand Down
325 changes: 325 additions & 0 deletions vllm/model_executor/models/modernbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple

import torch
from torch import nn
from transformers import ModernBertConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsCrossEncoding
from .utils import WeightsMapper, maybe_prefix


class ModernBertEmbeddings(nn.Module):

def __init__(self, config: ModernBertConfig):

super().__init__()
self.config = config
self.tok_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps,
bias=config.norm_bias)

def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds:
return self.norm(inputs_embeds)
else:
inputs_embeds = self.tok_embeddings(input_ids)
embeddings = self.norm(inputs_embeds)
return embeddings


class ModernBertRotaryEmbedding(RotaryEmbedding):

def __init__(self, config: ModernBertConfig, head_size: int, dim: int,
base: float):
super().__init__(
head_size=head_size,
rotary_dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
is_neox_style=True,
dtype=torch.float16)
self.config = config


class ModernBertAttention(nn.Module):

def __init__(self,
config: ModernBertConfig,
layer_id: Optional[int] = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
self.deterministic_flash_attn = config.deterministic_flash_attn
self.num_heads = config.num_attention_heads
assert self.num_heads % tp_size == 0
self.head_dim = config.hidden_size // config.num_attention_heads
self.all_head_size = self.head_dim * self.num_heads
self.scaling = self.head_dim**-0.5
self.Wqkv = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.num_heads,
bias=config.attention_bias,
)

if layer_id % config.global_attn_every_n_layers != 0:
self.local_attention = (config.local_attention // 2,
config.local_attention // 2)
else:
self.local_attention = (-1, -1)

rope_theta = config.global_rope_theta
if self.local_attention != (
-1, -1) and config.local_rope_theta is not None:
rope_theta = config.local_rope_theta
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
head_size=self.head_dim,
dim=self.head_dim,
base=rope_theta)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
attn_type=AttentionType.ENCODER_ONLY)
self.Wo = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=config.attention_bias)

def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_outputs = self.attn(q, k, v)
hidden_states = attn_outputs
hidden_states, _ = self.Wo(hidden_states)
return hidden_states


class ModernBertMLP(nn.Module):

def __init__(self, config: ModernBertConfig):
super().__init__()
self.config = config
self.Wi = nn.Linear(config.hidden_size,
int(config.intermediate_size) * 2,
bias=config.mlp_bias)
self.act = nn.GELU()
self.Wo = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=config.mlp_bias)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
return self.Wo(self.act(input) * gate)[0]


class ModernBertLayer(nn.Module):

def __init__(self,
config: ModernBertConfig,
prefix: str = "",
layer_id: Optional[int] = None):
super().__init__()
self.config = config
if layer_id == 0:
self.attn_norm = nn.Identity()
else:
self.attn_norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
self.mlp_norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)
self.mlp = ModernBertMLP(config)

def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
):
attn_outputs = self.attn(self.attn_norm(hidden_states),
position_ids=position_ids)
hidden_states = hidden_states + attn_outputs
mlp_output = self.mlp(self.mlp_norm(hidden_states))
hidden_states = hidden_states + mlp_output
return hidden_states


class ModernBertEncoderLayer(nn.Module):

def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.layers = nn.ModuleList([
ModernBertLayer(config=config, layer_id=layer_id)
for layer_id in range(config.num_hidden_layers)
])

def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, position_ids)
return hidden_states


@support_torch_compile
class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."})

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.embeddings = ModernBertEmbeddings(config)
self.encoder_layer = ModernBertEncoderLayer(vllm_config)
self.final_norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
inputs_embeds=inputs_embeds)

outputs = self.encoder_layer(
hidden_states=hidden_states,
position_ids=position_ids,
)
norm_outputs = self.final_norm(outputs)
return norm_outputs


class ModernBertPooler(nn.Module):

def __init__(self, config: ModernBertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
self.act = nn.GELU()
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
pooled_output = hidden_states
pooled_output = pooled_output.mean(dim=0, keepdim=False)
pooled_output = self.norm(self.act(self.dense(pooled_output)))
return pooled_output


class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = CrossEncodingPooler(config, self.classifier,
ModernBertPooler(config))

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

self_weights = []

def weight_filter():
for name, weight in weights:
if name.startswith("model."):
yield name[len("model."):], weight
else:
self_weights.append((name, weight))

self.model.load_weights(weight_filter())

params_dict = dict(self.named_parameters())

for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name.startswith("head"):
param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def forward(
self,
input_ids: Optional[torch.LongTensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=positions,
)
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
"RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
}

_MULTIMODAL_MODELS = {
Expand Down