Skip to content

[Model] support modernbert #16648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,11 @@ If your model is not in the above list, we will try to automatically convert the
* `BAAI/bge-reranker-v2-m3`, etc.
*
*
- * `ModernBertForSequenceClassification`
* ModernBert-based
* `Alibaba-NLP/gte-reranker-modernbert-base`, etc.
*
*
:::

(supported-mm-models)=
Expand Down
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ def check_available_online(
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
"ModernBertForSequenceClassification":
_HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base",
min_transformers_version="4.49"),
}

_MULTIMODAL_EXAMPLE_MODELS = {
Expand Down
351 changes: 351 additions & 0 deletions vllm/model_executor/models/modernbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Iterable, Optional, Set, Tuple

import torch
from torch import nn
from transformers import ModernBertConfig

from vllm.attention import Attention, AttentionType
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsCrossEncoding
from .utils import WeightsMapper, maybe_prefix


class ModernBertEmbeddings(nn.Module):

def __init__(self, config: ModernBertConfig):

super().__init__()
self.config = config
self.tok_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps,
bias=config.norm_bias)

def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds:
return self.norm(inputs_embeds)
else:
inputs_embeds = self.tok_embeddings(input_ids)
embeddings = self.norm(inputs_embeds)
return embeddings


class ModernBertRotaryEmbedding(RotaryEmbedding):

def __init__(self, config: ModernBertConfig, head_size: int, dim: int,
base: float):
super().__init__(
head_size=head_size,
rotary_dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
is_neox_style=True,
dtype=torch.float16)
self.config = config


class ModernBertAttention(nn.Module):

def __init__(self,
config: ModernBertConfig,
layer_id: Optional[int] = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
self.deterministic_flash_attn = config.deterministic_flash_attn
self.num_heads = config.num_attention_heads
assert self.num_heads % tp_size == 0
self.head_dim = config.hidden_size // config.num_attention_heads
self.all_head_size = self.head_dim * self.num_heads
self.scaling = self.head_dim**-0.5
self.Wqkv = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.num_heads,
bias=config.attention_bias,
)

if layer_id % config.global_attn_every_n_layers != 0:
self.local_attention = (config.local_attention // 2,
config.local_attention // 2)
else:
self.local_attention = (-1, -1)

rope_theta = config.global_rope_theta
if self.local_attention != (
-1, -1) and config.local_rope_theta is not None:
rope_theta = config.local_rope_theta
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
head_size=self.head_dim,
dim=self.head_dim,
base=rope_theta)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
attn_type=AttentionType.ENCODER_ONLY)
self.Wo = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=config.attention_bias)
self.pruned_heads = set()

def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_outputs = self.attn(q, k, v)
hidden_states = attn_outputs
hidden_states, _ = self.Wo(hidden_states)
return hidden_states


class GELUActivation(nn.Module):

def __init__(self, use_gelu_python: bool = False):
super().__init__()
if use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu

def _gelu_python(self, input: torch.Tensor) -> torch.Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))

def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.act(input)


class ModernBertMLP(nn.Module):

def __init__(self, config: ModernBertConfig):
super().__init__()
self.config = config
self.Wi = nn.Linear(config.hidden_size,
int(config.intermediate_size) * 2,
bias=config.mlp_bias)
self.act = GELUActivation()
self.Wo = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=config.mlp_bias)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
return self.Wo(self.act(input) * gate)[0]


class ModernBertLayer(nn.Module):

def __init__(self,
config: ModernBertConfig,
prefix: str = "",
layer_id: Optional[int] = None):
super().__init__()
self.config = config
if layer_id == 0:
self.attn_norm = nn.Identity()
else:
self.attn_norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
self.mlp_norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)
self.mlp = ModernBertMLP(config)

@torch.compile(dynamic=True)
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.mlp(self.mlp_norm(hidden_states))

def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
):
attn_outputs = self.attn(self.attn_norm(hidden_states),
position_ids=position_ids)
hidden_states = hidden_states + attn_outputs
mlp_output = (self.compiled_mlp(hidden_states)
if self.config.reference_compile else self.mlp(
self.mlp_norm(hidden_states)))
hidden_states = hidden_states + mlp_output
return hidden_states


class ModernBertEncoderLayer(nn.Module):

def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.layers = nn.ModuleList([
ModernBertLayer(config=config, layer_id=layer_id)
for layer_id in range(config.num_hidden_layers)
])

def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, position_ids)
return hidden_states


class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."})

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.embeddings = ModernBertEmbeddings(config)
self.encoder_layer = ModernBertEncoderLayer(vllm_config)
self.final_norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)
self.gradient_checkpointing = False
self.dtype = torch.float16

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
inputs_embeds=inputs_embeds)

outputs = self.encoder_layer(
hidden_states=hidden_states,
position_ids=position_ids,
)
norm_outputs = self.final_norm(outputs)
return norm_outputs


class ModernBertPooler(nn.Module):

def __init__(self, config: ModernBertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
self.act = GELUActivation()
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
pooled_output = hidden_states
pooled_output = pooled_output.mean(dim=0, keepdim=False)
pooled_output = self.norm(self.act(self.dense(pooled_output)))
return pooled_output


class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = ModernBertModel(vllm_config,
maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = CrossEncodingPooler(config, self.classifier,
ModernBertPooler(config))

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

self_weights = []

def weight_filter():
for name, weight in weights:
if name.startswith("model."):
yield name[len("model."):], weight
else:
self_weights.append((name, weight))

self.model.load_weights(weight_filter())

params_dict = dict(self.named_parameters())

for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name.startswith("head"):
param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
positions: torch.Tensor = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(
input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
)
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
"RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
}

_MULTIMODAL_MODELS = {
Expand Down