Skip to content

Commit 27c118a

Browse files
authored
Adding missing head. prefix in the weight name in ModernBertClassificationHead (#591)
1 parent b17d05f commit 27c118a

File tree

4 files changed

+53
-11
lines changed

4 files changed

+53
-11
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ Below are some examples of the currently supported models:
101101
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) |
102102
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) |
103103
| Re-Ranking | GTE | [Alibaba-NLP/gte-multilingual-reranker-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-reranker-base) |
104+
| Re-Ranking | ModernBert | [Alibaba-NLP/gte-reranker-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-reranker-modernbert-base) |
104105
| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) |
105106

106107
### Docker

backends/candle/src/models/modernbert.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -412,27 +412,28 @@ pub struct ModernBertClassificationHead {
412412
impl ModernBertClassificationHead {
413413
pub(crate) fn load(vb: VarBuilder, config: &ModernBertConfig) -> Result<Self> {
414414
let dense_weight = vb
415-
.pp("dense")
415+
.pp("head.dense")
416416
.get((config.hidden_size, config.hidden_size), "weight")?;
417-
let dense_bias = vb.pp("dense").get(config.hidden_size, "bias").ok();
418417
let dense = Linear::new(
419418
dense_weight,
420-
dense_bias,
419+
None,
421420
Some(config.classifier_activation.clone()),
422421
);
423422

424-
let norm =
425-
LayerNormNoBias::load(vb.pp("norm"), config.hidden_size, config.norm_eps as f32)?;
423+
let norm = LayerNormNoBias::load(
424+
vb.pp("head.norm"),
425+
config.hidden_size,
426+
config.norm_eps as f32,
427+
)?;
426428

427-
let classifier_weight = vb.pp("dense").get(
429+
let classifier_weight = vb.pp("classifier").get(
428430
(config.num_labels.unwrap_or(1), config.hidden_size),
429431
"weight",
430432
)?;
431433
let classifier_bias = vb
432-
.pp("dense")
433-
.get(config.num_labels.unwrap_or(1), "bias")
434-
.ok();
435-
let classifier = Linear::new(classifier_weight, classifier_bias, None);
434+
.pp("classifier")
435+
.get(config.num_labels.unwrap_or(1), "bias")?;
436+
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);
436437

437438
Ok(Self {
438439
dense,
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
source: backends/candle/tests/test_modernbert.rs
3+
expression: predictions_single
4+
---
5+
- - 2.2585099

backends/candle/tests/test_modernbert.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ mod common;
22

33
use crate::common::{sort_embeddings, SnapshotEmbeddings};
44
use anyhow::Result;
5-
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
5+
use common::{
6+
batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher, SnapshotScores,
7+
};
68
use text_embeddings_backend_candle::CandleBackend;
79
use text_embeddings_backend_core::{Backend, ModelType, Pool};
810

@@ -135,3 +137,36 @@ fn test_mini_pooled_raw() -> Result<()> {
135137

136138
Ok(())
137139
}
140+
141+
#[test]
142+
#[serial_test::serial]
143+
fn test_modernbert_classification() -> Result<()> {
144+
let model_root = download_artifacts("Alibaba-NLP/gte-reranker-modernbert-base", None).unwrap();
145+
let tokenizer = load_tokenizer(&model_root)?;
146+
147+
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
148+
149+
let input_single = batch(
150+
vec![tokenizer
151+
.encode(("What is Deep Learning?", "Deep Learning is not..."), true)
152+
.unwrap()],
153+
[0].to_vec(),
154+
vec![],
155+
);
156+
157+
let predictions: Vec<Vec<f32>> = backend
158+
.predict(input_single)?
159+
.into_iter()
160+
.map(|(_, v)| v)
161+
.collect();
162+
let predictions_single = SnapshotScores::from(predictions);
163+
164+
let matcher = relative_matcher();
165+
insta::assert_yaml_snapshot!(
166+
"modernbert_classification_single",
167+
predictions_single,
168+
&matcher
169+
);
170+
171+
Ok(())
172+
}

0 commit comments

Comments
 (0)