Skip to content

Commit 0bfeb7e

Browse files
feat: GTE classification head (#441)
Co-authored-by: Hyeongchan Kim <kozistr@gmail.com>
1 parent 7c4f67e commit 0bfeb7e

18 files changed

+145
-32
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,12 @@ Text Embeddings Inference currently supports CamemBERT, and XLM-RoBERTa Sequence
9292

9393
Below are some examples of the currently supported models:
9494

95-
| Task | Model Type | Model ID |
96-
|--------------------|-------------|---------------------------------------------------------------------------------------------|
97-
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) |
98-
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) |
99-
| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) |
95+
| Task | Model Type | Model ID |
96+
|--------------------|-------------|-----------------------------------------------------------------------------------------------------------------|
97+
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) |
98+
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) |
99+
| Re-Ranking | GTE | [Alibaba-NLP/gte-multilingual-reranker-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-reranker-base) |
100+
| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) |
100101

101102
### Docker
102103

@@ -372,7 +373,7 @@ docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingf
372373

373374
### Using Re-rankers models
374375

375-
`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa and XLM-RoBERTa Sequence Classification models.
376+
`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa, XLM-RoBERTa, and GTE Sequence Classification models.
376377
Re-rankers models are Sequence Classification cross-encoders models with a single class that scores the similarity
377378
between a query and a text.
378379

@@ -392,7 +393,7 @@ And then you can rank the similarity between a query and a list of texts with:
392393
```bash
393394
curl 127.0.0.1:8080/rerank \
394395
-X POST \
395-
-d '{"query":"What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \
396+
-d '{"query": "What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \
396397
-H 'Content-Type: application/json'
397398
```
398399

backends/candle/src/models/flash_bert.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ impl Model for FlashBertModel {
529529
fn is_padded(&self) -> bool {
530530
false
531531
}
532+
532533
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
533534
self.forward(batch)
534535
}

backends/candle/src/models/flash_gte.rs

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,66 @@ impl GTELayer {
198198
}
199199
}
200200

