Skip to content

Commit e496fe7

Browse files
feat: add /similarity route (#331)
1 parent 7b9245d commit e496fe7

File tree

5 files changed

+321
-2
lines changed

5 files changed

+321
-2
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/openapi.json

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,93 @@
565565
}
566566
}
567567
},
568+
"/similarity": {
569+
"post": {
570+
"tags": [
571+
"Text Embeddings Inference"
572+
],
573+
"summary": "Get Sentence Similarity. Returns a 424 status code if the model is not an embedding model.",
574+
"operationId": "similarity",
575+
"requestBody": {
576+
"content": {
577+
"application/json": {
578+
"schema": {
579+
"$ref": "#/components/schemas/SimilarityRequest"
580+
}
581+
}
582+
},
583+
"required": true
584+
},
585+
"responses": {
586+
"200": {
587+
"description": "Sentence Similarity",
588+
"content": {
589+
"application/json": {
590+
"schema": {
591+
"$ref": "#/components/schemas/SimilarityResponse"
592+
}
593+
}
594+
}
595+
},
596+
"413": {
597+
"description": "Batch size error",
598+
"content": {
599+
"application/json": {
600+
"schema": {
601+
"$ref": "#/components/schemas/ErrorResponse"
602+
},
603+
"example": {
604+
"error": "Batch size error",
605+
"error_type": "validation"
606+
}
607+
}
608+
}
609+
},
610+
"422": {
611+
"description": "Tokenization error",
612+
"content": {
613+
"application/json": {
614+
"schema": {
615+
"$ref": "#/components/schemas/ErrorResponse"
616+
},
617+
"example": {
618+
"error": "Tokenization error",
619+
"error_type": "tokenizer"
620+
}
621+
}
622+
}
623+
},
624+
"424": {
625+
"description": "Embedding Error",
626+
"content": {
627+
"application/json": {
628+
"schema": {
629+
"$ref": "#/components/schemas/ErrorResponse"
630+
},
631+
"example": {
632+
"error": "Inference failed",
633+
"error_type": "backend"
634+
}
635+
}
636+
}
637+
},
638+
"429": {
639+
"description": "Model is overloaded",
640+
"content": {
641+
"application/json": {
642+
"schema": {
643+
"$ref": "#/components/schemas/ErrorResponse"
644+
},
645+
"example": {
646+
"error": "Model is overloaded",
647+
"error_type": "overloaded"
648+
}
649+
}
650+
}
651+
}
652+
}
653+
}
654+
},
568655
"/tokenize": {
569656
"post": {
570657
"tags": [
@@ -1441,6 +1528,91 @@
14411528
"$ref": "#/components/schemas/Rank"
14421529
}
14431530
},
1531+
"SimilarityInput": {
1532+
"type": "object",
1533+
"required": [
1534+
"source_sentence",
1535+
"sentences"
1536+
],
1537+
"properties": {
1538+
"sentences": {
1539+
"type": "array",
1540+
"items": {
1541+
"type": "string"
1542+
},
1543+
"description": "A list of strings which will be compared against the source_sentence.",
1544+
"example": [
1545+
"What is Machine Learning?"
1546+
]
1547+
},
1548+
"source_sentence": {
1549+
"type": "string",
1550+
"description": "The string that you wish to compare the other strings with. This can be a phrase, sentence,\nor longer passage, depending on the model being used.",
1551+
"example": "What is Deep Learning?"
1552+
}
1553+
}
1554+
},
1555+
"SimilarityParameters": {
1556+
"type": "object",
1557+
"required": [
1558+
"truncation_direction"
1559+
],
1560+
"properties": {
1561+
"prompt_name": {
1562+
"type": "string",
1563+
"description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `Sentence Transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.",
1564+
"default": "null",
1565+
"example": "null",
1566+
"nullable": true
1567+
},
1568+
"truncate": {
1569+
"type": "boolean",
1570+
"default": "false",
1571+
"example": "false",
1572+
"nullable": true
1573+
},
1574+
"truncation_direction": {
1575+
"allOf": [
1576+
{
1577+
"$ref": "#/components/schemas/TruncationDirection"
1578+
}
1579+
],
1580+
"default": "right"
1581+
}
1582+
}
1583+
},
1584+
"SimilarityRequest": {
1585+
"type": "object",
1586+
"required": [
1587+
"inputs"
1588+
],
1589+
"properties": {
1590+
"inputs": {
1591+
"$ref": "#/components/schemas/SimilarityInput"
1592+
},
1593+
"parameters": {
1594+
"allOf": [
1595+
{
1596+
"$ref": "#/components/schemas/SimilarityParameters"
1597+
}
1598+
],
1599+
"default": "null",
1600+
"nullable": true
1601+
}
1602+
}
1603+
},
1604+
"SimilarityResponse": {
1605+
"type": "array",
1606+
"items": {
1607+
"type": "number",
1608+
"format": "float"
1609+
},
1610+
"example": [
1611+
0.0,
1612+
1.0,
1613+
0.5
1614+
]
1615+
},
14441616
"SimpleToken": {
14451617
"type": "object",
14461618
"required": [

router/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ opentelemetry = "0.23.0"
3030
opentelemetry_sdk = { version = "0.23.0", features = ["rt-tokio"] }
3131
opentelemetry-otlp = "0.16.0"
3232
reqwest = { version = "0.12.5", features = [] }
33+
simsimd = "4.4.0"
3334
serde = { workspace = true }
3435
serde_json = { workspace = true }
3536
thiserror = { workspace = true }

router/src/http/server.rs

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use crate::http::types::{
44
EmbedSparseRequest, EmbedSparseResponse, Embedding, EncodingFormat, Input, InputIds, InputType,
55
OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse,
66
OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank,
7-
RerankRequest, RerankResponse, Sequence, SimpleToken, SparseValue, TokenizeInput,
7+
RerankRequest, RerankResponse, Sequence, SimilarityInput, SimilarityParameters,
8+
SimilarityRequest, SimilarityResponse, SimpleToken, SparseValue, TokenizeInput,
89
TokenizeRequest, TokenizeResponse, TruncationDirection, VertexPrediction, VertexRequest,
910
VertexResponse,
1011
};
@@ -26,6 +27,7 @@ use futures::future::join_all;
2627
use futures::FutureExt;
2728
use http::header::AUTHORIZATION;
2829
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
30+
use simsimd::SpatialSimilarity;
2931
use std::net::SocketAddr;
3032
use std::time::{Duration, Instant};
3133
use text_embeddings_backend::BackendError;
@@ -455,6 +457,88 @@ async fn rerank(
455457
Ok((headers, Json(response)))
456458
}
457459

460+
/// Get Sentence Similarity. Returns a 424 status code if the model is not an embedding model.
461+
#[utoipa::path(
462+
post,
463+
tag = "Text Embeddings Inference",
464+
path = "/similarity",
465+
request_body = SimilarityRequest,
466+
responses(
467+
(status = 200, description = "Sentence Similarity", body = SimilarityResponse),
468+
(status = 424, description = "Embedding Error", body = ErrorResponse,
469+
example = json ! ({"error": "Inference failed", "error_type": "backend"})),
470+
(status = 429, description = "Model is overloaded", body = ErrorResponse,
471+
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
472+
(status = 422, description = "Tokenization error", body = ErrorResponse,
473+
example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})),
474+
(status = 413, description = "Batch size error", body = ErrorResponse,
475+
example = json ! ({"error": "Batch size error", "error_type": "validation"})),
476+
)
477+
)]
478+
#[instrument(
479+
skip_all,
480+
fields(total_time, tokenization_time, queue_time, inference_time,)
481+
)]
482+
async fn similarity(
483+
infer: Extension<Infer>,
484+
info: Extension<Info>,
485+
Json(req): Json<SimilarityRequest>,
486+
) -> Result<(HeaderMap, Json<SimilarityResponse>), (StatusCode, Json<ErrorResponse>)> {
487+
if req.inputs.sentences.is_empty() {
488+
let message = "`inputs.sentences` cannot be empty".to_string();
489+
tracing::error!("{message}");
490+
let err = ErrorResponse {
491+
error: message,
492+
error_type: ErrorType::Validation,
493+
};
494+
let counter = metrics::counter!("te_request_failure", "err" => "validation");
495+
counter.increment(1);
496+
Err(err)?;
497+
}
498+
// +1 because of the source sentence
499+
let batch_size = req.inputs.sentences.len() + 1;
500+
if batch_size > info.max_client_batch_size {
501+
let message = format!(
502+
"batch size {batch_size} > maximum allowed batch size {}",
503+
info.max_client_batch_size
504+
);
505+
tracing::error!("{message}");
506+
let err = ErrorResponse {
507+
error: message,
508+
error_type: ErrorType::Validation,
509+
};
510+
let counter = metrics::counter!("te_request_failure", "err" => "batch_size");
511+
counter.increment(1);
512+
Err(err)?;
513+
}
514+
515+
// Convert request to embed request
516+
let mut inputs = Vec::with_capacity(req.inputs.sentences.len() + 1);
517+
inputs.push(InputType::String(req.inputs.source_sentence));
518+
for s in req.inputs.sentences {
519+
inputs.push(InputType::String(s));
520+
}
521+
let parameters = req.parameters.unwrap_or_default();
522+
let embed_req = EmbedRequest {
523+
inputs: Input::Batch(inputs),
524+
truncate: parameters.truncate,
525+
truncation_direction: parameters.truncation_direction,
526+
prompt_name: parameters.prompt_name,
527+
normalize: false,
528+
};
529+
530+
// Get embeddings
531+
let (header_map, embed_response) = embed(infer, info, Json(embed_req)).await?;
532+
let embeddings = embed_response.0 .0;
533+
534+
// Compute cosine
535+
let distances = (1..batch_size)
536+
.map(|i| 1.0 - f32::cosine(&embeddings[0], &embeddings[i]).unwrap() as f32)
537+
.collect();
538+
539+
Ok((header_map, Json(SimilarityResponse(distances))))
540+
}
541+
458542
/// Get Embeddings. Returns a 424 status code if the model is not an embedding model.
459543
#[utoipa::path(
460544
post,
@@ -1472,6 +1556,7 @@ pub async fn run(
14721556
embed_all,
14731557
embed_sparse,
14741558
openai_embed,
1559+
similarity,
14751560
tokenize,
14761561
decode,
14771562
metrics,
@@ -1509,6 +1594,10 @@ pub async fn run(
15091594
TokenizeRequest,
15101595
TokenizeResponse,
15111596
TruncationDirection,
1597+
SimilarityInput,
1598+
SimilarityParameters,
1599+
SimilarityRequest,
1600+
SimilarityResponse,
15121601
SimpleToken,
15131602
InputType,
15141603
InputIds,
@@ -1587,6 +1676,7 @@ pub async fn run(
15871676
.route("/embed_sparse", post(embed_sparse))
15881677
.route("/predict", post(predict))
15891678
.route("/rerank", post(rerank))
1679+
.route("/similarity", post(similarity))
15901680
.route("/tokenize", post(tokenize))
15911681
.route("/decode", post(decode))
15921682
// OpenAI compat route
@@ -1634,7 +1724,11 @@ pub async fn run(
16341724
.route("/invocations", post(rerank))
16351725
}
16361726
ModelType::Embedding(model) => {
1637-
if model.pooling == "splade" {
1727+
if std::env::var("TASK").ok() == Some("sentence-similarity".to_string()) {
1728+
app.route("/", post(similarity))
1729+
// AWS Sagemaker route
1730+
.route("/invocations", post(similarity))
1731+
} else if model.pooling == "splade" {
16381732
app.route("/", post(embed_sparse))
16391733
// AWS Sagemaker route
16401734
.route("/invocations", post(embed_sparse))

0 commit comments

Comments
 (0)