1
1
use crate :: infer:: InferResponse ;
2
2
use crate :: tokenization:: Encoding ;
3
+ use std:: alloc:: { alloc, Layout } ;
3
4
use std:: cmp:: max;
4
5
use std:: collections:: VecDeque ;
6
+ use std:: ptr;
5
7
use std:: time:: { Duration , Instant } ;
6
8
use text_embeddings_backend:: { BackendError , Batch } ;
7
9
use tokio:: sync:: oneshot;
@@ -23,8 +25,6 @@ pub struct Metadata {
23
25
pub response_tx : oneshot:: Sender < Result < InferResponse , BackendError > > ,
24
26
/// Span that will live as long as entry
25
27
pub span : Span ,
26
- /// Temporary span used as a guard when logging inference, wait times...
27
- pub temp_span : Option < Span > ,
28
28
/// Tokenization duration
29
29
pub tokenization : Duration ,
30
30
/// Instant when this entry was queued
@@ -43,16 +43,22 @@ pub struct Queue {
43
43
}
44
44
45
45
impl Queue {
46
- pub fn new ( max_batch_tokens : usize , max_batch_requests : Option < usize > ) -> Self {
46
+ pub fn new (
47
+ max_batch_tokens : usize ,
48
+ max_batch_requests : Option < usize > ,
49
+ max_concurrent_requests : usize ,
50
+ ) -> Self {
47
51
// Create channels
48
52
let ( queue_sender, queue_receiver) = flume:: unbounded ( ) ;
49
53
50
- // Launch background queue task
51
- tokio:: spawn ( queue_task (
52
- max_batch_tokens,
53
- max_batch_requests,
54
- queue_receiver,
55
- ) ) ;
54
+ tokio:: task:: spawn_blocking ( move || {
55
+ queue_task (
56
+ max_batch_tokens,
57
+ max_batch_requests,
58
+ max_concurrent_requests,
59
+ queue_receiver,
60
+ )
61
+ } ) ;
56
62
57
63
Self { queue_sender }
58
64
}
@@ -89,16 +95,17 @@ impl Queue {
89
95
}
90
96
91
97
// Background task responsible of the queue state
92
- async fn queue_task (
98
+ fn queue_task (
93
99
max_batch_tokens : usize ,
94
100
max_batch_requests : Option < usize > ,
101
+ max_concurrent_requests : usize ,
95
102
queue_receiver : flume:: Receiver < QueueCommand > ,
96
103
) {
97
- let capacity = max_batch_requests. unwrap_or ( 512 ) ;
104
+ let capacity = max_batch_requests. unwrap_or ( max_concurrent_requests ) ;
98
105
99
- let mut entries: VecDeque < Entry > = VecDeque :: with_capacity ( 512 ) ;
106
+ let mut entries: VecDeque < Entry > = VecDeque :: with_capacity ( max_concurrent_requests ) ;
100
107
101
- while let Ok ( cmd) = queue_receiver. recv_async ( ) . await {
108
+ while let Ok ( cmd) = queue_receiver. recv ( ) {
102
109
match cmd {
103
110
QueueCommand :: Append ( entry, span) => {
104
111
let _span = span. entered ( ) ;
@@ -108,23 +115,23 @@ async fn queue_task(
108
115
QueueCommand :: NextBatch {
109
116
response_sender,
110
117
span,
111
- } => {
118
+ } => unsafe {
112
119
let _span = span. entered ( ) ;
113
120
114
- let next_batch_span = info_span ! ( parent: None , "batch" , batch_size = tracing:: field:: Empty , tokens = tracing:: field:: Empty ) ;
115
- next_batch_span. follows_from ( Span :: current ( ) ) ;
116
-
117
121
let mut metadata = Vec :: with_capacity ( capacity) ;
118
122
119
- let mut input_ids = Vec :: with_capacity ( max_batch_tokens) ;
120
- let mut token_type_ids = Vec :: with_capacity ( max_batch_tokens) ;
121
- let mut position_ids = Vec :: with_capacity ( max_batch_tokens) ;
123
+ let raw_input_ids = raw_u32_vec ( max_batch_tokens) ;
124
+ let raw_token_type_ids = raw_u32_vec ( max_batch_tokens) ;
125
+ let raw_position_ids = raw_u32_vec ( max_batch_tokens) ;
126
+
122
127
let mut cu_seq_lengths = Vec :: with_capacity ( capacity) ;
123
128
cu_seq_lengths. push ( 0 ) ;
124
129
125
130
let mut current_tokens = 0 ;
126
131
let mut max_length = 0 ;
127
132
133
+ let batch_time = Instant :: now ( ) ;
134
+
128
135
while let Some ( mut entry) = entries. pop_front ( ) {
129
136
// Filter entries where the response receiver was dropped (== entries where the request
130
137
// was dropped by the client)
@@ -141,37 +148,59 @@ async fn queue_task(
141
148
}
142
149
143
150
max_length = max ( max_length, entry_tokens as u32 ) ;
144
- current_tokens += entry_tokens;
145
-
146
- // Create a new span to link the batch back to this entry
147
- let entry_batch_span = info_span ! ( parent: & entry. metadata. span, "infer" ) ;
148
- // Add relationships
149
- next_batch_span. follows_from ( & entry_batch_span) ;
150
- entry_batch_span. follows_from ( & next_batch_span) ;
151
151
152
- entry. metadata . batch_time = Some ( Instant :: now ( ) ) ;
153
- entry. metadata . temp_span = Some ( entry_batch_span) ;
154
-
155
- metadata. push ( entry. metadata ) ;
156
- input_ids. extend ( entry. encoding . input_ids ) ;
157
- token_type_ids. extend ( entry. encoding . token_type_ids ) ;
158
- position_ids. extend ( entry. encoding . position_ids ) ;
159
- cu_seq_lengths. push ( current_tokens as u32 ) ;
152
+ entry. metadata . batch_time = Some ( batch_time) ;
153
+
154
+ {
155
+ let _span = info_span ! ( "extend" ) . entered ( ) ;
156
+
157
+ ptr:: copy (
158
+ entry. encoding . input_ids . as_mut_ptr ( ) ,
159
+ raw_input_ids. add ( current_tokens) ,
160
+ entry. encoding . input_ids . len ( ) ,
161
+ ) ;
162
+ ptr:: copy (
163
+ entry. encoding . token_type_ids . as_mut_ptr ( ) ,
164
+ raw_token_type_ids. add ( current_tokens) ,
165
+ entry. encoding . token_type_ids . len ( ) ,
166
+ ) ;
167
+ ptr:: copy (
168
+ entry. encoding . position_ids . as_mut_ptr ( ) ,
169
+ raw_position_ids. add ( current_tokens) ,
170
+ entry. encoding . position_ids . len ( ) ,
171
+ ) ;
172
+
173
+ // input_ids.extend_from_slice(entry.encoding.input_ids.as_slice());
174
+ // token_type_ids.extend_from_slice(entry.encoding.token_type_ids.as_slice());
175
+ // position_ids.extend_from_slice(entry.encoding.position_ids.as_slice());
176
+
177
+ // for i in 0..entry.encoding.input_ids.len() {
178
+ // input_ids.push(entry.encoding.input_ids[i]);
179
+ // token_type_ids.push(entry.encoding.token_type_ids[i]);
180
+ // position_ids.push(entry.encoding.position_ids[i]);
181
+ // }
182
+
183
+ current_tokens += entry_tokens;
184
+ metadata. push ( entry. metadata ) ;
185
+ cu_seq_lengths. push ( current_tokens as u32 ) ;
186
+ }
160
187
161
188
if Some ( metadata. len ( ) ) == max_batch_requests {
162
189
break ;
163
190
}
164
191
}
165
192
193
+ let input_ids =
194
+ Vec :: from_raw_parts ( raw_input_ids, current_tokens, max_batch_tokens) ;
195
+ let token_type_ids =
196
+ Vec :: from_raw_parts ( raw_token_type_ids, current_tokens, max_batch_tokens) ;
197
+ let position_ids =
198
+ Vec :: from_raw_parts ( raw_position_ids, current_tokens, max_batch_tokens) ;
199
+
200
+ let batch_size = metadata. len ( ) ;
166
201
let next_batch = if metadata. is_empty ( ) {
167
202
None
168
203
} else {
169
- next_batch_span. record ( "batch_size" , metadata. len ( ) as u32 ) ;
170
- next_batch_span. record ( "tokens" , current_tokens as u32 ) ;
171
-
172
- metrics:: histogram!( "te_batch_next_size" , metadata. len( ) as f64 ) ;
173
- metrics:: histogram!( "te_batch_next_tokens" , current_tokens as f64 ) ;
174
-
175
204
Some ( (
176
205
metadata,
177
206
Batch {
@@ -181,18 +210,25 @@ async fn queue_task(
181
210
cumulative_seq_lengths : cu_seq_lengths,
182
211
max_length,
183
212
} ,
184
- next_batch_span,
185
213
) )
186
214
} ;
187
215
188
216
let _ = response_sender. send ( next_batch) ;
217
+
218
+ metrics:: histogram!( "te_batch_next_size" , batch_size as f64 ) ;
219
+ metrics:: histogram!( "te_batch_next_tokens" , current_tokens as f64 ) ;
189
220
metrics:: gauge!( "te_queue_size" , entries. len( ) as f64 ) ;
190
- }
221
+ } ,
191
222
}
192
223
}
193
224
}
194
225
195
- type NextBatch = ( Vec < Metadata > , Batch , Span ) ;
226
+ unsafe fn raw_u32_vec ( capacity : usize ) -> * mut u32 {
227
+ let layout = Layout :: array :: < u32 > ( capacity) . unwrap ( ) ;
228
+ alloc ( layout) . cast :: < u32 > ( )
229
+ }
230
+
231
+ type NextBatch = ( Vec < Metadata > , Batch ) ;
196
232
197
233
#[ derive( Debug ) ]
198
234
enum QueueCommand {
0 commit comments