Skip to content

Commit b90ce0e

Browse files
feat: prefetch batches
1 parent c7010dd commit b90ce0e

File tree

8 files changed

+220
-226
lines changed

8 files changed

+220
-226
lines changed

Cargo.lock

Lines changed: 110 additions & 165 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backends/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ authors.workspace = true
66
homepage.workspace = true
77

88
[dependencies]
9-
flume = "^0.10"
9+
flume = "^0.11"
1010
clap = { version = "4.1.4", features = ["derive"], optional = true }
1111
text-embeddings-backend-core = { path = "core" }
1212
text-embeddings-backend-python = { path = "python", optional = true }

backends/candle/Cargo.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@ candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn
1616
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "07e1a5490211e25ed0d096a2b21d3c607666eaae", optional = true }
1717
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true }
1818
lazy_static = "^1.4"
19-
flume = "^0.10"
2019
text-embeddings-backend-core = { path = "../core" }
2120
tracing = "^0.1"
22-
safetensors = "^0.3"
21+
safetensors = "^0.4"
2322
thiserror = "^1.0"
2423
serde = { version = "^1.0", features = ["serde_derive"] }
2524
serde_json = "^1.0"
26-
memmap2 = "^0.7"
25+
memmap2 = "^0.9"
2726

2827
[build-dependencies]
2928
anyhow = { version = "1", features = ["backtrace"] }

core/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ authors.workspace = true
66
homepage.workspace = true
77

88
[dependencies]
9-
flume = "^0.10"
9+
flume = "^0.11"
1010
hf-hub = { version = "^0.3.0", features = ["tokio"] }
1111
metrics = "^0.21"
1212
text-embeddings-backend = { path = "../backends" }
1313
thiserror = "^1.0"
14-
tokenizers = { version = "^0.13", default-features=false, features=["onig"] }
14+
tokenizers = { version = "^0.14", default-features=false, features=["onig"] }
1515
tracing = "^0.1"
1616
tokio = { version = "^1.25", features = ["rt", "rt-multi-thread", "parking_lot", "sync"] }

