Skip to content

Commit f9c86b3

Browse files
Merge pull request #5 from huggingface/roberta
feat: add support for XLM-RoBERTa
2 parents b718452 + 5c06e62 commit f9c86b3

File tree

4 files changed

+51
-18
lines changed

4 files changed

+51
-18
lines changed

README.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,25 @@ Benchmark for [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1
4848

4949
### Supported Models
5050

51-
You can use any BERT model with absolute positions in `text-embeddings-inference`. If the model does not have `safetensors` weights
52-
you can convert it using [this space](https://huggingface.co/spaces/safetensors/convert).
51+
You can use any BERT or XLM-RoBERTa model with absolute positions in `text-embeddings-inference`.
52+
If the model does not have `safetensors` weights you can convert it using [this space](https://huggingface.co/spaces/safetensors/convert).
5353

5454
**Support for other model types will be added in the future.**
5555

56-
| MTEB Rank | Model Type | Model ID | Specific Revision |
57-
|-----------|------------|------------------------|--------------------------------------------------------------------------|
58-
| 1 | Bert | BAAI/bge-large-en-v1.5 | [refs/pr/5](https://huggingface.co/BAAI/bge-large-en-v1.5/discussions/5) |
59-
| 2 | | BAAI/bge-base-en-v1.5 | [refs/pr/1](https://huggingface.co/BAAI/bge-base-en-v1.5/discussions/1) |
60-
| 3 | | llmrails/ember-v1 | |
61-
| 4 | | thenlper/gte-large | |
62-
| 5 | | thenlper/gte-base | |
63-
| 6 | | intfloat/e5-large-v2 | |
64-
| 7 | | BAAI/bge-small-en-v1.5 | [refs/pr/3](https://huggingface.co/BAAI/bge-small-en-v1.5/discussions/3) |
65-
| 10 | | intfloat/e5-base-v2 | |
56+
Examples of supported models:
57+
58+
| MTEB Rank | Model Type | Model ID | Specific Revision |
59+
|-----------|--------------|--------------------------------|--------------------------------------------------------------------------|
60+
| 1 | Bert | BAAI/bge-large-en-v1.5 | [refs/pr/5](https://huggingface.co/BAAI/bge-large-en-v1.5/discussions/5) |
61+
| 2 | | BAAI/bge-base-en-v1.5 | [refs/pr/1](https://huggingface.co/BAAI/bge-base-en-v1.5/discussions/1) |
62+
| 3 | | llmrails/ember-v1 | |
63+
| 4 | | thenlper/gte-large | |
64+
| 5 | | thenlper/gte-base | |
65+
| 6 | | intfloat/e5-large-v2 | |
66+
| 7 | | BAAI/bge-small-en-v1.5 | [refs/pr/3](https://huggingface.co/BAAI/bge-small-en-v1.5/discussions/3) |
67+
| 10 | | intfloat/e5-base-v2 | |
68+
| 11 | XLM-RoBERTa | intfloat/multilingual-e5-large | |
69+
6670

6771
You can explore the list of best performing text embeddings models [here](https://huggingface.co/spaces/mteb/leaderboard).
6872

backends/candle/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ impl CandleBackend {
3939
};
4040

4141
// Check model type
42-
if config.model_type != Some("bert".to_string()) {
42+
if config.model_type != Some("bert".to_string())
43+
&& config.model_type != Some("xlm-roberta".to_string())
44+
{
4345
return Err(BackendError::Start(format!(
4446
"Model {:?} is not supported",
4547
config.model_type

core/src/tokenization.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ pub struct Tokenization {
1313
}
1414

1515
impl Tokenization {
16-
pub fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
16+
pub fn new(
17+
workers: usize,
18+
tokenizer: Tokenizer,
19+
max_input_length: usize,
20+
position_offset: usize,
21+
) -> Self {
1722
// Create channel
1823
let (sender, receiver) = flume::unbounded();
1924

@@ -24,7 +29,12 @@ impl Tokenization {
2429

2530
// Spawn worker
2631
tokio::task::spawn_blocking(move || {
27-
tokenizer_worker(tokenizer_clone, max_input_length, receiver_clone)
32+
tokenizer_worker(
33+
tokenizer_clone,
34+
max_input_length,
35+
position_offset,
36+
receiver_clone,
37+
)
2838
});
2939
}
3040

@@ -66,6 +76,7 @@ impl Tokenization {
6676
fn tokenizer_worker(
6777
tokenizer: Tokenizer,
6878
max_input_length: usize,
79+
position_offset: usize,
6980
receiver: flume::Receiver<TokenizerRequest>,
7081
) {
7182
// Loop over requests
@@ -74,8 +85,13 @@ fn tokenizer_worker(
7485
if !response_tx.is_closed() {
7586
// It's possible that the user dropped its request resulting in a send error.
7687
// We just discard the error
77-
let _ =
78-
response_tx.send(encode_input(inputs, truncate, max_input_length, &tokenizer));
88+
let _ = response_tx.send(encode_input(
89+
inputs,
90+
truncate,
91+
max_input_length,
92+
position_offset,
93+
&tokenizer,
94+
));
7995
}
8096
})
8197
}
@@ -86,6 +102,7 @@ fn encode_input(
86102
inputs: String,
87103
truncate: bool,
88104
max_input_length: usize,
105+
position_offset: usize,
89106
tokenizer: &Tokenizer,
90107
) -> Result<Encoding, TextEmbeddingsError> {
91108
// Get the number of tokens in the input
@@ -109,7 +126,8 @@ fn encode_input(
109126
Ok(Encoding {
110127
input_ids: encoding.get_ids().to_vec(),
111128
token_type_ids: encoding.get_type_ids().to_vec(),
112-
position_ids: (0..seq_len as u32).collect::<Vec<_>>(),
129+
position_ids: (position_offset as u32..(seq_len + position_offset) as u32)
130+
.collect::<Vec<_>>(),
113131
})
114132
}
115133

router/src/main.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ pub struct ModelConfig {
116116
pub model_type: String,
117117
#[serde(alias = "n_positions")]
118118
pub max_position_embeddings: usize,
119+
pub pad_token_id: usize,
119120
}
120121

121122
#[tokio::main]
@@ -167,11 +168,19 @@ async fn main() -> Result<()> {
167168
);
168169
tokenizer.with_padding(None);
169170

171+
// Position IDs offset. Used for Roberta.
172+
let position_offset = if config.pad_token_id == 0 {
173+
0
174+
} else {
175+
config.pad_token_id + 1
176+
};
177+
170178
// Tokenization logic
171179
let tokenization = Tokenization::new(
172180
args.tokenization_workers,
173181
tokenizer,
174182
config.max_position_embeddings,
183+
position_offset,
175184
);
176185

177186
// Create backend

0 commit comments

Comments
 (0)