Skip to content

Commit 8e85e9c

Browse files
feat: rerank route (#84)
1 parent 36161c1 commit 8e85e9c

File tree

10 files changed

+467
-33
lines changed

10 files changed

+467
-33
lines changed

README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ length of 512 tokens:
3333
- [Docker Images](#docker-images)
3434
- [API Documentation](#api-documentation)
3535
- [Using a private or gated model](#using-a-private-or-gated-model)
36+
- [Using Re-rankers models](#using-re-rankers-models)
3637
- [Using Sequence Classification models](#using-sequence-classification-models)
3738
- [Distributed Tracing](#distributed-tracing)
3839
- [Local Install](#local-install)
@@ -281,11 +282,14 @@ token=<your cli READ token>
281282
docker run --gpus all -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.4.0 --model-id $model
282283
```
283284

284-
### Using Sequence Classification models
285+
### Using Re-rankers models
285286

286287
`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa and XLM-RoBERTa Sequence Classification models.
288+
Re-rankers models are Sequence Classification cross-encoders models with a single class that scores the similarity
289+
between a query and a passage.
290+
287291
See [this blogpost](https://blog.llamaindex.ai/boosting-rag-picking-the-best-embedding-reranker-models-42d079022e83) by
288-
the LlamaIndex team to understand how you can use Sequence Classification models in your RAG pipeline to improve
292+
the LlamaIndex team to understand how you can use re-rankers models in your RAG pipeline to improve
289293
downstream performance.
290294

291295
```shell
@@ -296,15 +300,17 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
296300
docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.4.0 --model-id $model --revision $revision
297301
```
298302

299-
And then you can rank the similarity between a pair of inputs with:
303+
And then you can rank the similarity between a query and a list of passages with:
300304

301305
```bash
302-
curl 127.0.0.1:8080/predict \
306+
curl 127.0.0.1:8080/rerank \
303307
-X POST \
304-
-d '{"inputs":["What is Deep Learning?", "Deep learning is..."], "raw_scores": true}' \
308+
-d '{"query":"What is Deep Learning?", "passages": ["Deep Learning is not...", "Deep learning is..."]}' \
305309
-H 'Content-Type: application/json'
306310
```
307311

312+
### Using Sequence Classification models
313+
308314
You can also use classic Sequence Classification models like `SamLowe/roberta-base-go_emotions`:
309315

310316
```shell

backends/core/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub enum BackendError {
5353
NoBackend,
5454
#[error("Could not start backend: {0}")]
5555
Start(String),
56-
#[error("Inference error: {0}")]
56+
#[error("{0}")]
5757
Inference(String),
5858
#[error("Backend is unhealthy")]
5959
Unhealthy,

backends/src/lib.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mod dtype;
33
use std::path::PathBuf;
44
use std::sync::atomic::{AtomicBool, Ordering};
55
use std::sync::Arc;
6+
use std::time::{Duration, Instant};
67
use text_embeddings_backend_core::Backend as CoreBackend;
78
use tokio::sync::oneshot;
89
use tracing::{instrument, Span};
@@ -91,7 +92,7 @@ impl Backend {
9192
}
9293

9394
#[instrument(skip_all)]
94-
pub async fn embed(&self, batch: Batch) -> Result<Vec<Embedding>, BackendError> {
95+
pub async fn embed(&self, batch: Batch) -> Result<(Vec<Embedding>, Duration), BackendError> {
9596
let (sender, receiver) = oneshot::channel();
9697

9798
self.backend_sender
@@ -107,7 +108,7 @@ impl Backend {
107108
}
108109

109110
#[instrument(skip_all)]
110-
pub async fn predict(&self, batch: Batch) -> Result<Vec<Vec<f32>>, BackendError> {
111+
pub async fn predict(&self, batch: Batch) -> Result<(Vec<Vec<f32>>, Duration), BackendError> {
111112
let (sender, receiver) = oneshot::channel();
112113

113114
self.backend_sender
@@ -166,18 +167,19 @@ fn backend_blocking_task(
166167
command_receiver: flume::Receiver<BackendCommand>,
167168
) {
168169
while let Ok(cmd) = command_receiver.recv() {
170+
let start = Instant::now();
169171
match cmd {
170172
BackendCommand::Health(span, sender) => {
171173
let _span = span.entered();
172174
let _ = sender.send(backend.health());
173175
}
174176
BackendCommand::Embed(batch, span, sender) => {
175177
let _span = span.entered();
176-
let _ = sender.send(backend.embed(batch));
178+
let _ = sender.send(backend.embed(batch).map(|e| (e, start.elapsed())));
177179
}
178180
BackendCommand::Predict(batch, span, sender) => {
179181
let _span = span.entered();
180-
let _ = sender.send(backend.predict(batch));
182+
let _ = sender.send(backend.predict(batch).map(|e| (e, start.elapsed())));
181183
}
182184
}
183185
}
@@ -188,11 +190,12 @@ enum BackendCommand {
188190
Embed(
189191
Batch,
190192
Span,
191-
oneshot::Sender<Result<Vec<Embedding>, BackendError>>,
193+
oneshot::Sender<Result<(Vec<Embedding>, Duration), BackendError>>,
192194
),
193195
Predict(
194196
Batch,
195197
Span,
196-
oneshot::Sender<Result<Vec<Vec<f32>>, BackendError>>,
198+
#[allow(clippy::type_complexity)]
199+
oneshot::Sender<Result<(Vec<Vec<f32>>, Duration), BackendError>>,
197200
),
198201
}

core/src/infer.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl Infer {
9090
) -> Result<InferResponse, TextEmbeddingsError> {
9191
if self.is_classifier() {
9292
metrics::increment_counter!("te_request_failure", "err" => "model_type");
93-
let message = "model is not an embedding model".to_string();
93+
let message = "Model is not an embedding model".to_string();
9494
tracing::error!("{message}");
9595
return Err(TextEmbeddingsError::Backend(BackendError::Inference(
9696
message,
@@ -185,8 +185,7 @@ impl Infer {
185185
) -> Result<InferResponse, TextEmbeddingsError> {
186186
if !self.is_classifier() {
187187
metrics::increment_counter!("te_request_failure", "err" => "model_type");
188-
let message = "model is not a classifier model".to_string();
189-
// tracing::error!("{message}");
188+
let message = "Model is not a classifier model".to_string();
190189
return Err(TextEmbeddingsError::Backend(BackendError::Inference(
191190
message,
192191
)));
@@ -313,22 +312,21 @@ async fn backend_task(
313312
mut embed_receiver: mpsc::UnboundedReceiver<(NextBatch, oneshot::Sender<()>)>,
314313
) {
315314
while let Some((batch, _callback)) = embed_receiver.recv().await {
316-
let inference_start = Instant::now();
317315
let results = match &backend.model_type {
318316
ModelType::Classifier => backend.predict(batch.1).await,
319317
ModelType::Embedding(_) => backend.embed(batch.1).await,
320318
};
321319

322320
// Handle sending responses in another thread to avoid starving the backend
323321
tokio::task::spawn_blocking(move || match results {
324-
Ok(embeddings) => {
322+
Ok((embeddings, inference_duration)) => {
325323
batch.0.into_iter().zip(embeddings).for_each(|(m, e)| {
326324
let _ = m.response_tx.send(Ok(InferResponse {
327325
results: e,
328326
prompt_tokens: m.prompt_tokens,
329327
tokenization: m.tokenization,
330-
queue: inference_start - m.queue_time,
331-
inference: inference_start.elapsed(),
328+
queue: m.queue_time.elapsed() - inference_duration,
329+
inference: inference_duration,
332330
}));
333331
});
334332
}

core/src/tokenization.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ impl From<String> for EncodingInput {
180180
}
181181
}
182182

183+
impl From<(String, String)> for EncodingInput {
184+
fn from(value: (String, String)) -> Self {
185+
Self::Dual(value.0, value.1)
186+
}
187+
}
188+
183189
type TokenizerRequest = (
184190
EncodingInput,
185191
bool,

docs/openapi.json

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,94 @@
348348
}
349349
}
350350
}
351+
},
352+
"/rerank": {
353+
"post": {
354+
"tags": [
355+
"Text Embeddings Inference"
356+
],
357+
"summary": "Get Ranks. Returns a 424 status code if the model is not a Sequence Classification model with",
358+
"description": "Get Ranks. Returns a 424 status code if the model is not a Sequence Classification model with\na single class.",
359+
"operationId": "rerank",
360+
"requestBody": {
361+
"content": {
362+
"application/json": {
363+
"schema": {
364+
"$ref": "#/components/schemas/RerankRequest"
365+
}
366+
}
367+
},
368+
"required": true
369+
},
370+
"responses": {
371+
"200": {
372+
"description": "Ranks",
373+
"content": {
374+
"application/json": {
375+
"schema": {
376+
"$ref": "#/components/schemas/RerankResponse"
377+
}
378+
}
379+
}
380+
},
381+
"413": {
382+
"description": "Batch size error",
383+
"content": {
384+
"application/json": {
385+
"schema": {
386+
"$ref": "#/components/schemas/ErrorResponse"
387+
},
388+
"example": {
389+
"error": "Batch size error",
390+
"error_type": "validation"
391+
}
392+
}
393+
}
394+
},
395+
"422": {
396+
"description": "Tokenization error",
397+
"content": {
398+
"application/json": {
399+
"schema": {
400+
"$ref": "#/components/schemas/ErrorResponse"
401+
},
402+
"example": {
403+
"error": "Tokenization error",
404+
"error_type": "tokenizer"
405+
}
406+
}
407+
}
408+
},
409+
"424": {
410+
"description": "Rerank Error",
411+
"content": {
412+
"application/json": {
413+
"schema": {
414+
"$ref": "#/components/schemas/ErrorResponse"
415+
},
416+
"example": {
417+
"error": "Inference failed",
418+
"error_type": "backend"
419+
}
420+
}
421+
}
422+
},
423+
"429": {
424+
"description": "Model is overloaded",
425+
"content": {
426+
"application/json": {
427+
"schema": {
428+
"$ref": "#/components/schemas/ErrorResponse"
429+
},
430+
"example": {
431+
"error": "Model is overloaded",
432+
"error_type": "overloaded"
433+
}
434+
}
435+
}
436+
}
437+
}
438+
}
351439
}
352440
},
353441
"components": {
@@ -797,6 +885,74 @@
797885
"example": "0.5"
798886
}
799887
}
888+
},
889+
"Rank": {
890+
"type": "object",
891+
"required": [
892+
"index",
893+
"score"
894+
],
895+
"properties": {
896+
"index": {
897+
"type": "integer",
898+
"example": "0",
899+
"minimum": 0
900+
},
901+
"passage": {
902+
"type": "string",
903+
"default": "null",
904+
"example": "Deep Learning is ...",
905+
"nullable": true
906+
},
907+
"score": {
908+
"type": "number",
909+
"format": "float",
910+
"example": "1.0"
911+
}
912+
}
913+
},
914+
"RerankRequest": {
915+
"type": "object",
916+
"required": [
917+
"query",
918+
"passages"
919+
],
920+
"properties": {
921+
"passages": {
922+
"type": "array",
923+
"items": {
924+
"type": "string"
925+
},
926+
"example": [
927+
"Deep Learning is ..."
928+
]
929+
},
930+
"query": {
931+
"type": "string",
932+
"example": "What is Deep Learning?"
933+
},
934+
"raw_scores": {
935+
"type": "boolean",
936+
"default": "false",
937+
"example": "false"
938+
},
939+
"return_passages": {
940+
"type": "boolean",
941+
"default": "false",
942+
"example": "false"
943+
},
944+
"truncate": {
945+
"type": "boolean",
946+
"default": "false",
947+
"example": "false"
948+
}
949+
}
950+
},
951+
"RerankResponse": {
952+
"type": "array",
953+
"items": {
954+
"$ref": "#/components/schemas/Rank"
955+
}
800956
}
801957
}
802958
},

docs/source/en/quick_tour.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@ curl 127.0.0.1:8080/embed \
5353
-H 'Content-Type: application/json'
5454
```
5555

56-
## Sequence Classification
56+
## Re-rankers
57+
58+
Re-rankers models are Sequence Classification cross-encoders models with a single class that scores the similarity
59+
between a query and a passage.
5760

58-
TEI can also be used to deploy Sequence Classification models.
5961
See [this blogpost](https://blog.llamaindex.ai/boosting-rag-picking-the-best-embedding-reranker-models-42d079022e83) by
60-
the LlamaIndex team to understand how you can use Sequence Classification models in your RAG pipeline to improve
62+
the LlamaIndex team to understand how you can use re-rankers models in your RAG pipeline to improve
6163
downstream performance.
6264

6365
Let's say you want to use `BAAI/bge-reranker-large`:
@@ -70,15 +72,18 @@ volume=$PWD/data
7072
docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.4.0 --model-id $model --revision $revision
7173
```
7274

73-
Once you have deployed a model you can use the `predict` endpoint and rank the similarity between a pair of inputs:
75+
Once you have deployed a model you can use the `rerank` endpoint to rank the similarity between a query and a list
76+
of passages:
7477

7578
```bash
76-
curl 127.0.0.1:8080/predict \
79+
curl 127.0.0.1:8080/rerank \
7780
-X POST \
78-
-d '{"inputs":["What is Deep Learning?", "Deep learning is..."], "raw_scores": true}' \
81+
-d '{"query":"What is Deep Learning?", "passages": ["Deep Learning is not...", "Deep learning is..."], "raw_scores": false}' \
7982
-H 'Content-Type: application/json'
8083
```
8184

85+
## Sequence Classification
86+
8287
You can also use classic Sequence Classification models like `SamLowe/roberta-base-go_emotions`:
8388

8489
```shell

docs/source/en/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Below are some examples of the currently supported models:
4444
To explore the list of best performing text embeddings models, visit the
4545
[Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
4646

47-
## Supported sequence classification models
47+
## Supported re-rankers and sequence classification models
4848

4949
Text Embeddings Inference currently supports CamemBERT, and XLM-RoBERTa Sequence Classification models with absolute positions.
5050

0 commit comments

Comments
 (0)