Skip to content

Commit 5104236

Browse files
kozistralvarobartt
andauthored
Implement the ModernBert model (#459)
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
1 parent 7f7832e commit 5104236

23 files changed

+28468
-6
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Ember, GTE and E5. TEI implements many features such as:
6666
#### Text Embeddings
6767

6868
Text Embeddings Inference currently supports Nomic, BERT, CamemBERT, XLM-RoBERTa models with absolute positions, JinaBERT
69-
model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, and MPNet.
69+
model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, MPNet, and ModernBERT.
7070

7171
Below are some examples of the currently supported models:
7272

@@ -85,6 +85,7 @@ Below are some examples of the currently supported models:
8585
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
8686
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |
8787
| N/A | 0.1B | MPNet | [sentence-transformers/all-mpnet-base-v2](https://hf.co/sentence-transformers/all-mpnet-base-v2) |
88+
| N/A | 0.4B | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) |
8889

8990
To explore the list of best performing text embeddings models, visit the
9091
[Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).

backends/candle/src/flash_attn.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@ pub(crate) fn flash_attn_varlen(
3232
softmax_scale: f32,
3333
causal: bool,
3434
window_size_left: Option<usize>,
35+
window_size_right: Option<usize>,
3536
) -> Result<Tensor, candle::Error> {
3637
let runtime_compute_cap = get_runtime_compute_cap();
3738

3839
if runtime_compute_cap == 75 {
3940
if alibi_slopes.is_some() {
4041
candle::bail!("Flash attention v1 does not support alibi");
4142
}
42-
if window_size_left.is_some() {
43+
if window_size_left.is_some() | window_size_right.is_some() {
4344
candle::bail!("Flash attention v1 does not support attention windowing");
4445
}
4546

@@ -65,7 +66,13 @@ pub(crate) fn flash_attn_varlen(
6566
{
6667
use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed};
6768

68-
let window_size_right = if causal { Some(0) } else { None };
69+
let window_size_right = if causal {
70+
Some(0)
71+
} else if window_size_right.is_some() {
72+
window_size_right
73+
} else {
74+
None
75+
};
6976

7077
let attention = if let Some(alibi_slopes) = alibi_slopes {
7178
flash_attn_varlen_alibi_windowed(

backends/candle/src/layers/layer_norm.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,84 @@
11
use candle::{DType, Device, Result, Tensor, D};
22
use candle_nn::VarBuilder;
33

4+
#[derive(Debug)]
5+
pub struct LayerNormNoBias {
6+
weight: Tensor,
7+
epsilon: f32,
8+
span: tracing::Span,
9+
}
10+
11+
impl LayerNormNoBias {
12+
pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result<Self> {
13+
Ok(Self {
14+
weight: vb
15+
.get(hidden_size, "weight")
16+
.or_else(|_| vb.get(hidden_size, "gamma"))?,
17+
epsilon,
18+
span: tracing::span!(tracing::Level::TRACE, "layer-norm-no-bias"),
19+
})
20+
}
21+
22+
pub fn forward(&self, hidden_states: &Tensor, residual: Option<&Tensor>) -> Result<Tensor> {
23+
let _enter = self.span.enter();
24+
25+
match hidden_states.device() {
26+
Device::Cpu | Device::Metal(_) => {
27+
let mut hidden_states = hidden_states.clone();
28+
if let Some(residual) = residual {
29+
hidden_states = hidden_states.add(residual)?;
30+
}
31+
let hidden_states_dtype = hidden_states.dtype();
32+
let internal_dtype = match hidden_states_dtype {
33+
DType::F16 | DType::BF16 => DType::F32,
34+
d => d,
35+
};
36+
let hidden_size = hidden_states.dim(D::Minus1)?;
37+
let hidden_states = hidden_states.to_dtype(internal_dtype)?;
38+
let mean_hidden_states =
39+
(hidden_states.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
40+
let hidden_states = hidden_states.broadcast_sub(&mean_hidden_states)?;
41+
let norm_hidden_states =
42+
(hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
43+
let hidden_states_normed = hidden_states
44+
.broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?;
45+
let hidden_states = hidden_states_normed
46+
.to_dtype(hidden_states_dtype)?
47+
.broadcast_mul(&self.weight)?;
48+
49+
Ok(hidden_states)
50+
}
51+
Device::Cuda(_) => {
52+
#[cfg(feature = "cuda")]
53+
{
54+
use candle_layer_norm::{fused_add_layer_norm, layer_norm};
55+
56+
let original_shape = hidden_states.shape();
57+
let hidden_states = hidden_states.flatten_to(D::Minus2)?;
58+
59+
let result = if let Some(residual) = residual {
60+
let residual = residual.flatten_to(D::Minus2)?;
61+
62+
let (result, _) = fused_add_layer_norm(
63+
&hidden_states,
64+
&residual,
65+
&self.weight,
66+
None,
67+
self.epsilon,
68+
)?;
69+
Ok(result)
70+
} else {
71+
layer_norm(&hidden_states, &self.weight, None, self.epsilon)
72+
}?;
73+
result.reshape(original_shape)
74+
}
75+
#[cfg(not(feature = "cuda"))]
76+
candle::bail!("`cuda` feature is not enabled")
77+
}
78+
}
79+
}
80+
}
81+
482
#[derive(Debug)]
583
pub struct LayerNorm {
684
weight: Tensor,
@@ -49,6 +127,7 @@ impl LayerNorm {
49127
let hidden_states = hidden_states_normed
50128
.to_dtype(hidden_states_dtype)?
51129
.broadcast_mul(&self.weight)?;
130+
52131
hidden_states.broadcast_add(&self.bias)
53132
}
54133
Device::Cuda(_) => {

backends/candle/src/layers/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mod rms_norm;
77
mod rotary;
88

99
pub use cublaslt::get_cublas_lt_wrapper;
10-
pub use layer_norm::LayerNorm;
10+
pub use layer_norm::{LayerNorm, LayerNormNoBias};
1111
pub use linear::{HiddenAct, Linear};
1212
#[allow(unused_imports)]
1313
pub use rms_norm::RMSNorm;

backends/candle/src/lib.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ use crate::compute_cap::{
1212
};
1313
use crate::models::{
1414
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
15-
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, NomicBertModel, NomicConfig,
16-
Qwen2Config,
15+
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig,
16+
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
1717
};
1818
#[cfg(feature = "cuda")]
1919
use crate::models::{
@@ -104,6 +104,8 @@ enum Config {
104104
Qwen2(Qwen2Config),
105105
#[serde(rename = "mpnet")]
106106
MPNet(MPNetConfig),
107+
#[serde(rename(deserialize = "modernbert"))]
108+
ModernBert(ModernBertConfig),
107109
}
108110

109111
pub struct CandleBackend {
@@ -274,6 +276,19 @@ impl CandleBackend {
274276
tracing::info!("Starting MPNet model on {:?}", device);
275277
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
276278
}
279+
(Config::ModernBert(config), _) => match device {
280+
Device::Metal(_) => {
281+
return Err(BackendError::Start(
282+
"ModernBert is not currently supported on MPS device".to_string(),
283+
));
284+
}
285+
_ => {
286+
tracing::info!("Starting ModernBert model on {:?}", device);
287+
Ok(Box::new(
288+
ModernBertModel::load(vb, &config, model_type).s()?,
289+
))
290+
}
291+
},
277292
#[cfg(feature = "cuda")]
278293
(Config::Bert(config), Device::Cuda(_)) => {
279294
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))

backends/candle/src/models/flash_bert.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ impl BertAttention {
104104
self.softmax_scale,
105105
false,
106106
None,
107+
None,
107108
)?;
108109
let attention = attention.flatten_from(candle::D::Minus2)?;
109110

backends/candle/src/models/flash_distilbert.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ impl DistilBertAttention {
8585
self.softmax_scale,
8686
false,
8787
None,
88+
None,
8889
)?;
8990
let attention = attention.flatten_from(candle::D::Minus2)?;
9091

backends/candle/src/models/flash_gte.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ impl GTEAttention {
8787
self.softmax_scale,
8888
false,
8989
None,
90+
None,
9091
)?;
9192
let attention = attention.flatten_from(candle::D::Minus2)?;
9293

backends/candle/src/models/flash_jina.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ impl JinaAttention {
106106
self.softmax_scale,
107107
false,
108108
None,
109+
None,
109110
)?;
110111
let attention = attention.flatten_from(candle::D::Minus2)?;
111112

backends/candle/src/models/flash_jina_code.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ impl JinaCodeAttention {
142142
self.softmax_scale,
143143
false,
144144
None,
145+
None,
145146
)?;
146147
let attention = attention.flatten_from(candle::D::Minus2)?;
147148

0 commit comments

Comments
 (0)