1
1
use crate :: infer:: InferResponse ;
2
2
use crate :: tokenization:: Encoding ;
3
+ use std:: alloc:: { alloc, Layout } ;
3
4
use std:: cmp:: max;
4
5
use std:: collections:: VecDeque ;
6
+ use std:: ptr;
5
7
use std:: time:: { Duration , Instant } ;
6
8
use text_embeddings_backend:: { BackendError , Batch } ;
7
9
use tokio:: sync:: oneshot;
8
- use tracing:: { info_span , instrument, Span } ;
10
+ use tracing:: { instrument, Span } ;
9
11
10
12
/// Queue entry
11
13
#[ derive( Debug ) ]
@@ -23,8 +25,6 @@ pub struct Metadata {
23
25
pub response_tx : oneshot:: Sender < Result < InferResponse , BackendError > > ,
24
26
/// Span that will live as long as entry
25
27
pub span : Span ,
26
- /// Temporary span used as a guard when logging inference, wait times...
27
- pub temp_span : Option < Span > ,
28
28
/// Tokenization duration
29
29
pub tokenization : Duration ,
30
30
/// Instant when this entry was queued
@@ -43,16 +43,23 @@ pub struct Queue {
43
43
}
44
44
45
45
impl Queue {
46
- pub fn new ( max_batch_tokens : usize , max_batch_requests : Option < usize > ) -> Self {
46
+ pub fn new (
47
+ max_batch_tokens : usize ,
48
+ max_batch_requests : Option < usize > ,
49
+ max_concurrent_requests : usize ,
50
+ ) -> Self {
47
51
// Create channels
48
52
let ( queue_sender, queue_receiver) = flume:: unbounded ( ) ;
49
53
50
54
// Launch background queue task
51
- tokio:: spawn ( queue_task (
52
- max_batch_tokens,
53
- max_batch_requests,
54
- queue_receiver,
55
- ) ) ;
55
+ tokio:: task:: spawn_blocking ( move || {
56
+ queue_blocking_task (
57
+ max_batch_tokens,
58
+ max_batch_requests,
59
+ max_concurrent_requests,
60
+ queue_receiver,
61
+ )
62
+ } ) ;
56
63
57
64
Self { queue_sender }
58
65
}
@@ -89,16 +96,17 @@ impl Queue {
89
96
}
90
97
91
98
// Background task responsible of the queue state
92
- async fn queue_task (
99
+ fn queue_blocking_task (
93
100
max_batch_tokens : usize ,
94
101
max_batch_requests : Option < usize > ,
102
+ max_concurrent_requests : usize ,
95
103
queue_receiver : flume:: Receiver < QueueCommand > ,
96
104
) {
97
- let capacity = max_batch_requests. unwrap_or ( 512 ) ;
105
+ let capacity = max_batch_requests. unwrap_or ( max_concurrent_requests ) ;
98
106
99
- let mut entries: VecDeque < Entry > = VecDeque :: with_capacity ( 512 ) ;
107
+ let mut entries: VecDeque < Entry > = VecDeque :: with_capacity ( max_concurrent_requests ) ;
100
108
101
- while let Ok ( cmd) = queue_receiver. recv_async ( ) . await {
109
+ while let Ok ( cmd) = queue_receiver. recv ( ) {
102
110
match cmd {
103
111
QueueCommand :: Append ( entry, span) => {
104
112
let _span = span. entered ( ) ;
@@ -108,23 +116,23 @@ async fn queue_task(
108
116
QueueCommand :: NextBatch {
109
117
response_sender,
110
118
span,
111
- } => {
119
+ } => unsafe {
112
120
let _span = span. entered ( ) ;
113
121
114
- let next_batch_span = info_span ! ( parent: None , "batch" , batch_size = tracing:: field:: Empty , tokens = tracing:: field:: Empty ) ;
115
- next_batch_span. follows_from ( Span :: current ( ) ) ;
122
+ // Allocate raw memory
123
+ let raw_input_ids = raw_u32_vec ( max_batch_tokens) ;
124
+ let raw_token_type_ids = raw_u32_vec ( max_batch_tokens) ;
125
+ let raw_position_ids = raw_u32_vec ( max_batch_tokens) ;
116
126
117
127
let mut metadata = Vec :: with_capacity ( capacity) ;
118
-
119
- let mut input_ids = Vec :: with_capacity ( max_batch_tokens) ;
120
- let mut token_type_ids = Vec :: with_capacity ( max_batch_tokens) ;
121
- let mut position_ids = Vec :: with_capacity ( max_batch_tokens) ;
122
128
let mut cu_seq_lengths = Vec :: with_capacity ( capacity) ;
123
129
cu_seq_lengths. push ( 0 ) ;
124
130
125
131
let mut current_tokens = 0 ;
126
132
let mut max_length = 0 ;
127
133
134
+ let batch_time = Instant :: now ( ) ;
135
+
128
136
while let Some ( mut entry) = entries. pop_front ( ) {
129
137
// Filter entries where the response receiver was dropped (== entries where the request
130
138
// was dropped by the client)
@@ -141,37 +149,47 @@ async fn queue_task(
141
149
}
142
150
143
151
max_length = max ( max_length, entry_tokens as u32 ) ;
144
- current_tokens += entry_tokens;
145
152
146
- // Create a new span to link the batch back to this entry
147
- let entry_batch_span = info_span ! ( parent: & entry. metadata. span, "infer" ) ;
148
- // Add relationships
149
- next_batch_span. follows_from ( & entry_batch_span) ;
150
- entry_batch_span. follows_from ( & next_batch_span) ;
151
-
152
- entry. metadata . batch_time = Some ( Instant :: now ( ) ) ;
153
- entry. metadata . temp_span = Some ( entry_batch_span) ;
153
+ entry. metadata . batch_time = Some ( batch_time) ;
154
+
155
+ // Copy memory to the correct spot in the raw vectors
156
+ ptr:: copy (
157
+ entry. encoding . input_ids . as_mut_ptr ( ) ,
158
+ raw_input_ids. add ( current_tokens) ,
159
+ entry. encoding . input_ids . len ( ) ,
160
+ ) ;
161
+ ptr:: copy (
162
+ entry. encoding . token_type_ids . as_mut_ptr ( ) ,
163
+ raw_token_type_ids. add ( current_tokens) ,
164
+ entry. encoding . token_type_ids . len ( ) ,
165
+ ) ;
166
+ ptr:: copy (
167
+ entry. encoding . position_ids . as_mut_ptr ( ) ,
168
+ raw_position_ids. add ( current_tokens) ,
169
+ entry. encoding . position_ids . len ( ) ,
170
+ ) ;
154
171
172
+ current_tokens += entry_tokens;
155
173
metadata. push ( entry. metadata ) ;
156
- input_ids. extend ( entry. encoding . input_ids ) ;
157
- token_type_ids. extend ( entry. encoding . token_type_ids ) ;
158
- position_ids. extend ( entry. encoding . position_ids ) ;
159
174
cu_seq_lengths. push ( current_tokens as u32 ) ;
160
175
161
176
if Some ( metadata. len ( ) ) == max_batch_requests {
162
177
break ;
163
178
}
164
179
}
165
180
181
+ // Create final vectors from raw memory
182
+ let input_ids =
183
+ Vec :: from_raw_parts ( raw_input_ids, current_tokens, max_batch_tokens) ;
184
+ let token_type_ids =
185
+ Vec :: from_raw_parts ( raw_token_type_ids, current_tokens, max_batch_tokens) ;
186
+ let position_ids =
187
+ Vec :: from_raw_parts ( raw_position_ids, current_tokens, max_batch_tokens) ;
188
+
189
+ let batch_size = metadata. len ( ) ;
166
190
let next_batch = if metadata. is_empty ( ) {
167
191
None
168
192
} else {
169
- next_batch_span. record ( "batch_size" , metadata. len ( ) as u32 ) ;
170
- next_batch_span. record ( "tokens" , current_tokens as u32 ) ;
171
-
172
- metrics:: histogram!( "te_batch_next_size" , metadata. len( ) as f64 ) ;
173
- metrics:: histogram!( "te_batch_next_tokens" , current_tokens as f64 ) ;
174
-
175
193
Some ( (
176
194
metadata,
177
195
Batch {
@@ -181,18 +199,25 @@ async fn queue_task(
181
199
cumulative_seq_lengths : cu_seq_lengths,
182
200
max_length,
183
201
} ,
184
- next_batch_span,
185
202
) )
186
203
} ;
187
204
188
205
let _ = response_sender. send ( next_batch) ;
206
+
207
+ metrics:: histogram!( "te_batch_next_size" , batch_size as f64 ) ;
208
+ metrics:: histogram!( "te_batch_next_tokens" , current_tokens as f64 ) ;
189
209
metrics:: gauge!( "te_queue_size" , entries. len( ) as f64 ) ;
190
- }
210
+ } ,
191
211
}
192
212
}
193
213
}
194
214
195
- type NextBatch = ( Vec < Metadata > , Batch , Span ) ;
215
+ unsafe fn raw_u32_vec ( capacity : usize ) -> * mut u32 {
216
+ let layout = Layout :: array :: < u32 > ( capacity) . unwrap ( ) ;
217
+ alloc ( layout) . cast :: < u32 > ( )
218
+ }
219
+
220
+ type NextBatch = ( Vec < Metadata > , Batch ) ;
196
221
197
222
#[ derive( Debug ) ]
198
223
enum QueueCommand {
0 commit comments