Skip to content

Commit 4048136

Browse files
authored
feat: support roberta (#62)
1 parent 944f5c7 commit 4048136

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

README.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ such as:
5353

5454
### Supported Models
5555

56-
You can use any JinaBERT model with Alibi or absolute positions or any BERT, CamemBERT or XLM-RoBERTa model with
57-
absolute positions in `text-embeddings-inference`.
56+
You can use any JinaBERT model with Alibi or absolute positions or any BERT, CamemBERT, RoBERTa, or XLM-RoBERTa model with absolute positions in `text-embeddings-inference`.
5857

5958
**Support for other model types will be added in the future.**
6059

@@ -96,8 +95,8 @@ curl 127.0.0.1:8080/embed \
9695
-H 'Content-Type: application/json'
9796
```
9897

99-
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
100-
We also recommend using NVIDIA drivers with CUDA version 12.0 or higher.
98+
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
99+
We also recommend using NVIDIA drivers with CUDA version 12.0 or higher.
101100

102101
To see all options to serve your models:
103102

@@ -236,7 +235,7 @@ Text Embeddings Inference ships with multiple Docker images that you can use to
236235
| Ada Lovelace (RTX 4000 series, ...) | ghcr.io/huggingface/text-embeddings-inference:89-0.3.0 |
237236
| Hopper (H100) | ghcr.io/huggingface/text-embeddings-inference:hopper-0.3.0 (experimental) |
238237

239-
**Warning**: Flash Attention is turned off by default for the Turing image as it suffers from precision issues.
238+
**Warning**: Flash Attention is turned off by default for the Turing image as it suffers from precision issues.
240239
You can turn Flash Attention v1 ON by using the `USE_FLASH_ATTENTION=True` environment variable.
241240

242241
### API documentation
@@ -329,7 +328,7 @@ cargo install --path router -F candle-cuda-turing --no-default-features
329328
cargo install --path router -F candle-cuda --no-default-features
330329
```
331330

332-
You can now launch Text Embeddings Inference on GPU with:
331+
You can now launch Text Embeddings Inference on GPU with:
333332

334333
```shell
335334
model=BAAI/bge-large-en-v1.5

backends/candle/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ impl CandleBackend {
4141
if config.model_type != Some("bert".to_string())
4242
&& config.model_type != Some("xlm-roberta".to_string())
4343
&& config.model_type != Some("camembert".to_string())
44+
&& config.model_type != Some("roberta".to_string())
4445
{
4546
return Err(BackendError::Start(format!(
4647
"Model {:?} is not supported",

router/src/main.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,14 @@ async fn main() -> Result<()> {
215215
tokenizer.with_padding(None);
216216

217217
// Position IDs offset. Used for Roberta and camembert.
218-
let position_offset =
219-
if &config.model_type == "xlm-roberta" || &config.model_type == "camembert" {
220-
config.pad_token_id + 1
221-
} else {
222-
0
223-
};
218+
let position_offset = if &config.model_type == "xlm-roberta"
219+
|| &config.model_type == "camembert"
220+
|| &config.model_type == "roberta"
221+
{
222+
config.pad_token_id + 1
223+
} else {
224+
0
225+
};
224226
let max_input_length = config.max_position_embeddings - position_offset;
225227

226228
let tokenization_workers = args

0 commit comments

Comments
 (0)