Skip to content

Commit b9c4f3d

Browse files
optimizations
1 parent a6f8700 commit b9c4f3d

File tree

2 files changed

+115
-51
lines changed

2 files changed

+115
-51
lines changed

rust/src/embeddings/local/colpali_ort.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ impl OrtColPaliEmbedder {
101101
}
102102
}
103103

104-
fn tokenize_batch(tokenizer: &Tokenizer, text_batch: &[String]) -> Result<Array2<i64>, E> {
104+
fn tokenize_batch(tokenizer: &Tokenizer, text_batch: &[&str]) -> Result<Array2<i64>, E> {
105105
let token_ids = tokenizer
106106
.encode_batch_fast(text_batch.to_vec(), true)
107107
.map_err(E::msg)?
@@ -138,7 +138,7 @@ fn tokenize(tokenizer: &Tokenizer, text: String) -> Result<Array2<i64>, E> {
138138
Ok(token_ids_array)
139139
}
140140

141-
fn get_attention_mask(tokenizer: &Tokenizer, text_batch: &[String]) -> Result<Array2<i64>, E> {
141+
fn get_attention_mask(tokenizer: &Tokenizer, text_batch: &[&str]) -> Result<Array2<i64>, E> {
142142
let attention_mask = tokenizer
143143
.encode_batch(text_batch.to_vec(), true)
144144
.map_err(E::msg)?
@@ -195,7 +195,7 @@ impl OrtColPaliEmbedder {
195195
impl ColPaliEmbed for OrtColPaliEmbedder {
196196
fn embed(
197197
&self,
198-
text_batch: &[String],
198+
text_batch: &[&str],
199199
batch_size: Option<usize>,
200200
) -> Result<Vec<EmbeddingResult>, anyhow::Error> {
201201
let batch_size = batch_size.unwrap_or(32);
@@ -221,8 +221,8 @@ impl ColPaliEmbed for OrtColPaliEmbedder {
221221
}
222222

223223
fn embed_query(&self, query: &str) -> anyhow::Result<Vec<EmbedData>> {
224-
let token_ids = tokenize_batch(&self.tokenizer, &[query.to_string()])?;
225-
let attention_mask = get_attention_mask(&self.tokenizer, &[query.to_string()])?;
224+
let token_ids = tokenize_batch(&self.tokenizer, &[query])?;
225+
let attention_mask = get_attention_mask(&self.tokenizer, &[query])?;
226226
let pixel_values: Array4<f32> =
227227
Array4::zeros((1, self.num_channels, self.image_size, self.image_size));
228228
let e = self

rust/src/embeddings/local/ort_bert.rs

Lines changed: 110 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use super::bert::{BertEmbed, TokenizerConfig};
2-
use super::pooling::{ModelOutput, Pooling};
2+
use super::pooling::{ModelOutput, PooledOutputType, Pooling};
33
use super::text_embedding::ONNXModel;
44
use crate::embeddings::embed::EmbeddingResult;
55
use crate::embeddings::local::text_embedding::models_map;
66
use crate::embeddings::utils::{
7-
get_attention_mask_ndarray, get_type_ids_ndarray, tokenize_batch_ndarray,
7+
get_type_ids_ndarray, tokenize_batch_ndarray,
88
};
99

1010
use crate::Dtype;
@@ -15,7 +15,6 @@ use ndarray::prelude::*;
1515
use ort::execution_providers::{CUDAExecutionProvider, CoreMLExecutionProvider, ExecutionProvider};
1616
use ort::session::builder::GraphOptimizationLevel;
1717
use ort::session::Session;
18-
use ort::value::Value;
1918
use rayon::prelude::*;
2019
use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
2120

@@ -140,14 +139,22 @@ impl OrtBertEmbedder {
140139
println!("Session is using CUDAExecutionProvider");
141140
}
142141

143-
let threads = std::thread::available_parallelism().unwrap().get();
142+
// Get physical core count (excluding hyperthreading)
143+
let threads = std::thread::available_parallelism()
144+
.map(|p| p.get())
145+
.unwrap_or(1);
146+
// For CPU-bound workloads like ONNX inference, it's often better to use
147+
// physical cores rather than logical cores to avoid context switching overhead
148+
let optimal_threads = std::cmp::max(1, threads / 2);
149+
144150
let model = Session::builder()?
145151
.with_execution_providers([
146152
CUDAExecutionProvider::default().build(),
147153
CoreMLExecutionProvider::default().build(),
148154
])?
149155
.with_optimization_level(GraphOptimizationLevel::Level3)?
150-
.with_intra_threads(threads)?
156+
.with_intra_threads(optimal_threads)? // Use optimal thread count
157+
.with_inter_threads(1)? // Set inter-op parallelism to 1 when using GPU
151158
.commit_from_file(weights_filename)?;
152159

153160
Ok(OrtBertEmbedder {
@@ -161,55 +168,68 @@ impl OrtBertEmbedder {
161168
impl BertEmbed for OrtBertEmbedder {
162169
fn embed(
163170
&self,
164-
text_batch: &[String],
171+
text_batch: &[&str],
165172
batch_size: Option<usize>,
166173
) -> Result<Vec<EmbeddingResult>, E> {
167174
let batch_size = batch_size.unwrap_or(32);
175+
176+
// Pre-compute input names once
177+
let input_names: Vec<_> = self.model.inputs.iter().map(|input| input.name.as_str()).collect();
178+
let output_name = self.model.outputs.first().unwrap().name.as_str();
179+
let needs_token_type = input_names.iter().any(|&x| x == "token_type_ids");
180+
168181
let encodings = text_batch
169182
.par_chunks(batch_size)
170183
.flat_map(|mini_text_batch| -> Result<Vec<Vec<f32>>, E> {
171-
let input_ids: Array2<i64> =
172-
tokenize_batch_ndarray(&self.tokenizer, mini_text_batch)?;
173-
let token_type_ids: Array2<i64> = Array2::zeros(input_ids.raw_dim());
174-
let attention_mask: Array2<i64> = Array2::ones(input_ids.raw_dim());
175-
176-
let input_names = self
177-
.model
178-
.inputs
179-
.iter()
180-
.map(|input| input.name.as_str())
181-
.collect::<Vec<_>>();
182-
183-
let mut inputs =
184-
ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask]?;
185-
if input_names.iter().any(|&x| x == "token_type_ids") {
186-
inputs.push((
187-
"token_type_ids".into(),
188-
Value::from_array(token_type_ids.clone())?.into(),
189-
));
190-
}
184+
// Tokenize and prepare inputs
185+
let (input_ids, attention_mask) = tokenize_batch_ndarray(&self.tokenizer, mini_text_batch)?;
186+
187+
// Build inputs more efficiently
188+
let inputs = if needs_token_type {
189+
let token_type_ids = Array2::<i64>::zeros(input_ids.raw_dim());
190+
ort::inputs![
191+
"input_ids" => input_ids,
192+
"attention_mask" => attention_mask.clone(),
193+
"token_type_ids" => token_type_ids
194+
]?
195+
} else {
196+
ort::inputs![
197+
"input_ids" => input_ids,
198+
"attention_mask" => attention_mask.clone()
199+
]?
200+
};
201+
202+
// Run model and extract embeddings
191203
let outputs = self.model.run(inputs)?;
192-
let embeddings: Array3<f32> = outputs
193-
[self.model.outputs.first().unwrap().name.as_str()]
194-
.try_extract_tensor::<f32>()?
195-
.to_owned()
196-
.into_dimensionality::<ndarray::Ix3>()?;
197-
let (_, _, _) = embeddings.dim();
198-
let embeddings = self
199-
.pooling
200-
.pool(&ModelOutput::Array(embeddings))?
201-
.to_array()?;
204+
let embeddings: Array3<f32> = outputs[output_name]
205+
.try_extract_tensor()?
206+
.to_owned()
207+
.into_dimensionality()?;
208+
209+
// Prepare attention mask for pooling
210+
let attention_mask = if matches!(self.pooling, Pooling::Mean) {
211+
Some(PooledOutputType::from(attention_mask.mapv(|x| x as f32)))
212+
} else {
213+
None
214+
};
215+
216+
// Pool and normalize embeddings
217+
let model_output = ModelOutput::Array(embeddings);
218+
let pooled = self.pooling.pool(&model_output, attention_mask.as_ref())?;
219+
let embeddings = pooled.to_array()?;
220+
221+
// Normalize in one step
202222
let norms = embeddings.mapv(|x| x * x).sum_axis(Axis(1)).mapv(f32::sqrt);
203-
let embeddings = &embeddings / &norms.insert_axis(Axis(1));
223+
let normalized = embeddings / &norms.insert_axis(Axis(1));
204224

205-
Ok(embeddings.outer_iter().map(|row| row.to_vec()).collect())
225+
Ok(normalized.outer_iter().map(|row| row.to_vec()).collect())
206226
})
207227
.flatten()
208228
.collect::<Vec<_>>();
209229

210230
Ok(encodings
211-
.iter()
212-
.map(|x| EmbeddingResult::DenseVector(x.to_vec()))
231+
.into_iter() // Use into_iter since we don't need the original vector
232+
.map(|x| EmbeddingResult::DenseVector(x))
213233
.collect())
214234
}
215235
}
@@ -276,7 +296,7 @@ impl OrtSparseBertEmbedder {
276296
(Some(max_len), Some(model_max_len)) => std::cmp::min(max_len, model_max_len),
277297
(Some(max_len), None) => max_len,
278298
(None, Some(model_max_len)) => model_max_len,
279-
(None, None) => 128,
299+
(None, None) => 256,
280300
};
281301
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
282302
let pp = PaddingParams {
@@ -300,14 +320,22 @@ impl OrtSparseBertEmbedder {
300320
println!("Session is using CUDAExecutionProvider");
301321
}
302322

303-
let threads = std::thread::available_parallelism().unwrap().get();
323+
// Get physical core count (excluding hyperthreading)
324+
let threads = std::thread::available_parallelism()
325+
.map(|p| p.get())
326+
.unwrap_or(1);
327+
// For CPU-bound workloads like ONNX inference, it's often better to use
328+
// physical cores rather than logical cores to avoid context switching overhead
329+
let optimal_threads = std::cmp::max(1, threads / 2);
330+
304331
let model = Session::builder()?
305332
.with_execution_providers([
306333
CUDAExecutionProvider::default().build(),
307334
CoreMLExecutionProvider::default().build(),
308335
])?
309336
.with_optimization_level(GraphOptimizationLevel::Level3)?
310-
.with_intra_threads(threads)?
337+
.with_intra_threads(optimal_threads)? // Use optimal thread count
338+
.with_inter_threads(1)? // Set inter-op parallelism to 1 when using GPU
311339
.commit_from_file(weights_filename)?;
312340

313341
Ok(OrtSparseBertEmbedder { tokenizer, model })
@@ -317,14 +345,13 @@ impl OrtSparseBertEmbedder {
317345
impl BertEmbed for OrtSparseBertEmbedder {
318346
fn embed(
319347
&self,
320-
text_batch: &[String],
348+
text_batch: &[&str],
321349
batch_size: Option<usize>,
322350
) -> Result<Vec<EmbeddingResult>, anyhow::Error> {
323351
let batch_size = batch_size.unwrap_or(32);
324352
let encodings = text_batch.par_chunks(batch_size).flat_map(|mini_text_batch| -> Result<Vec<Vec<f32>>, E> {
325-
let token_ids: Array2<i64> = tokenize_batch_ndarray(&self.tokenizer, mini_text_batch)?;
353+
let (token_ids, attention_mask): (Array2<i64>, Array2<i64>) = tokenize_batch_ndarray(&self.tokenizer, mini_text_batch)?;
326354
let token_type_ids: Array2<i64> = get_type_ids_ndarray(&self.tokenizer, mini_text_batch)?;
327-
let attention_mask = get_attention_mask_ndarray(&self.tokenizer, mini_text_batch)?;
328355
let outputs = self.model.run(ort::inputs!["input_ids" => token_ids, "input_mask" => attention_mask.clone(), "segment_ids" => token_type_ids]?).unwrap();
329356
let embeddings: Array3<f32> = outputs["output"]
330357
.try_extract_tensor::<f32>()?
@@ -344,3 +371,40 @@ impl BertEmbed for OrtSparseBertEmbedder {
344371
.collect())
345372
}
346373
}
374+
375+
376+
#[cfg(test)]
377+
mod tests {
378+
379+
use super::*;
380+
381+
#[test]
382+
fn test_ort_bert_embed() {
383+
let embedder = OrtBertEmbedder::new(
384+
None,
385+
Some("sentence-transformers/all-MiniLM-L6-v2"),
386+
None,
387+
None,
388+
Some("onnx/model.onnx"),
389+
)
390+
.unwrap();
391+
let embeddings = embedder
392+
.embed(&["Hello, world!", "I am a rust programmer"], Some(32))
393+
.unwrap();
394+
println!("embeddings: {:?}", embeddings);
395+
396+
let test_embeddings: Vec<f32> = vec![
397+
-3.81771736e-02,
398+
3.29111032e-02,
399+
-5.45938499e-03,
400+
1.43699143e-02,
401+
];
402+
let embeddings = embeddings[0].to_dense().unwrap()[0..4].to_vec();
403+
assert!(
404+
(embeddings
405+
.iter()
406+
.zip(test_embeddings.iter())
407+
.all(|(a, b)| a.abs() - b.abs() < 1e-5))
408+
);
409+
}
410+
}

0 commit comments

Comments
 (0)