201+
pub struct GTEClassificationHead {
202+
pooler: Option<Linear>,
203+
classifier: Linear,
204+
span: tracing::Span,
205+
}
206+
207+
impl GTEClassificationHead {
208+
#[allow(dead_code)]
209+
pub(crate) fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
210+
let n_classes = match &config.id2label {
211+
None => candle::bail!("`id2label` must be set for classifier models"),
212+
Some(id2label) => id2label.len(),
213+
};
214+
215+
let pooler = if let Ok(pooler_weight) = vb
216+
.pp("pooler.dense")
217+
.get((config.hidden_size, config.hidden_size), "weight")
218+
{
219+
let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?;
220+
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
221+
} else {
222+
None
223+
};
224+
225+
let classifier_weight = vb
226+
.pp("classifier")
227+
.get((n_classes, config.hidden_size), "weight")?;
228+
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
229+
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);
230+
231+
Ok(Self {
232+
classifier,
233+
pooler,
234+
span: tracing::span!(tracing::Level::TRACE, "classifier"),
235+
})
236+
}
237+
238+
pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
239+
let _enter = self.span.enter();
240+
241+
let mut hidden_states = hidden_states.unsqueeze(1)?;
242+
if let Some(pooler) = self.pooler.as_ref() {
243+
hidden_states = pooler.forward(&hidden_states)?;
244+
hidden_states = hidden_states.tanh()?;
245+
}
246+
247+
let hidden_states = self.classifier.forward(&hidden_states)?;
248+
let hidden_states = hidden_states.squeeze(1)?;
249+
Ok(hidden_states)
250+
}
251+
}
252+
201253
pub struct FlashGTEModel {
202254
word_embeddings: Embedding,
203255
token_type_embeddings: Option<Embedding>,
204256
layers: Vec<GTELayer>,
205257
embeddings_norm: LayerNorm,
206258
cos_cache: Tensor,
207259
sin_cache: Tensor,
260+
classifier: Option<GTEClassificationHead>,
208261
pool: Pool,
209262
pub device: Device,
210263

@@ -233,11 +286,14 @@ impl FlashGTEModel {
233286
candle::bail!("Only `PositionEmbeddingType::Rope` is supported");
234287
}
235288

236-
let pool = match model_type {
289+
let (pool, classifier) = match model_type {
237290
ModelType::Classifier => {
238-
candle::bail!("`classifier` model type is not supported for GTE")
291+
let pool = Pool::Cls;
292+
293+
let classifier = GTEClassificationHead::load(vb.clone(), config)?;
294+
(pool, Some(classifier))
239295
}
240-
ModelType::Embedding(pool) => pool,
296+
ModelType::Embedding(pool) => (pool, None),
241297
};
242298

243299
let word_embeddings = Embedding::new(
@@ -292,6 +348,7 @@ impl FlashGTEModel {
292348
embeddings_norm,
293349
cos_cache,
294350
sin_cache,
351+
classifier,
295352
pool,
296353
device: vb.device().clone(),
297354
span: tracing::span!(tracing::Level::TRACE, "model"),
@@ -457,7 +514,20 @@ impl Model for FlashGTEModel {
457514
fn is_padded(&self) -> bool {
458515
false
459516
}
517+
460518
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
461519
self.forward(batch)
462520
}
521+
522+
fn predict(&self, batch: Batch) -> Result<Tensor> {
523+
match &self.classifier {
524+
None => candle::bail!("`predict` is not implemented for this model"),
525+
Some(classifier) => {
526+
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
527+
let pooled_embeddings =
528+
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
529+
classifier.forward(&pooled_embeddings)
530+
}
531+
}
532+
}
463533
}

backends/candle/src/models/gte.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::layers::HiddenAct;
22
use crate::models::PositionEmbeddingType;
33
use serde::Deserialize;
4+
use std::collections::HashMap;
45

56
#[derive(Debug, Clone, PartialEq, Deserialize)]
67
pub struct NTKScaling {
@@ -32,4 +33,5 @@ pub struct GTEConfig {
3233
pub logn_attention_scale: bool,
3334
#[serde(default)]
3435
pub logn_attention_clip1: bool,
36+
pub id2label: Option<HashMap<String, String>>,
3537
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
source: backends/candle/tests/test_flash_gte.rs
3+
assertion_line: 86
4+
expression: predictions_single
5+
---
6+
- - -0.74365234

backends/candle/tests/test_bert.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ fn test_mini() -> Result<()> {
1313
let tokenizer = load_tokenizer(&model_root)?;
1414

1515
let backend = CandleBackend::new(
16-
model_root,
16+
&model_root,
1717
"float32".to_string(),
1818
ModelType::Embedding(Pool::Mean),
1919
)?;
@@ -73,7 +73,7 @@ fn test_mini_pooled_raw() -> Result<()> {
7373
let tokenizer = load_tokenizer(&model_root)?;
7474

7575
let backend = CandleBackend::new(
76-
model_root,
76+
&model_root,
7777
"float32".to_string(),
7878
ModelType::Embedding(Pool::Cls),
7979
)?;
@@ -142,7 +142,7 @@ fn test_emotions() -> Result<()> {
142142
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
143143
let tokenizer = load_tokenizer(&model_root)?;
144144

145-
let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?;
145+
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
146146

147147
let input_batch = batch(
148148
vec![
@@ -192,7 +192,7 @@ fn test_bert_classification() -> Result<()> {
192192
let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?;
193193
let tokenizer = load_tokenizer(&model_root)?;
194194

195-
let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?;
195+
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
196196

197197
let input_single = batch(
198198
vec![tokenizer

backends/candle/tests/test_flash_bert.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ fn test_flash_mini() -> Result<()> {
1919
let tokenizer = load_tokenizer(&model_root)?;
2020

2121
let backend = CandleBackend::new(
22-
model_root,
22+
&model_root,
2323
"float16".to_string(),
2424
ModelType::Embedding(Pool::Mean),
2525
)?;
@@ -83,7 +83,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> {
8383
let tokenizer = load_tokenizer(&model_root)?;
8484

8585
let backend = CandleBackend::new(
86-
model_root,
86+
&model_root,
8787
"float16".to_string(),
8888
ModelType::Embedding(Pool::Cls),
8989
)?;
@@ -156,7 +156,7 @@ fn test_flash_emotions() -> Result<()> {
156156
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
157157
let tokenizer = load_tokenizer(&model_root)?;
158158

159-
let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?;
159+
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;
160160

161161
let input_batch = batch(
162162
vec![
@@ -210,7 +210,7 @@ fn test_flash_bert_classification() -> Result<()> {
210210
let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?;
211211
let tokenizer = load_tokenizer(&model_root)?;
212212

213-
let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?;
213+
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;
214214

215215
let input_single = batch(
216216
vec![tokenizer

backends/candle/tests/test_flash_gte.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#![allow(dead_code, unused_imports)]
22
mod common;
33

4-
use crate::common::{sort_embeddings, SnapshotEmbeddings};
4+
use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores};
55
use anyhow::Result;
6-
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
6+
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
77
use text_embeddings_backend_candle::CandleBackend;
88
use text_embeddings_backend_core::{Backend, ModelType, Pool};
99

@@ -15,7 +15,7 @@ fn test_flash_gte() -> Result<()> {
1515
let tokenizer = load_tokenizer(&model_root)?;
1616

1717
let backend = CandleBackend::new(
18-
model_root,
18+
&model_root,
1919
"float16".to_string(),
2020
ModelType::Embedding(Pool::Cls),
2121
)?;
@@ -51,3 +51,36 @@ fn test_flash_gte() -> Result<()> {
5151

5252
Ok(())
5353
}
54+
55+
#[test]
56+
#[serial_test::serial]
57+
#[cfg(all(
58+
feature = "cuda",
59+
any(feature = "flash-attn", feature = "flash-attn-v1")
60+
))]
61+
fn test_flash_gte_classification() -> Result<()> {
62+
let model_root = download_artifacts("Alibaba-NLP/gte-multilingual-reranker-base", None)?;
63+
let tokenizer = load_tokenizer(&model_root)?;
64+
65+
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;
66+
67+
let input_single = batch(
68+
vec![tokenizer
69+
.encode(("What is Deep Learning?", "Deep Learning is not..."), true)
70+
.unwrap()],
71+
[0].to_vec(),
72+
vec![],
73+
);
74+
75+
let predictions: Vec<Vec<f32>> = backend
76+
.predict(input_single)?
77+
.into_iter()
78+
.map(|(_, v)| v)
79+
.collect();
80+
let predictions_single = SnapshotScores::from(predictions);
81+
82+
let matcher = relative_matcher();
83+
insta::assert_yaml_snapshot!("gte_classification_single", predictions_single, &matcher);
84+
85+
Ok(())
86+
}

backends/candle/tests/test_flash_jina.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ fn test_flash_jina_small() -> Result<()> {
1515
let tokenizer = load_tokenizer(&model_root)?;
1616

1717
let backend = CandleBackend::new(
18-
model_root,
18+
&model_root,
1919
"float16".to_string(),
2020
ModelType::Embedding(Pool::Mean),
2121
)?;

0 commit comments

Comments
 (0)