Skip to content

Commit a059696

Browse files
feat: add embed_raw route to get all embeddings without pooling (#154)
1 parent 24533a0 commit a059696

33 files changed

+26845
-312
lines changed

Cargo.lock

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backends/candle/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ candle-flash-attn = { version = "^0.3", optional = true }
1515
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "d5b873e4555b7f460ed639d96f26cb014f2daad7", optional = true }
1616
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "c8a810ffe649c5f4634cbe1f0aaf02f6025fe5a5", optional = true }
1717
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "0dd5bdceb9ba7cded921c62f9ddd66e7726327ba", optional = true }
18+
nohash-hasher = "^0.2"
1819
text-embeddings-backend-core = { path = "../core" }
1920
tracing = "^0.1"
2021
safetensors = "^0.4"

backends/candle/src/lib.rs

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ use crate::models::{BertModel, JinaBertModel, Model, PositionEmbeddingType};
1818
use candle::{DType, Device};
1919
use candle_nn::VarBuilder;
2020
use models::Config;
21+
use nohash_hasher::BuildNoHashHasher;
22+
use std::collections::HashMap;
2123
use std::path::PathBuf;
22-
use text_embeddings_backend_core::{Backend, BackendError, Batch, Embedding, ModelType};
24+
use text_embeddings_backend_core::{
25+
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
26+
};
2327

2428
pub struct CandleBackend {
2529
model: Box<dyn Model + Send>,
@@ -148,16 +152,63 @@ impl Backend for CandleBackend {
148152
self.model.is_padded()
149153
}
150154

151-
fn embed(&self, batch: Batch) -> Result<Vec<Embedding>, BackendError> {
152-
let results = self.model.embed(batch).e()?;
153-
let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
154-
Ok(results)
155+
fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> {
156+
let batch_size = batch.len();
157+
let pooled_indices = batch.pooled_indices.clone();
158+
let raw_indices = batch.raw_indices.clone();
159+
160+
// Used for indexing in the raw_embeddings tensor
161+
let input_lengths: Vec<usize> = (0..batch.len())
162+
.map(|i| {
163+
(batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]) as usize
164+
})
165+
.collect();
166+
167+
// Run forward
168+
let (pooled_embeddings, raw_embeddings) = self.model.embed(batch).e()?;
169+
170+
// Device => Host data transfer
171+
let pooled_embeddings = match pooled_embeddings {
172+
None => vec![],
173+
Some(pooled_embeddings) => pooled_embeddings.to_dtype(DType::F32).e()?.to_vec2().e()?,
174+
};
175+
176+
// This transfer is expensive...
177+
let raw_embeddings = match raw_embeddings {
178+
None => vec![],
179+
Some(raw_embeddings) => raw_embeddings.to_dtype(DType::F32).e()?.to_vec2().e()?,
180+
};
181+
182+
let mut embeddings =
183+
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
184+
for (i, e) in pooled_indices.into_iter().zip(pooled_embeddings) {
185+
embeddings.insert(i as usize, Embedding::Pooled(e));
186+
}
187+
188+
let mut cumulative_length = 0;
189+
for i in raw_indices.into_iter() {
190+
let length = input_lengths[i as usize];
191+
let e = raw_embeddings[cumulative_length..cumulative_length + length].to_vec();
192+
embeddings.insert(i as usize, Embedding::All(e));
193+
cumulative_length += length;
194+
}
195+
196+
Ok(embeddings)
155197
}
156198

157-
fn predict(&self, batch: Batch) -> Result<Vec<Vec<f32>>, BackendError> {
199+
fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
200+
let batch_size = batch.len();
201+
158202
let results = self.model.predict(batch).e()?;
159203
let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
160-
Ok(results)
204+
205+
let mut predictions =
206+
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
207+
for (i, r) in results.into_iter().enumerate() {
208+
predictions.insert(i, r);
209+
}
210+
211+
Ok(predictions)
161212
}
162213
}
163214

backends/candle/src/models.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub use flash_jina::FlashJinaBertModel;
2727
pub(crate) trait Model {
2828
fn is_padded(&self) -> bool;
2929

30-
fn embed(&self, _batch: Batch) -> Result<Tensor> {
30+
fn embed(&self, _batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
3131
candle::bail!("`embed` is not implemented for this model");
3232
}
3333

backends/candle/src/models/bert.rs

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -515,10 +515,10 @@ impl BertModel {
515515
})
516516
}
517517

