Skip to content

Support MRL (Matryoshka Representation Learning) #676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,16 @@ impl Infer {
panic!("unexpected enum variant")
};

// Timings
let total_time = start_time.elapsed();

// Metrics
let counter = metrics::counter!("te_embed_success");
counter.increment(1);
let histogram = metrics::histogram!("te_embed_duration");
histogram.record(total_time.as_secs_f64());
let histogram = metrics::histogram!("te_embed_tokenization_duration");
histogram.record(response.metadata.tokenization.as_secs_f64());
let histogram = metrics::histogram!("te_embed_queue_duration");
histogram.record(response.metadata.queue.as_secs_f64());
let histogram = metrics::histogram!("te_embed_inference_duration");
histogram.record(response.metadata.inference.as_secs_f64());
metrics::counter!("te_embed_success").increment(1);
metrics::histogram!("te_embed_duration").record(total_time.as_secs_f64());
metrics::histogram!("te_embed_tokenization_duration")
.record(response.metadata.tokenization.as_secs_f64());
metrics::histogram!("te_embed_queue_duration")
.record(response.metadata.queue.as_secs_f64());
metrics::histogram!("te_embed_inference_duration")
.record(response.metadata.inference.as_secs_f64());

Ok(response)
}
Expand Down Expand Up @@ -224,6 +220,7 @@ impl Infer {
Ok(response)
}

#[allow(clippy::too_many_arguments)]
#[instrument(skip(self, inputs, permit))]
pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>(
&self,
Expand All @@ -232,20 +229,29 @@ impl Infer {
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
normalize: bool,
dimensions: Option<usize>,
permit: OwnedSemaphorePermit,
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
let start_time = Instant::now();

if self.is_splade() && normalize {
let counter = metrics::counter!("te_request_failure", "err" => "model_type");
counter.increment(1);
metrics::counter!("te_request_failure", "err" => "model_type").increment(1);
let message = "`normalize` is not available for SPLADE models".to_string();
tracing::error!("{message}");
return Err(TextEmbeddingsError::Backend(BackendError::Inference(
message,
)));
}

if let Some(dimensions) = dimensions {
if dimensions == 0 {
metrics::counter!("te_request_failure", "err" => "validation").increment(1);
let message = "`dimensions` should be positive".to_string();
tracing::error!("{message}");
return Err(TextEmbeddingsError::Validation(message));
}
}

let results = self
.embed(
inputs,
Expand All @@ -262,6 +268,11 @@ impl Infer {
panic!("unexpected enum variant")
};

if let Some(mrl_dimensions) = dimensions {
let mrl_dimensions = mrl_dimensions.min(response.results.len());
response.results.truncate(mrl_dimensions);
}

if normalize {
// Normalize embedding
let scale = (1.0
Expand All @@ -283,16 +294,14 @@ impl Infer {
let total_time = start_time.elapsed();

// Metrics
let counter = metrics::counter!("te_embed_success");
counter.increment(1);
let histogram = metrics::histogram!("te_embed_duration");
histogram.record(total_time.as_secs_f64());
let histogram = metrics::histogram!("te_embed_tokenization_duration");
histogram.record(response.metadata.tokenization.as_secs_f64());
let histogram = metrics::histogram!("te_embed_queue_duration");
histogram.record(response.metadata.queue.as_secs_f64());
let histogram = metrics::histogram!("te_embed_inference_duration");
histogram.record(response.metadata.inference.as_secs_f64());
metrics::counter!("te_embed_success").increment(1);
metrics::histogram!("te_embed_duration").record(total_time.as_secs_f64());
metrics::histogram!("te_embed_tokenization_duration")
.record(response.metadata.tokenization.as_secs_f64());
metrics::histogram!("te_embed_queue_duration")
.record(response.metadata.queue.as_secs_f64());
metrics::histogram!("te_embed_inference_duration")
.record(response.metadata.inference.as_secs_f64());

Ok(response)
}
Expand Down
1 change: 1 addition & 0 deletions proto/tei.proto
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ message EmbedRequest {
bool normalize = 3;
TruncationDirection truncation_direction = 4;
optional string prompt_name = 5;
optional uint32 dimensions = 6;
}

message EmbedResponse {
Expand Down
1 change: 1 addition & 0 deletions router/src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ impl TextEmbeddingsService {
truncation_direction,
request.prompt_name,
request.normalize,
request.dimensions.map(|v| v as usize),
permit,
)
.await
Expand Down
5 changes: 5 additions & 0 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ async fn similarity(
truncation_direction: parameters.truncation_direction,
prompt_name: parameters.prompt_name,
normalize: false,
dimensions: None,
};

// Get embeddings
Expand Down Expand Up @@ -611,6 +612,7 @@ async fn embed(
req.truncation_direction.into(),
req.prompt_name,
req.normalize,
req.dimensions,
permit,
)
.await
Expand Down Expand Up @@ -679,6 +681,7 @@ async fn embed(
req.truncation_direction.into(),
prompt_name,
req.normalize,
req.dimensions,
permit,
)
.await
Expand Down Expand Up @@ -1156,6 +1159,7 @@ async fn openai_embed(
tokenizers::TruncationDirection::Right,
None,
true,
req.dimensions,
permit,
)
.await
Expand Down Expand Up @@ -1228,6 +1232,7 @@ async fn openai_embed(
tokenizers::TruncationDirection::Right,
None,
true,
req.dimensions,
permit,
)
.await
Expand Down
11 changes: 11 additions & 0 deletions router/src/http/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ pub(crate) struct OpenAICompatRequest {
#[schema(default = "float", example = "float")]
#[serde(default)]
pub encoding_format: EncodingFormat,
#[schema(default = "null", example = "null", nullable = true)]
pub dimensions: Option<usize>,
}

#[derive(Serialize, ToSchema)]
Expand Down Expand Up @@ -406,12 +408,15 @@ pub(crate) struct SimilarityResponse(pub Vec<f32>);
#[derive(Deserialize, ToSchema)]
pub(crate) struct EmbedRequest {
pub inputs: Input,

#[serde(default)]
#[schema(default = "false", example = "false", nullable = true)]
pub truncate: Option<bool>,

#[serde(default)]
#[schema(default = "right", example = "right")]
pub truncation_direction: TruncationDirection,

/// The name of the prompt that should be used by for encoding. If not set, no prompt
/// will be applied.
///
Expand All @@ -423,9 +428,15 @@ pub(crate) struct EmbedRequest {
/// any text to encode.
#[schema(default = "null", example = "null", nullable = true)]
pub prompt_name: Option<String>,

#[serde(default = "default_normalize")]
#[schema(default = "true", example = "true")]
pub normalize: bool,

/// The number of dimensions the resulting output embeddings should have. If not set, the orignal
/// shape of the representation will be returned.
#[schema(default = "null", example = "null", nullable = true)]
pub dimensions: Option<usize>,
}

fn default_normalize() -> bool {
Expand Down
Loading
Loading