Skip to content

Commit 54557b2

Browse files
xsank唯勤
authored andcommitted
[Model] support modernbert (vllm-project#16648)
Signed-off-by: 唯勤 <xsank.mz@alibaba-inc.com> Co-authored-by: 唯勤 <xsank.mz@alibaba-inc.com>
1 parent ac7d663 commit 54557b2

File tree

4 files changed

+335
-0
lines changed

4 files changed

+335
-0
lines changed

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,11 @@ If your model is not in the above list, we will try to automatically convert the
740740
* `BAAI/bge-reranker-v2-m3`, etc.
741741
*
742742
*
743+
- * `ModernBertForSequenceClassification`
744+
* ModernBert-based
745+
* `Alibaba-NLP/gte-reranker-modernbert-base`, etc.
746+
*
747+
*
743748
:::
744749

745750
(supported-mm-models)=

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ def check_available_online(
275275
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
276276
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
277277
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
278+
"ModernBertForSequenceClassification":
279+
_HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base",
280+
min_transformers_version="4.49"),
278281
}
279282

280283
_MULTIMODAL_EXAMPLE_MODELS = {
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Iterable, Optional, Set, Tuple
3+
4+
import torch
5+
from torch import nn
6+
from transformers import ModernBertConfig
7+
8+
from vllm.attention import Attention, AttentionType
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import VllmConfig
11+
from vllm.distributed import get_tensor_model_parallel_world_size
12+
from vllm.model_executor.layers.linear import (QKVParallelLinear,
13+
RowParallelLinear)
14+
from vllm.model_executor.layers.pooler import CrossEncodingPooler
15+
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
16+
from vllm.model_executor.layers.vocab_parallel_embedding import (
17+
VocabParallelEmbedding)
18+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
19+
from vllm.model_executor.pooling_metadata import PoolingMetadata
20+
from vllm.sequence import IntermediateTensors, PoolerOutput
21+
22+
from .interfaces import SupportsCrossEncoding
23+
from .utils import WeightsMapper, maybe_prefix
24+
25+
26+
class ModernBertEmbeddings(nn.Module):
27+
28+
def __init__(self, config: ModernBertConfig):
29+
30+
super().__init__()
31+
self.config = config
32+
self.tok_embeddings = VocabParallelEmbedding(config.vocab_size,
33+
config.hidden_size)
34+
self.norm = nn.LayerNorm(config.hidden_size,
35+
eps=config.layer_norm_eps,
36+
bias=config.norm_bias)
37+
38+
def forward(
39+
self,
40+
input_ids: torch.Tensor,
41+
inputs_embeds: Optional[torch.Tensor] = None,
42+
) -> torch.Tensor:
43+
if inputs_embeds:
44+
return self.norm(inputs_embeds)
45+
else:
46+
inputs_embeds = self.tok_embeddings(input_ids)
47+
embeddings = self.norm(inputs_embeds)
48+
return embeddings
49+
50+
51+
class ModernBertRotaryEmbedding(RotaryEmbedding):
52+
53+
def __init__(self, config: ModernBertConfig, head_size: int, dim: int,
54+
base: float):
55+
super().__init__(
56+
head_size=head_size,
57+
rotary_dim=dim,
58+
max_position_embeddings=config.max_position_embeddings,
59+
base=base,
60+
is_neox_style=True,
61+
dtype=torch.float16)
62+
self.config = config
63+
64+
65+
class ModernBertAttention(nn.Module):
66+
67+
def __init__(self,
68+
config: ModernBertConfig,
69+
layer_id: Optional[int] = None):
70+
super().__init__()
71+
self.config = config
72+
self.hidden_size = config.hidden_size
73+
tp_size = get_tensor_model_parallel_world_size()
74+
self.layer_id = layer_id
75+
self.deterministic_flash_attn = config.deterministic_flash_attn
76+
self.num_heads = config.num_attention_heads
77+
assert self.num_heads % tp_size == 0
78+
self.head_dim = config.hidden_size // config.num_attention_heads
79+
self.all_head_size = self.head_dim * self.num_heads
80+
self.scaling = self.head_dim**-0.5
81+
self.Wqkv = QKVParallelLinear(
82+
config.hidden_size,
83+
self.head_dim,
84+
self.num_heads,
85+
bias=config.attention_bias,
86+
)
87+
88+
if layer_id % config.global_attn_every_n_layers != 0:
89+
self.local_attention = (config.local_attention // 2,
90+
config.local_attention // 2)
91+
else:
92+
self.local_attention = (-1, -1)
93+
94+
rope_theta = config.global_rope_theta
95+
if self.local_attention != (
96+
-1, -1) and config.local_rope_theta is not None:
97+
rope_theta = config.local_rope_theta
98+
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
99+
head_size=self.head_dim,
100+
dim=self.head_dim,
101+
base=rope_theta)
102+
self.attn = Attention(self.num_heads,
103+
self.head_dim,
104+
self.scaling,
105+
prefix=f"{layer_id}.attn",
106+
attn_type=AttentionType.ENCODER_ONLY)
107+
self.Wo = RowParallelLinear(config.hidden_size,
108+
config.hidden_size,
109+
bias=config.attention_bias)
110+
111+
def forward(
112+
self,
113+
hidden_states: torch.Tensor,
114+
position_ids: Optional[torch.LongTensor] = None,
115+
) -> torch.Tensor:
116+
qkv, _ = self.Wqkv(hidden_states)
117+
q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
118+
q, k = self.rotary_emb(position_ids, q, k)
119+
attn_outputs = self.attn(q, k, v)
120+
hidden_states = attn_outputs
121+
hidden_states, _ = self.Wo(hidden_states)
122+
return hidden_states
123+
124+
125+
class ModernBertMLP(nn.Module):
126+
127+
def __init__(self, config: ModernBertConfig):
128+
super().__init__()
129+
self.config = config
130+
self.Wi = nn.Linear(config.hidden_size,
131+
int(config.intermediate_size) * 2,
132+
bias=config.mlp_bias)
133+
self.act = nn.GELU()
134+
self.Wo = RowParallelLinear(config.intermediate_size,
135+
config.hidden_size,
136+
bias=config.mlp_bias)
137+
138+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
139+
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
140+
return self.Wo(self.act(input) * gate)[0]
141+
142+
143+
class ModernBertLayer(nn.Module):
144+
145+
def __init__(self,
146+
config: ModernBertConfig,
147+
prefix: str = "",
148+
layer_id: Optional[int] = None):
149+
super().__init__()
150+
self.config = config
151+
if layer_id == 0:
152+
self.attn_norm = nn.Identity()
153+
else:
154+
self.attn_norm = nn.LayerNorm(config.hidden_size,
155+
eps=config.norm_eps,
156+
bias=config.norm_bias)
157+
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
158+
self.mlp_norm = nn.LayerNorm(config.hidden_size,
159+
eps=config.norm_eps,
160+
bias=config.norm_bias)
161+
self.mlp = ModernBertMLP(config)
162+
163+
def forward(
164+
self,
165+
hidden_states: torch.Tensor,
166+
position_ids: Optional[torch.LongTensor] = None,
167+
):
168+
attn_outputs = self.attn(self.attn_norm(hidden_states),
169+
position_ids=position_ids)
170+
hidden_states = hidden_states + attn_outputs
171+
mlp_output = self.mlp(self.mlp_norm(hidden_states))
172+
hidden_states = hidden_states + mlp_output
173+
return hidden_states
174+
175+
176+
class ModernBertEncoderLayer(nn.Module):
177+
178+
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
179+
super().__init__()
180+
config = vllm_config.model_config.hf_config
181+
self.layers = nn.ModuleList([
182+
ModernBertLayer(config=config, layer_id=layer_id)
183+
for layer_id in range(config.num_hidden_layers)
184+
])
185+
186+
def forward(
187+
self,
188+
hidden_states: torch.Tensor,
189+
position_ids: Optional[torch.LongTensor] = None,
190+
) -> torch.Tensor:
191+
for i, layer in enumerate(self.layers):
192+
hidden_states = layer(hidden_states, position_ids)
193+
return hidden_states
194+
195+
196+
@support_torch_compile
197+
class ModernBertModel(nn.Module):
198+
hf_to_vllm_mapper = WeightsMapper(
199+
orig_to_new_prefix={"layers.": "encoder_layer.layers."})
200+
201+
def __init__(
202+
self,
203+
vllm_config: VllmConfig,
204+
prefix: str = "",
205+
):
206+
super().__init__()
207+
config = vllm_config.model_config.hf_config
208+
self.config = config
209+
self.embeddings = ModernBertEmbeddings(config)
210+
self.encoder_layer = ModernBertEncoderLayer(vllm_config)
211+
self.final_norm = nn.LayerNorm(config.hidden_size,
212+
eps=config.norm_eps,
213+
bias=config.norm_bias)
214+
215+
def load_weights(self, weights: Iterable[Tuple[str,
216+
torch.Tensor]]) -> Set[str]:
217+
weights = self.hf_to_vllm_mapper.apply(weights)
218+
params_dict = dict(self.named_parameters())
219+
loaded_params: Set[str] = set()
220+
for name, loaded_weight in weights:
221+
if name.endswith(".bias") and name not in params_dict:
222+
continue
223+
param = params_dict[name]
224+
weight_loader = getattr(param, "weight_loader",
225+
default_weight_loader)
226+
weight_loader(param, loaded_weight)
227+
loaded_params.add(name)
228+
return loaded_params
229+
230+
def forward(
231+
self,
232+
input_ids: Optional[torch.LongTensor] = None,
233+
inputs_embeds: Optional[torch.Tensor] = None,
234+
position_ids: Optional[torch.LongTensor] = None,
235+
) -> torch.Tensor:
236+
if inputs_embeds is not None:
237+
hidden_states = inputs_embeds
238+
else:
239+
hidden_states = self.embeddings(input_ids=input_ids,
240+
inputs_embeds=inputs_embeds)
241+
242+
outputs = self.encoder_layer(
243+
hidden_states=hidden_states,
244+
position_ids=position_ids,
245+
)
246+
norm_outputs = self.final_norm(outputs)
247+
return norm_outputs
248+
249+
250+
class ModernBertPooler(nn.Module):
251+
252+
def __init__(self, config: ModernBertConfig):
253+
super().__init__()
254+
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
255+
config.classifier_bias)
256+
self.act = nn.GELU()
257+
self.norm = nn.LayerNorm(config.hidden_size,
258+
eps=config.norm_eps,
259+
bias=config.norm_bias)
260+
261+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
262+
pooled_output = hidden_states
263+
pooled_output = pooled_output.mean(dim=0, keepdim=False)
264+
pooled_output = self.norm(self.act(self.dense(pooled_output)))
265+
return pooled_output
266+
267+
268+
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
269+
270+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
271+
super().__init__()
272+
config = vllm_config.model_config.hf_config
273+
self.config = config
274+
self.model = ModernBertModel(vllm_config=vllm_config,
275+
prefix=maybe_prefix(prefix, "modernbert"))
276+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
277+
self._pooler = CrossEncodingPooler(config, self.classifier,
278+
ModernBertPooler(config))
279+
280+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
281+
282+
self_weights = []
283+
284+
def weight_filter():
285+
for name, weight in weights:
286+
if name.startswith("model."):
287+
yield name[len("model."):], weight
288+
else:
289+
self_weights.append((name, weight))
290+
291+
self.model.load_weights(weight_filter())
292+
293+
params_dict = dict(self.named_parameters())
294+
295+
for name, loaded_weight in self_weights:
296+
if name.startswith("classifier"):
297+
param = params_dict[name]
298+
weight_loader = getattr(param, "weight_loader",
299+
default_weight_loader)
300+
weight_loader(param, loaded_weight)
301+
if name.startswith("head"):
302+
param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
303+
weight_loader = getattr(param, "weight_loader",
304+
default_weight_loader)
305+
weight_loader(param, loaded_weight)
306+
307+
def pooler(
308+
self,
309+
hidden_states: torch.Tensor,
310+
pooling_metadata: PoolingMetadata,
311+
) -> Optional[PoolerOutput]:
312+
return self._pooler(hidden_states, pooling_metadata)
313+
314+
def forward(
315+
self,
316+
input_ids: Optional[torch.LongTensor],
317+
positions: torch.Tensor,
318+
intermediate_tensors: Optional[IntermediateTensors] = None,
319+
inputs_embeds: Optional[torch.Tensor] = None,
320+
) -> torch.Tensor:
321+
return self.model(
322+
input_ids=input_ids,
323+
inputs_embeds=inputs_embeds,
324+
position_ids=positions,
325+
)

vllm/model_executor/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@
162162
"RobertaForSequenceClassification"),
163163
"XLMRobertaForSequenceClassification": ("roberta",
164164
"RobertaForSequenceClassification"),
165+
"ModernBertForSequenceClassification": ("modernbert",
166+
"ModernBertForSequenceClassification"),
165167
}
166168

167169
_MULTIMODAL_MODELS = {

0 commit comments

Comments
 (0)