518-
pub fn forward(&self, batch: Batch) -> Result<Tensor> {
518+
pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
519519
let _enter = self.span.enter();
520520

521-
let batch_size = batch.cumulative_seq_lengths.len() - 1;
521+
let batch_size = batch.len();
522522
let max_length = batch.max_length as usize;
523523

524524
let shape = (batch_size, max_length);
@@ -634,25 +634,91 @@ impl BertModel {
634634
.embeddings
635635
.forward(&input_ids, &type_ids, &position_ids)?;
636636

637-
let mut outputs = self
637+
let outputs = self
638638
.encoder
639639
.forward(&embedding_output, attention_bias.as_ref())?;
640640

641-
let results = match self.pool {
642-
// CLS pooling
643-
Pool::Cls => outputs.i((.., 0))?,
644-
// Mean pooling
645-
Pool::Mean => {
646-
if let Some(attention_mask) = attention_mask {
647-
// Mask padded values
648-
outputs = outputs.broadcast_mul(&attention_mask)?;
641+
let has_pooling_requests = !batch.pooled_indices.is_empty();
642+
let has_raw_requests = !batch.raw_indices.is_empty();
643+
644+
let pooled_embeddings = if has_pooling_requests {
645+
let pooled_indices_length = batch.pooled_indices.len();
646+
let mut outputs = outputs.clone();
647+
648+
// Only use pooled_indices if at least one member of the batch ask for raw embeddings
649+
let pooled_indices = if has_raw_requests {
650+
let pooled_indices =
651+
Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?;
652+
653+
// Select values in the batch
654+
outputs = outputs.index_select(&pooled_indices, 0)?;
655+
Some(pooled_indices)
656+
} else {
657+
None
658+
};
659+
660+
let pooled_embeddings = match self.pool {
661+
// CLS pooling
662+
Pool::Cls => outputs.i((.., 0))?,
663+
// Mean pooling
664+
Pool::Mean => {
665+
if let Some(ref attention_mask) = attention_mask {
666+
let mut attention_mask = attention_mask.clone();
667+
668+
if let Some(pooled_indices) = pooled_indices {
669+
// Select values in the batch
670+
attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
671+
};
672+
673+
// Mask padded values
674+
outputs = outputs.broadcast_mul(&attention_mask)?;
675+
}
676+
677+
(outputs.sum(1)?.broadcast_div(&input_lengths))?
649678
}
679+
};
680+
Some(pooled_embeddings)
681+
} else {
682+
None
683+
};
650684

651-
(outputs.sum(1)?.broadcast_div(&input_lengths))?
685+
let raw_embeddings = if has_raw_requests {
686+
// Reshape outputs
687+
let (b, l, h) = outputs.shape().dims3()?;
688+
let outputs = outputs.reshape((b * l, h))?;
689+
690+
// We need to remove the padding tokens only if batch_size > 1 and there are some
691+
// member of the batch that require pooling
692+
// or if batch_size > 1 and the members of the batch have different lengths
693+
if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 {
694+
let mut final_indices: Vec<u32> = Vec::with_capacity(batch_size * max_length);
695+
696+
for i in batch.raw_indices.into_iter() {
697+
let start = i * batch.max_length;
698+
let i = i as usize;
699+
let length =
700+
batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i];
701+
702+
for j in start..start + length {
703+
// Add indices for the tokens of this specific member of the batch
704+
final_indices.push(j);
705+
}
706+
}
707+
708+
let final_indices_length = final_indices.len();
709+
let final_indices =
710+
Tensor::from_vec(final_indices, final_indices_length, &self.device)?;
711+
712+
// Select the tokens with final indices
713+
Some(outputs.index_select(&final_indices, 0)?)
714+
} else {
715+
Some(outputs)
652716
}
717+
} else {
718+
None
653719
};
654720

655-
Ok(results)
721+
Ok((pooled_embeddings, raw_embeddings))
656722
}
657723
}
658724

@@ -661,16 +727,18 @@ impl Model for BertModel {
661727
true
662728
}
663729

664-
fn embed(&self, batch: Batch) -> Result<Tensor> {
730+
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
665731
self.forward(batch)
666732
}
667733

668734
fn predict(&self, batch: Batch) -> Result<Tensor> {
669735
match &self.classifier {
670736
None => candle::bail!("`predict` is not implemented for this model"),
671737
Some(classifier) => {
672-
let hidden_states = self.forward(batch)?;
673-
classifier.forward(&hidden_states)
738+
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
739+
let pooled_embeddings =
740+
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
741+
classifier.forward(&pooled_embeddings)
674742
}
675743
}
676744
}

0 commit comments

Comments
 (0)