1
1
use super :: bert:: { BertEmbed , TokenizerConfig } ;
2
- use super :: pooling:: { ModelOutput , Pooling } ;
2
+ use super :: pooling:: { ModelOutput , PooledOutputType , Pooling } ;
3
3
use super :: text_embedding:: ONNXModel ;
4
4
use crate :: embeddings:: embed:: EmbeddingResult ;
5
5
use crate :: embeddings:: local:: text_embedding:: models_map;
6
6
use crate :: embeddings:: utils:: {
7
- get_attention_mask_ndarray , get_type_ids_ndarray, tokenize_batch_ndarray,
7
+ get_type_ids_ndarray, tokenize_batch_ndarray,
8
8
} ;
9
9
10
10
use crate :: Dtype ;
@@ -15,7 +15,6 @@ use ndarray::prelude::*;
15
15
use ort:: execution_providers:: { CUDAExecutionProvider , CoreMLExecutionProvider , ExecutionProvider } ;
16
16
use ort:: session:: builder:: GraphOptimizationLevel ;
17
17
use ort:: session:: Session ;
18
- use ort:: value:: Value ;
19
18
use rayon:: prelude:: * ;
20
19
use tokenizers:: { PaddingParams , Tokenizer , TruncationParams } ;
21
20
@@ -140,14 +139,22 @@ impl OrtBertEmbedder {
140
139
println ! ( "Session is using CUDAExecutionProvider" ) ;
141
140
}
142
141
143
- let threads = std:: thread:: available_parallelism ( ) . unwrap ( ) . get ( ) ;
142
+ // Get physical core count (excluding hyperthreading)
143
+ let threads = std:: thread:: available_parallelism ( )
144
+ . map ( |p| p. get ( ) )
145
+ . unwrap_or ( 1 ) ;
146
+ // For CPU-bound workloads like ONNX inference, it's often better to use
147
+ // physical cores rather than logical cores to avoid context switching overhead
148
+ let optimal_threads = std:: cmp:: max ( 1 , threads / 2 ) ;
149
+
144
150
let model = Session :: builder ( ) ?
145
151
. with_execution_providers ( [
146
152
CUDAExecutionProvider :: default ( ) . build ( ) ,
147
153
CoreMLExecutionProvider :: default ( ) . build ( ) ,
148
154
] ) ?
149
155
. with_optimization_level ( GraphOptimizationLevel :: Level3 ) ?
150
- . with_intra_threads ( threads) ?
156
+ . with_intra_threads ( optimal_threads) ? // Use optimal thread count
157
+ . with_inter_threads ( 1 ) ? // Set inter-op parallelism to 1 when using GPU
151
158
. commit_from_file ( weights_filename) ?;
152
159
153
160
Ok ( OrtBertEmbedder {
@@ -161,55 +168,68 @@ impl OrtBertEmbedder {
161
168
impl BertEmbed for OrtBertEmbedder {
162
169
fn embed (
163
170
& self ,
164
- text_batch : & [ String ] ,
171
+ text_batch : & [ & str ] ,
165
172
batch_size : Option < usize > ,
166
173
) -> Result < Vec < EmbeddingResult > , E > {
167
174
let batch_size = batch_size. unwrap_or ( 32 ) ;
175
+
176
+ // Pre-compute input names once
177
+ let input_names: Vec < _ > = self . model . inputs . iter ( ) . map ( |input| input. name . as_str ( ) ) . collect ( ) ;
178
+ let output_name = self . model . outputs . first ( ) . unwrap ( ) . name . as_str ( ) ;
179
+ let needs_token_type = input_names. iter ( ) . any ( |& x| x == "token_type_ids" ) ;
180
+
168
181
let encodings = text_batch
169
182
. par_chunks ( batch_size)
170
183
. flat_map ( |mini_text_batch| -> Result < Vec < Vec < f32 > > , E > {
171
- let input_ids: Array2 < i64 > =
172
- tokenize_batch_ndarray ( & self . tokenizer , mini_text_batch) ?;
173
- let token_type_ids: Array2 < i64 > = Array2 :: zeros ( input_ids. raw_dim ( ) ) ;
174
- let attention_mask: Array2 < i64 > = Array2 :: ones ( input_ids. raw_dim ( ) ) ;
175
-
176
- let input_names = self
177
- . model
178
- . inputs
179
- . iter ( )
180
- . map ( |input| input. name . as_str ( ) )
181
- . collect :: < Vec < _ > > ( ) ;
182
-
183
- let mut inputs =
184
- ort:: inputs![ "input_ids" => input_ids, "attention_mask" => attention_mask] ?;
185
- if input_names. iter ( ) . any ( |& x| x == "token_type_ids" ) {
186
- inputs. push ( (
187
- "token_type_ids" . into ( ) ,
188
- Value :: from_array ( token_type_ids. clone ( ) ) ?. into ( ) ,
189
- ) ) ;
190
- }
184
+ // Tokenize and prepare inputs
185
+ let ( input_ids, attention_mask) = tokenize_batch_ndarray ( & self . tokenizer , mini_text_batch) ?;
186
+
187
+ // Build inputs more efficiently
188
+ let inputs = if needs_token_type {
189
+ let token_type_ids = Array2 :: < i64 > :: zeros ( input_ids. raw_dim ( ) ) ;
190
+ ort:: inputs![
191
+ "input_ids" => input_ids,
192
+ "attention_mask" => attention_mask. clone( ) ,
193
+ "token_type_ids" => token_type_ids
194
+ ] ?
195
+ } else {
196
+ ort:: inputs![
197
+ "input_ids" => input_ids,
198
+ "attention_mask" => attention_mask. clone( )
199
+ ] ?
200
+ } ;
201
+
202
+ // Run model and extract embeddings
191
203
let outputs = self . model . run ( inputs) ?;
192
- let embeddings: Array3 < f32 > = outputs
193
- [ self . model . outputs . first ( ) . unwrap ( ) . name . as_str ( ) ]
194
- . try_extract_tensor :: < f32 > ( ) ?
195
- . to_owned ( )
196
- . into_dimensionality :: < ndarray:: Ix3 > ( ) ?;
197
- let ( _, _, _) = embeddings. dim ( ) ;
198
- let embeddings = self
199
- . pooling
200
- . pool ( & ModelOutput :: Array ( embeddings) ) ?
201
- . to_array ( ) ?;
204
+ let embeddings: Array3 < f32 > = outputs[ output_name]
205
+ . try_extract_tensor ( ) ?
206
+ . to_owned ( )
207
+ . into_dimensionality ( ) ?;
208
+
209
+ // Prepare attention mask for pooling
210
+ let attention_mask = if matches ! ( self . pooling, Pooling :: Mean ) {
211
+ Some ( PooledOutputType :: from ( attention_mask. mapv ( |x| x as f32 ) ) )
212
+ } else {
213
+ None
214
+ } ;
215
+
216
+ // Pool and normalize embeddings
217
+ let model_output = ModelOutput :: Array ( embeddings) ;
218
+ let pooled = self . pooling . pool ( & model_output, attention_mask. as_ref ( ) ) ?;
219
+ let embeddings = pooled. to_array ( ) ?;
220
+
221
+ // Normalize in one step
202
222
let norms = embeddings. mapv ( |x| x * x) . sum_axis ( Axis ( 1 ) ) . mapv ( f32:: sqrt) ;
203
- let embeddings = & embeddings / & norms. insert_axis ( Axis ( 1 ) ) ;
223
+ let normalized = embeddings / & norms. insert_axis ( Axis ( 1 ) ) ;
204
224
205
- Ok ( embeddings . outer_iter ( ) . map ( |row| row. to_vec ( ) ) . collect ( ) )
225
+ Ok ( normalized . outer_iter ( ) . map ( |row| row. to_vec ( ) ) . collect ( ) )
206
226
} )
207
227
. flatten ( )
208
228
. collect :: < Vec < _ > > ( ) ;
209
229
210
230
Ok ( encodings
211
- . iter ( )
212
- . map ( |x| EmbeddingResult :: DenseVector ( x. to_vec ( ) ) )
231
+ . into_iter ( ) // Use into_iter since we don't need the original vector
232
+ . map ( |x| EmbeddingResult :: DenseVector ( x) )
213
233
. collect ( ) )
214
234
}
215
235
}
@@ -276,7 +296,7 @@ impl OrtSparseBertEmbedder {
276
296
( Some ( max_len) , Some ( model_max_len) ) => std:: cmp:: min ( max_len, model_max_len) ,
277
297
( Some ( max_len) , None ) => max_len,
278
298
( None , Some ( model_max_len) ) => model_max_len,
279
- ( None , None ) => 128 ,
299
+ ( None , None ) => 256 ,
280
300
} ;
281
301
let mut tokenizer = Tokenizer :: from_file ( tokenizer_filename) . map_err ( E :: msg) ?;
282
302
let pp = PaddingParams {
@@ -300,14 +320,22 @@ impl OrtSparseBertEmbedder {
300
320
println ! ( "Session is using CUDAExecutionProvider" ) ;
301
321
}
302
322
303
- let threads = std:: thread:: available_parallelism ( ) . unwrap ( ) . get ( ) ;
323
+ // Get physical core count (excluding hyperthreading)
324
+ let threads = std:: thread:: available_parallelism ( )
325
+ . map ( |p| p. get ( ) )
326
+ . unwrap_or ( 1 ) ;
327
+ // For CPU-bound workloads like ONNX inference, it's often better to use
328
+ // physical cores rather than logical cores to avoid context switching overhead
329
+ let optimal_threads = std:: cmp:: max ( 1 , threads / 2 ) ;
330
+
304
331
let model = Session :: builder ( ) ?
305
332
. with_execution_providers ( [
306
333
CUDAExecutionProvider :: default ( ) . build ( ) ,
307
334
CoreMLExecutionProvider :: default ( ) . build ( ) ,
308
335
] ) ?
309
336
. with_optimization_level ( GraphOptimizationLevel :: Level3 ) ?
310
- . with_intra_threads ( threads) ?
337
+ . with_intra_threads ( optimal_threads) ? // Use optimal thread count
338
+ . with_inter_threads ( 1 ) ? // Set inter-op parallelism to 1 when using GPU
311
339
. commit_from_file ( weights_filename) ?;
312
340
313
341
Ok ( OrtSparseBertEmbedder { tokenizer, model } )
@@ -317,14 +345,13 @@ impl OrtSparseBertEmbedder {
317
345
impl BertEmbed for OrtSparseBertEmbedder {
318
346
fn embed (
319
347
& self ,
320
- text_batch : & [ String ] ,
348
+ text_batch : & [ & str ] ,
321
349
batch_size : Option < usize > ,
322
350
) -> Result < Vec < EmbeddingResult > , anyhow:: Error > {
323
351
let batch_size = batch_size. unwrap_or ( 32 ) ;
324
352
let encodings = text_batch. par_chunks ( batch_size) . flat_map ( |mini_text_batch| -> Result < Vec < Vec < f32 > > , E > {
325
- let token_ids: Array2 < i64 > = tokenize_batch_ndarray ( & self . tokenizer , mini_text_batch) ?;
353
+ let ( token_ids, attention_mask ) : ( Array2 < i64 > , Array2 < i64 > ) = tokenize_batch_ndarray ( & self . tokenizer , mini_text_batch) ?;
326
354
let token_type_ids: Array2 < i64 > = get_type_ids_ndarray ( & self . tokenizer , mini_text_batch) ?;
327
- let attention_mask = get_attention_mask_ndarray ( & self . tokenizer , mini_text_batch) ?;
328
355
let outputs = self . model . run ( ort:: inputs![ "input_ids" => token_ids, "input_mask" => attention_mask. clone( ) , "segment_ids" => token_type_ids] ?) . unwrap ( ) ;
329
356
let embeddings: Array3 < f32 > = outputs[ "output" ]
330
357
. try_extract_tensor :: < f32 > ( ) ?
@@ -344,3 +371,40 @@ impl BertEmbed for OrtSparseBertEmbedder {
344
371
. collect ( ) )
345
372
}
346
373
}
374
+
375
+
376
+ #[ cfg( test) ]
377
+ mod tests {
378
+
379
+ use super :: * ;
380
+
381
+ #[ test]
382
+ fn test_ort_bert_embed ( ) {
383
+ let embedder = OrtBertEmbedder :: new (
384
+ None ,
385
+ Some ( "sentence-transformers/all-MiniLM-L6-v2" ) ,
386
+ None ,
387
+ None ,
388
+ Some ( "onnx/model.onnx" ) ,
389
+ )
390
+ . unwrap ( ) ;
391
+ let embeddings = embedder
392
+ . embed ( & [ "Hello, world!" , "I am a rust programmer" ] , Some ( 32 ) )
393
+ . unwrap ( ) ;
394
+ println ! ( "embeddings: {:?}" , embeddings) ;
395
+
396
+ let test_embeddings: Vec < f32 > = vec ! [
397
+ -3.81771736e-02 ,
398
+ 3.29111032e-02 ,
399
+ -5.45938499e-03 ,
400
+ 1.43699143e-02 ,
401
+ ] ;
402
+ let embeddings = embeddings[ 0 ] . to_dense ( ) . unwrap ( ) [ 0 ..4 ] . to_vec ( ) ;
403
+ assert ! (
404
+ ( embeddings
405
+ . iter( )
406
+ . zip( test_embeddings. iter( ) )
407
+ . all( |( a, b) | a. abs( ) - b. abs( ) < 1e-5 ) )
408
+ ) ;
409
+ }
410
+ }
0 commit comments