File tree Expand file tree Collapse file tree 10 files changed +45
-6
lines changed Expand file tree Collapse file tree 10 files changed +45
-6
lines changed Original file line number Diff line number Diff line change @@ -126,6 +126,10 @@ impl Backend for CandleBackend {
126
126
Ok ( ( ) )
127
127
}
128
128
129
+ fn is_padded ( & self ) -> bool {
130
+ self . model . is_padded ( )
131
+ }
132
+
129
133
fn embed ( & self , batch : Batch ) -> Result < Vec < Embedding > , BackendError > {
130
134
let results = self . model . embed ( batch) . e ( ) ?;
131
135
let results = results. to_dtype ( DType :: F32 ) . e ( ) ?. to_vec2 ( ) . e ( ) ?;
Original file line number Diff line number Diff line change @@ -19,6 +19,8 @@ mod jina;
19
19
pub use flash_bert:: FlashBertModel ;
20
20
21
21
pub ( crate ) trait Model {
22
+ fn is_padded ( & self ) -> bool ;
23
+
22
24
fn embed ( & self , _batch : Batch ) -> Result < Tensor > {
23
25
candle:: bail!( "`embed` is not implemented for this model" ) ;
24
26
}
Original file line number Diff line number Diff line change @@ -618,6 +618,10 @@ impl BertModel {
618
618
}
619
619
620
620
impl Model for BertModel {
621
+ fn is_padded ( & self ) -> bool {
622
+ true
623
+ }
624
+
621
625
fn embed ( & self , batch : Batch ) -> Result < Tensor > {
622
626
self . forward ( batch)
623
627
}
Original file line number Diff line number Diff line change @@ -447,6 +447,9 @@ impl FlashBertModel {
447
447
}
448
448
449
449
impl Model for FlashBertModel {
450
+ fn is_padded ( & self ) -> bool {
451
+ false
452
+ }
450
453
fn embed ( & self , batch : Batch ) -> Result < Tensor > {
451
454
self . forward ( batch)
452
455
}
Original file line number Diff line number Diff line change @@ -595,6 +595,9 @@ impl JinaBertModel {
595
595
}
596
596
597
597
impl Model for JinaBertModel {
598
+ fn is_padded ( & self ) -> bool {
599
+ true
600
+ }
598
601
fn embed ( & self , batch : Batch ) -> Result < Tensor > {
599
602
self . forward ( batch)
600
603
}
Original file line number Diff line number Diff line change @@ -20,6 +20,8 @@ pub trait Backend {
20
20
None
21
21
}
22
22
23
+ fn is_padded ( & self ) -> bool ;
24
+
23
25
fn embed ( & self , batch : Batch ) -> Result < Vec < Embedding > , BackendError > ;
24
26
25
27
fn predict ( & self , batch : Batch ) -> Result < Vec < Vec < f32 > > , BackendError > ;
Original file line number Diff line number Diff line change @@ -65,6 +65,10 @@ impl Backend for PythonBackend {
65
65
Ok ( ( ) )
66
66
}
67
67
68
+ fn is_padded ( & self ) -> bool {
69
+ false
70
+ }
71
+
68
72
fn embed ( & self , batch : Batch ) -> Result < Vec < Embedding > , BackendError > {
69
73
let results = self
70
74
. tokio_runtime
Original file line number Diff line number Diff line change @@ -21,6 +21,7 @@ pub struct Backend {
21
21
backend_sender : mpsc:: UnboundedSender < BackendCommand > ,
22
22
/// Health status
23
23
health_receiver : watch:: Receiver < bool > ,
24
+ pub padded_model : bool ,
24
25
pub max_batch_size : Option < usize > ,
25
26
pub model_type : ModelType ,
26
27
}
@@ -42,6 +43,7 @@ impl Backend {
42
43
uds_path,
43
44
otlp_endpoint,
44
45
) ?;
46
+ let padded_model = backend. is_padded ( ) ;
45
47
let max_batch_size = backend. max_batch_size ( ) ;
46
48
47
49
let ( health_sender, health_receiver) = watch:: channel ( false ) ;
@@ -53,6 +55,7 @@ impl Backend {
53
55
Ok ( Self {
54
56
backend_sender,
55
57
health_receiver,
58
+ padded_model,
56
59
max_batch_size,
57
60
model_type,
58
61
} )
Original file line number Diff line number Diff line change @@ -40,6 +40,7 @@ pub struct Queue {
40
40
41
41
impl Queue {
42
42
pub fn new (
43
+ padded_model : bool ,
43
44
max_batch_tokens : usize ,
44
45
max_batch_requests : Option < usize > ,
45
46
max_concurrent_requests : usize ,
@@ -50,6 +51,7 @@ impl Queue {
50
51
// Launch background queue task
51
52
tokio:: task:: spawn_blocking ( move || {
52
53
queue_blocking_task (
54
+ padded_model,
53
55
max_batch_tokens,
54
56
max_batch_requests,
55
57
max_concurrent_requests,
@@ -93,6 +95,7 @@ impl Queue {
93
95
94
96
// Background task responsible of the queue state
95
97
fn queue_blocking_task (
98
+ padded_model : bool ,
96
99
max_batch_tokens : usize ,
97
100
max_batch_requests : Option < usize > ,
98
101
max_concurrent_requests : usize ,
@@ -136,7 +139,14 @@ fn queue_blocking_task(
136
139
137
140
let entry_tokens = entry. encoding . input_ids . len ( ) ;
138
141
139
- if current_tokens + entry_tokens > max_batch_tokens {
142
+ let total_tokens = if padded_model {
143
+ ( max ( max_length, entry_tokens as u32 ) * ( metadata. len ( ) + 1 ) as u32 )
144
+ as usize
145
+ } else {
146
+ current_tokens + entry_tokens
147
+ } ;
148
+
149
+ if total_tokens > max_batch_tokens {
140
150
entries. push_front ( entry) ;
141
151
break ;
142
152
}
Original file line number Diff line number Diff line change @@ -331,14 +331,18 @@ async fn main() -> Result<()> {
331
331
. await
332
332
. context ( "Model backend is not healthy" ) ?;
333
333
334
- let max_batch_requests = backend. max_batch_size . map ( |s| {
335
- tracing:: warn!( "Backend does not support a batch size > {s}" ) ;
336
- tracing:: warn!( "forcing `max_batch_requests={s}`" ) ;
337
- s
338
- } ) ;
334
+ let max_batch_requests = backend
335
+ . max_batch_size
336
+ . map ( |s| {
337
+ tracing:: warn!( "Backend does not support a batch size > {s}" ) ;
338
+ tracing:: warn!( "forcing `max_batch_requests={s}`" ) ;
339
+ s
340
+ } )
341
+ . or ( args. max_batch_requests ) ;
339
342
340
343
// Queue logic
341
344
let queue = Queue :: new (
345
+ backend. padded_model ,
342
346
args. max_batch_tokens ,
343
347
max_batch_requests,
344
348
args. max_concurrent_requests ,
You can’t perform that action at this time.
0 commit comments