@@ -7,7 +7,7 @@ use std::ptr;
7
7
use std:: time:: { Duration , Instant } ;
8
8
use text_embeddings_backend:: { BackendError , Batch } ;
9
9
use tokio:: sync:: oneshot;
10
- use tracing:: { info_span , instrument, Span } ;
10
+ use tracing:: { instrument, Span } ;
11
11
12
12
/// Queue entry
13
13
#[ derive( Debug ) ]
@@ -51,8 +51,9 @@ impl Queue {
51
51
// Create channels
52
52
let ( queue_sender, queue_receiver) = flume:: unbounded ( ) ;
53
53
54
+ // Launch background queue task
54
55
tokio:: task:: spawn_blocking ( move || {
55
- queue_task (
56
+ queue_blocking_task (
56
57
max_batch_tokens,
57
58
max_batch_requests,
58
59
max_concurrent_requests,
@@ -95,7 +96,7 @@ impl Queue {
95
96
}
96
97
97
98
// Background task responsible of the queue state
98
- fn queue_task (
99
+ fn queue_blocking_task (
99
100
max_batch_tokens : usize ,
100
101
max_batch_requests : Option < usize > ,
101
102
max_concurrent_requests : usize ,
@@ -118,12 +119,12 @@ fn queue_task(
118
119
} => unsafe {
119
120
let _span = span. entered ( ) ;
120
121
121
- let mut metadata = Vec :: with_capacity ( capacity) ;
122
-
122
+ // Allocate raw memory
123
123
let raw_input_ids = raw_u32_vec ( max_batch_tokens) ;
124
124
let raw_token_type_ids = raw_u32_vec ( max_batch_tokens) ;
125
125
let raw_position_ids = raw_u32_vec ( max_batch_tokens) ;
126
126
127
+ let mut metadata = Vec :: with_capacity ( capacity) ;
127
128
let mut cu_seq_lengths = Vec :: with_capacity ( capacity) ;
128
129
cu_seq_lengths. push ( 0 ) ;
129
130
@@ -151,45 +152,33 @@ fn queue_task(
151
152
152
153
entry. metadata . batch_time = Some ( batch_time) ;
153
154
154
- {
155
- let _span = info_span ! ( "extend" ) . entered ( ) ;
156
-
157
- ptr:: copy (
158
- entry. encoding . input_ids . as_mut_ptr ( ) ,
159
- raw_input_ids. add ( current_tokens) ,
160
- entry. encoding . input_ids . len ( ) ,
161
- ) ;
162
- ptr:: copy (
163
- entry. encoding . token_type_ids . as_mut_ptr ( ) ,
164
- raw_token_type_ids. add ( current_tokens) ,
165
- entry. encoding . token_type_ids . len ( ) ,
166
- ) ;
167
- ptr:: copy (
168
- entry. encoding . position_ids . as_mut_ptr ( ) ,
169
- raw_position_ids. add ( current_tokens) ,
170
- entry. encoding . position_ids . len ( ) ,
171
- ) ;
172
-
173
- // input_ids.extend_from_slice(entry.encoding.input_ids.as_slice());
174
- // token_type_ids.extend_from_slice(entry.encoding.token_type_ids.as_slice());
175
- // position_ids.extend_from_slice(entry.encoding.position_ids.as_slice());
176
-
177
- // for i in 0..entry.encoding.input_ids.len() {
178
- // input_ids.push(entry.encoding.input_ids[i]);
179
- // token_type_ids.push(entry.encoding.token_type_ids[i]);
180
- // position_ids.push(entry.encoding.position_ids[i]);
181
- // }
182
-
183
- current_tokens += entry_tokens;
184
- metadata. push ( entry. metadata ) ;
185
- cu_seq_lengths. push ( current_tokens as u32 ) ;
186
- }
155
+ // Copy memory to the correct spot in the raw vectors
156
+ ptr:: copy (
157
+ entry. encoding . input_ids . as_mut_ptr ( ) ,
158
+ raw_input_ids. add ( current_tokens) ,
159
+ entry. encoding . input_ids . len ( ) ,
160
+ ) ;
161
+ ptr:: copy (
162
+ entry. encoding . token_type_ids . as_mut_ptr ( ) ,
163
+ raw_token_type_ids. add ( current_tokens) ,
164
+ entry. encoding . token_type_ids . len ( ) ,
165
+ ) ;
166
+ ptr:: copy (
167
+ entry. encoding . position_ids . as_mut_ptr ( ) ,
168
+ raw_position_ids. add ( current_tokens) ,
169
+ entry. encoding . position_ids . len ( ) ,
170
+ ) ;
171
+
172
+ current_tokens += entry_tokens;
173
+ metadata. push ( entry. metadata ) ;
174
+ cu_seq_lengths. push ( current_tokens as u32 ) ;
187
175
188
176
if Some ( metadata. len ( ) ) == max_batch_requests {
189
177
break ;
190
178
}
191
179
}
192
180
181
+ // Create final vectors from raw memory
193
182
let input_ids =
194
183
Vec :: from_raw_parts ( raw_input_ids, current_tokens, max_batch_tokens) ;
195
184
let token_type_ids =
0 commit comments