Skip to content

Commit fc716d5

Browse files
fix: padding support in batch tokens (#93)
1 parent d3e5b5a commit fc716d5

File tree

10 files changed

+45
-6
lines changed

10 files changed

+45
-6
lines changed

backends/candle/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ impl Backend for CandleBackend {
126126
Ok(())
127127
}
128128

129+
fn is_padded(&self) -> bool {
130+
self.model.is_padded()
131+
}
132+
129133
fn embed(&self, batch: Batch) -> Result<Vec<Embedding>, BackendError> {
130134
let results = self.model.embed(batch).e()?;
131135
let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;

backends/candle/src/models.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ mod jina;
1919
pub use flash_bert::FlashBertModel;
2020

2121
pub(crate) trait Model {
22+
fn is_padded(&self) -> bool;
23+
2224
fn embed(&self, _batch: Batch) -> Result<Tensor> {
2325
candle::bail!("`embed` is not implemented for this model");
2426
}

backends/candle/src/models/bert.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,10 @@ impl BertModel {
618618
}
619619

620620
impl Model for BertModel {
621+
fn is_padded(&self) -> bool {
622+
true
623+
}
624+
621625
fn embed(&self, batch: Batch) -> Result<Tensor> {
622626
self.forward(batch)
623627
}

backends/candle/src/models/flash_bert.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,9 @@ impl FlashBertModel {
447447
}
448448

449449
impl Model for FlashBertModel {
450+
fn is_padded(&self) -> bool {
451+
false
452+
}
450453
fn embed(&self, batch: Batch) -> Result<Tensor> {
451454
self.forward(batch)
452455
}

backends/candle/src/models/jina.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,9 @@ impl JinaBertModel {
595595
}
596596

597597
impl Model for JinaBertModel {
598+
fn is_padded(&self) -> bool {
599+
true
600+
}
598601
fn embed(&self, batch: Batch) -> Result<Tensor> {
599602
self.forward(batch)
600603
}

backends/core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ pub trait Backend {
2020
None
2121
}
2222

23+
fn is_padded(&self) -> bool;
24+
2325
fn embed(&self, batch: Batch) -> Result<Vec<Embedding>, BackendError>;
2426

2527
fn predict(&self, batch: Batch) -> Result<Vec<Vec<f32>>, BackendError>;

backends/python/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ impl Backend for PythonBackend {
6565
Ok(())
6666
}
6767

68+
fn is_padded(&self) -> bool {
69+
false
70+
}
71+
6872
fn embed(&self, batch: Batch) -> Result<Vec<Embedding>, BackendError> {
6973
let results = self
7074
.tokio_runtime

backends/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub struct Backend {
2121
backend_sender: mpsc::UnboundedSender<BackendCommand>,
2222
/// Health status
2323
health_receiver: watch::Receiver<bool>,
24+
pub padded_model: bool,
2425
pub max_batch_size: Option<usize>,
2526
pub model_type: ModelType,
2627
}
@@ -42,6 +43,7 @@ impl Backend {
4243
uds_path,
4344
otlp_endpoint,
4445
)?;
46+
let padded_model = backend.is_padded();
4547
let max_batch_size = backend.max_batch_size();
4648

4749
let (health_sender, health_receiver) = watch::channel(false);
@@ -53,6 +55,7 @@ impl Backend {
5355
Ok(Self {
5456
backend_sender,
5557
health_receiver,
58+
padded_model,
5659
max_batch_size,
5760
model_type,
5861
})

core/src/queue.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub struct Queue {
4040

4141
impl Queue {
4242
pub fn new(
43+
padded_model: bool,
4344
max_batch_tokens: usize,
4445
max_batch_requests: Option<usize>,
4546
max_concurrent_requests: usize,
@@ -50,6 +51,7 @@ impl Queue {
5051
// Launch background queue task
5152
tokio::task::spawn_blocking(move || {
5253
queue_blocking_task(
54+
padded_model,
5355
max_batch_tokens,
5456
max_batch_requests,
5557
max_concurrent_requests,
@@ -93,6 +95,7 @@ impl Queue {
9395

9496
// Background task responsible of the queue state
9597
fn queue_blocking_task(
98+
padded_model: bool,
9699
max_batch_tokens: usize,
97100
max_batch_requests: Option<usize>,
98101
max_concurrent_requests: usize,
@@ -136,7 +139,14 @@ fn queue_blocking_task(
136139

137140
let entry_tokens = entry.encoding.input_ids.len();
138141

139-
if current_tokens + entry_tokens > max_batch_tokens {
142+
let total_tokens = if padded_model {
143+
(max(max_length, entry_tokens as u32) * (metadata.len() + 1) as u32)
144+
as usize
145+
} else {
146+
current_tokens + entry_tokens
147+
};
148+
149+
if total_tokens > max_batch_tokens {
140150
entries.push_front(entry);
141151
break;
142152
}

router/src/main.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,18 @@ async fn main() -> Result<()> {
331331
.await
332332
.context("Model backend is not healthy")?;
333333

334-
let max_batch_requests = backend.max_batch_size.map(|s| {
335-
tracing::warn!("Backend does not support a batch size > {s}");
336-
tracing::warn!("forcing `max_batch_requests={s}`");
337-
s
338-
});
334+
let max_batch_requests = backend
335+
.max_batch_size
336+
.map(|s| {
337+
tracing::warn!("Backend does not support a batch size > {s}");
338+
tracing::warn!("forcing `max_batch_requests={s}`");
339+
s
340+
})
341+
.or(args.max_batch_requests);
339342

340343
// Queue logic
341344
let queue = Queue::new(
345+
backend.padded_model,
342346
args.max_batch_tokens,
343347
max_batch_requests,
344348
args.max_concurrent_requests,

0 commit comments

Comments
 (0)