core/src/infer.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ impl Infer {
2929
) -> Self {
3030
let notify_batching_task = Arc::new(Notify::new());
3131

32+
// Create two batching tasks to prefetch batches
33+
tokio::spawn(batching_task(
34+
backend.clone(),
35+
queue.clone(),
36+
notify_batching_task.clone(),
37+
));
3238
tokio::spawn(batching_task(
3339
backend.clone(),
3440
queue.clone(),
@@ -86,7 +92,6 @@ impl Infer {
8692
metadata: Metadata {
8793
response_tx,
8894
span: Span::current(),
89-
temp_span: None,
9095
tokenization: start_time.elapsed(),
9196
queue_time: Instant::now(),
9297
batch_time: None,
@@ -133,12 +138,16 @@ impl Infer {
133138
}
134139
}
135140

141+
#[instrument(skip_all)]
136142
async fn batching_task(backend: Backend, queue: Queue, notify: Arc<Notify>) {
137143
loop {
138144
notify.notified().await;
139145

140146
while let Some(batch) = queue.next_batch().await {
141-
match backend.embed(batch.1).await {
147+
let results = backend.embed(batch.1).await;
148+
149+
// Handle sending responses in another thread to not starve the model
150+
tokio::task::spawn_blocking(move || match results {
142151
Ok(embeddings) => {
143152
batch.0.into_iter().zip(embeddings).for_each(|(m, e)| {
144153
let _ = m.response_tx.send(Ok(InferResponse {
@@ -160,7 +169,7 @@ async fn batching_task(backend: Backend, queue: Queue, notify: Arc<Notify>) {
160169
let _ = m.response_tx.send(Err(err.clone()));
161170
});
162171
}
163-
}
172+
});
164173
}
165174
}
166175
}

core/src/queue.rs

Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use crate::infer::InferResponse;
22
use crate::tokenization::Encoding;
3+
use std::alloc::{alloc, Layout};
34
use std::cmp::max;
45
use std::collections::VecDeque;
6+
use std::ptr;
57
use std::time::{Duration, Instant};
68
use text_embeddings_backend::{BackendError, Batch};
79
use tokio::sync::oneshot;
@@ -23,8 +25,6 @@ pub struct Metadata {
2325
pub response_tx: oneshot::Sender<Result<InferResponse, BackendError>>,
2426
/// Span that will live as long as entry
2527
pub span: Span,
26-
/// Temporary span used as a guard when logging inference, wait times...
27-
pub temp_span: Option<Span>,
2828
/// Tokenization duration
2929
pub tokenization: Duration,
3030
/// Instant when this entry was queued
@@ -43,16 +43,22 @@ pub struct Queue {
4343
}
4444

4545
impl Queue {
46-
pub fn new(max_batch_tokens: usize, max_batch_requests: Option<usize>) -> Self {
46+
pub fn new(
47+
max_batch_tokens: usize,
48+
max_batch_requests: Option<usize>,
49+
max_concurrent_requests: usize,
50+
) -> Self {
4751
// Create channels
4852
let (queue_sender, queue_receiver) = flume::unbounded();
4953

50-
// Launch background queue task
51-
tokio::spawn(queue_task(
52-
max_batch_tokens,
53-
max_batch_requests,
54-
queue_receiver,
55-
));
54+
tokio::task::spawn_blocking(move || {
55+
queue_task(
56+
max_batch_tokens,
57+
max_batch_requests,
58+
max_concurrent_requests,
59+
queue_receiver,
60+
)
61+
});
5662

5763
Self { queue_sender }
5864
}
@@ -89,16 +95,17 @@ impl Queue {
8995
}
9096

9197
// Background task responsible of the queue state
92-
async fn queue_task(
98+
fn queue_task(
9399
max_batch_tokens: usize,
94100
max_batch_requests: Option<usize>,
101+
max_concurrent_requests: usize,
95102
queue_receiver: flume::Receiver<QueueCommand>,
96103
) {
97-
let capacity = max_batch_requests.unwrap_or(512);
104+
let capacity = max_batch_requests.unwrap_or(max_concurrent_requests);
98105

99-
let mut entries: VecDeque<Entry> = VecDeque::with_capacity(512);
106+
let mut entries: VecDeque<Entry> = VecDeque::with_capacity(max_concurrent_requests);
100107

101-
while let Ok(cmd) = queue_receiver.recv_async().await {
108+
while let Ok(cmd) = queue_receiver.recv() {
102109
match cmd {
103110
QueueCommand::Append(entry, span) => {
104111
let _span = span.entered();
@@ -108,23 +115,23 @@ async fn queue_task(
108115
QueueCommand::NextBatch {
109116
response_sender,
110117
span,
111-
} => {
118+
} => unsafe {
112119
let _span = span.entered();
113120

114-
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty, tokens = tracing::field::Empty);
115-
next_batch_span.follows_from(Span::current());
116-
117121
let mut metadata = Vec::with_capacity(capacity);
118122

119-
let mut input_ids = Vec::with_capacity(max_batch_tokens);
120-
let mut token_type_ids = Vec::with_capacity(max_batch_tokens);
121-
let mut position_ids = Vec::with_capacity(max_batch_tokens);
123+
let raw_input_ids = raw_u32_vec(max_batch_tokens);
124+
let raw_token_type_ids = raw_u32_vec(max_batch_tokens);
125+
let raw_position_ids = raw_u32_vec(max_batch_tokens);
126+
122127
let mut cu_seq_lengths = Vec::with_capacity(capacity);
123128
cu_seq_lengths.push(0);
124129

125130
let mut current_tokens = 0;
126131
let mut max_length = 0;
127132

133+
let batch_time = Instant::now();
134+
128135
while let Some(mut entry) = entries.pop_front() {
129136
// Filter entries where the response receiver was dropped (== entries where the request
130137
// was dropped by the client)
@@ -141,37 +148,59 @@ async fn queue_task(
141148
}
142149

143150
max_length = max(max_length, entry_tokens as u32);
144-
current_tokens += entry_tokens;
145-
146-
// Create a new span to link the batch back to this entry
147-
let entry_batch_span = info_span!(parent: &entry.metadata.span, "infer");
148-
// Add relationships
149-
next_batch_span.follows_from(&entry_batch_span);
150-
entry_batch_span.follows_from(&next_batch_span);
151151

152-
entry.metadata.batch_time = Some(Instant::now());
153-
entry.metadata.temp_span = Some(entry_batch_span);
154-
155-
metadata.push(entry.metadata);
156-
input_ids.extend(entry.encoding.input_ids);
157-
token_type_ids.extend(entry.encoding.token_type_ids);
158-
position_ids.extend(entry.encoding.position_ids);
159-
cu_seq_lengths.push(current_tokens as u32);
152+
entry.metadata.batch_time = Some(batch_time);
153+
154+
{
155+
let _span = info_span!("extend").entered();
156+
157+
ptr::copy(
158+
entry.encoding.input_ids.as_mut_ptr(),
159+
raw_input_ids.add(current_tokens),
160+
entry.encoding.input_ids.len(),
161+
);
162+
ptr::copy(
163+
entry.encoding.token_type_ids.as_mut_ptr(),
164+
raw_token_type_ids.add(current_tokens),
165+
entry.encoding.token_type_ids.len(),
166+
);
167+
ptr::copy(
168+
entry.encoding.position_ids.as_mut_ptr(),
169+
raw_position_ids.add(current_tokens),
170+
entry.encoding.position_ids.len(),
171+
);
172+
173+
// input_ids.extend_from_slice(entry.encoding.input_ids.as_slice());
174+
// token_type_ids.extend_from_slice(entry.encoding.token_type_ids.as_slice());
175+
// position_ids.extend_from_slice(entry.encoding.position_ids.as_slice());
176+
177+
// for i in 0..entry.encoding.input_ids.len() {
178+
// input_ids.push(entry.encoding.input_ids[i]);
179+
// token_type_ids.push(entry.encoding.token_type_ids[i]);
180+
// position_ids.push(entry.encoding.position_ids[i]);
181+
// }
182+
183+
current_tokens += entry_tokens;
184+
metadata.push(entry.metadata);
185+
cu_seq_lengths.push(current_tokens as u32);
186+
}
160187

161188
if Some(metadata.len()) == max_batch_requests {
162189
break;
163190
}
164191
}
165192

193+
let input_ids =
194+
Vec::from_raw_parts(raw_input_ids, current_tokens, max_batch_tokens);
195+
let token_type_ids =
196+
Vec::from_raw_parts(raw_token_type_ids, current_tokens, max_batch_tokens);
197+
let position_ids =
198+
Vec::from_raw_parts(raw_position_ids, current_tokens, max_batch_tokens);
199+
200+
let batch_size = metadata.len();
166201
let next_batch = if metadata.is_empty() {
167202
None
168203
} else {
169-
next_batch_span.record("batch_size", metadata.len() as u32);
170-
next_batch_span.record("tokens", current_tokens as u32);
171-
172-
metrics::histogram!("te_batch_next_size", metadata.len() as f64);
173-
metrics::histogram!("te_batch_next_tokens", current_tokens as f64);
174-
175204
Some((
176205
metadata,
177206
Batch {
@@ -181,18 +210,25 @@ async fn queue_task(
181210
cumulative_seq_lengths: cu_seq_lengths,
182211
max_length,
183212
},
184-
next_batch_span,
185213
))
186214
};
187215

188216
let _ = response_sender.send(next_batch);
217+
218+
metrics::histogram!("te_batch_next_size", batch_size as f64);
219+
metrics::histogram!("te_batch_next_tokens", current_tokens as f64);
189220
metrics::gauge!("te_queue_size", entries.len() as f64);
190-
}
221+
},
191222
}
192223
}
193224
}
194225

195-
type NextBatch = (Vec<Metadata>, Batch, Span);
226+
unsafe fn raw_u32_vec(capacity: usize) -> *mut u32 {
227+
let layout = Layout::array::<u32>(capacity).unwrap();
228+
alloc(layout).cast::<u32>()
229+
}
230+
231+
type NextBatch = (Vec<Metadata>, Batch);
196232

197233
#[derive(Debug)]
198234
enum QueueCommand {

router/Cargo.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,27 @@ text-embeddings-backend = { path = "../backends", features = ["clap"] }
2323
text-embeddings-core = { path = "../core" }
2424
clap = { version = "4.1.4", features = ["derive", "env"] }
2525
futures = "^0.3"
26-
flume = "0.10.14"
26+
flume = "0.11.0"
2727
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
2828
hf-hub = { version = "0.3.0", features = ["tokio"] }
2929
num_cpus = "1.16.0"
3030
metrics = "0.21.0"
3131
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
32-
opentelemetry = { version = "0.19.0", features = ["rt-tokio"] }
33-
opentelemetry-otlp = "0.12.0"
32+
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
33+
opentelemetry-otlp = "0.13.0"
3434
reqwest = { version = "0.11.14", features = [] }
3535
serde = "1.0.152"
3636
serde_json = "1.0.93"
3737
thiserror = "1.0.38"
38-
tokenizers = { version = "^0.13", default-features=false, features=["onig"] }
38+
tokenizers = { version = "0.14.1", default-features=false, features=["onig"] }
3939
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
4040
tower-http = { version = "0.4.0", features = ["cors"] }
4141
tracing = "0.1.37"
42-
tracing-opentelemetry = "0.19.0"
42+
tracing-chrome = "0.7.1"
43+
tracing-opentelemetry = "0.21.0"
4344
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
44-
utoipa = { version = "3.0.1", features = ["axum_extras"] }
45-
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
45+
utoipa = { version = "4.0.0", features = ["axum_extras"] }
46+
utoipa-swagger-ui = { version = "4.0.0", features = ["axum"] }
4647

4748
[build-dependencies]
4849
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

router/src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,11 @@ async fn main() -> Result<()> {
209209
});
210210

211211
// Queue logic
212-
let queue = Queue::new(args.max_batch_tokens, max_batch_requests);
212+
let queue = Queue::new(
213+
args.max_batch_tokens,
214+
max_batch_requests,
215+
args.max_concurrent_requests,
216+
);
213217

214218
// Create infer task
215219
let infer = Infer::new(tokenization, queue, args.max_concurrent_requests, backend);

0 commit comments

Comments
 (0)