Skip to content

Commit 45df4fa

Browse files
kozistralvarobartt
andauthored
Support MRL (Matryoshka Representation Learning) (#676)
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
1 parent 6e900af commit 45df4fa

14 files changed

+628
-3929
lines changed

core/src/infer.rs

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,16 @@ impl Infer {
151151
panic!("unexpected enum variant")
152152
};
153153

154-
// Timings
155154
let total_time = start_time.elapsed();
156155

157-
// Metrics
158-
let counter = metrics::counter!("te_embed_success");
159-
counter.increment(1);
160-
let histogram = metrics::histogram!("te_embed_duration");
161-
histogram.record(total_time.as_secs_f64());
162-
let histogram = metrics::histogram!("te_embed_tokenization_duration");
163-
histogram.record(response.metadata.tokenization.as_secs_f64());
164-
let histogram = metrics::histogram!("te_embed_queue_duration");
165-
histogram.record(response.metadata.queue.as_secs_f64());
166-
let histogram = metrics::histogram!("te_embed_inference_duration");
167-
histogram.record(response.metadata.inference.as_secs_f64());
156+
metrics::counter!("te_embed_success").increment(1);
157+
metrics::histogram!("te_embed_duration").record(total_time.as_secs_f64());
158+
metrics::histogram!("te_embed_tokenization_duration")
159+
.record(response.metadata.tokenization.as_secs_f64());
160+
metrics::histogram!("te_embed_queue_duration")
161+
.record(response.metadata.queue.as_secs_f64());
162+
metrics::histogram!("te_embed_inference_duration")
163+
.record(response.metadata.inference.as_secs_f64());
168164

169165
Ok(response)
170166
}
@@ -224,6 +220,7 @@ impl Infer {
224220
Ok(response)
225221
}
226222

223+
#[allow(clippy::too_many_arguments)]
227224
#[instrument(skip(self, inputs, permit))]
228225
pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>(
229226
&self,
@@ -232,20 +229,31 @@ impl Infer {
232229
truncation_direction: TruncationDirection,
233230
prompt_name: Option<String>,
234231
normalize: bool,
232+
dimensions: Option<usize>,
235233
permit: OwnedSemaphorePermit,
236234
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
237235
let start_time = Instant::now();
238236

239237
if self.is_splade() && normalize {
240238
let counter = metrics::counter!("te_request_failure", "err" => "model_type");
241239
counter.increment(1);
240+
242241
let message = "`normalize` is not available for SPLADE models".to_string();
243242
tracing::error!("{message}");
244243
return Err(TextEmbeddingsError::Backend(BackendError::Inference(
245244
message,
246245
)));
247246
}
248247

248+
if let Some(dimensions) = dimensions {
249+
if dimensions == 0 {
250+
metrics::counter!("te_request_failure", "err" => "validation").increment(1);
251+
let message = "`dimensions` should be positive".to_string();
252+
tracing::error!("{message}");
253+
return Err(TextEmbeddingsError::Validation(message));
254+
}
255+
}
256+
249257
let results = self
250258
.embed(
251259
inputs,
@@ -262,6 +270,21 @@ impl Infer {
262270
panic!("unexpected enum variant")
263271
};
264272

273+
if let Some(mrl_dimensions) = dimensions {
274+
if mrl_dimensions > response.results.len() {
275+
metrics::counter!("te_request_failure", "err" => "validation").increment(1);
276+
277+
let message =
278+
"`dimensions` should be smaller than the maximum embedding dimension."
279+
.to_string();
280+
tracing::error!("{message}");
281+
282+
return Err(TextEmbeddingsError::Validation(message));
283+
}
284+
285+
response.results.truncate(mrl_dimensions);
286+
}
287+
265288
if normalize {
266289
// Normalize embedding
267290
let scale = (1.0

proto/tei.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ message EmbedRequest {
8080
bool normalize = 3;
8181
TruncationDirection truncation_direction = 4;
8282
optional string prompt_name = 5;
83+
optional uint32 dimensions = 6;
8384
}
8485

8586
message EmbedResponse {

router/src/grpc/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ impl TextEmbeddingsService {
9191
truncation_direction,
9292
request.prompt_name,
9393
request.normalize,
94+
request.dimensions.map(|v| v as usize),
9495
permit,
9596
)
9697
.await

router/src/http/server.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ async fn similarity(
544544
truncation_direction: parameters.truncation_direction,
545545
prompt_name: parameters.prompt_name,
546546
normalize: false,
547+
dimensions: None,
547548
};
548549

549550
// Get embeddings
@@ -611,6 +612,7 @@ async fn embed(
611612
req.truncation_direction.into(),
612613
req.prompt_name,
613614
req.normalize,
615+
req.dimensions,
614616
permit,
615617
)
616618
.await
@@ -679,6 +681,7 @@ async fn embed(
679681
req.truncation_direction.into(),
680682
prompt_name,
681683
req.normalize,
684+
req.dimensions,
682685
permit,
683686
)
684687
.await
@@ -1156,6 +1159,7 @@ async fn openai_embed(
11561159
tokenizers::TruncationDirection::Right,
11571160
None,
11581161
true,
1162+
req.dimensions,
11591163
permit,
11601164
)
11611165
.await
@@ -1228,6 +1232,7 @@ async fn openai_embed(
12281232
tokenizers::TruncationDirection::Right,
12291233
None,
12301234
true,
1235+
req.dimensions,
12311236
permit,
12321237
)
12331238
.await

router/src/http/types.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ pub(crate) struct OpenAICompatRequest {
323323
#[schema(default = "float", example = "float")]
324324
#[serde(default)]
325325
pub encoding_format: EncodingFormat,
326+
#[schema(default = "null", example = "null", nullable = true)]
327+
pub dimensions: Option<usize>,
326328
}
327329

328330
#[derive(Serialize, ToSchema)]
@@ -406,12 +408,15 @@ pub(crate) struct SimilarityResponse(pub Vec<f32>);
406408
#[derive(Deserialize, ToSchema)]
407409
pub(crate) struct EmbedRequest {
408410
pub inputs: Input,
411+
409412
#[serde(default)]
410413
#[schema(default = "false", example = "false", nullable = true)]
411414
pub truncate: Option<bool>,
415+
412416
#[serde(default)]
413417
#[schema(default = "right", example = "right")]
414418
pub truncation_direction: TruncationDirection,
419+
415420
/// The name of the prompt that should be used by for encoding. If not set, no prompt
416421
/// will be applied.
417422
///
@@ -423,9 +428,15 @@ pub(crate) struct EmbedRequest {
423428
/// any text to encode.
424429
#[schema(default = "null", example = "null", nullable = true)]
425430
pub prompt_name: Option<String>,
431+
426432
#[serde(default = "default_normalize")]
427433
#[schema(default = "true", example = "true")]
428434
pub normalize: bool,
435+
436+
/// The number of dimensions that the output embeddings should have. If not set, the original
437+
/// shape of the representation will be returned instead.
438+
#[schema(default = "null", example = "null", nullable = true)]
439+
pub dimensions: Option<usize>,
429440
}
430441

431442
fn default_normalize() -> bool {

0 commit comments

Comments
 (0)