Skip to content

Commit 70f1796

Browse files
Merge pull request #10 from huggingface/feat/prefetch_batch
feat: prefetch batch
2 parents c7010dd + 2fc7111 commit 70f1796

File tree

8 files changed

+206
-223
lines changed

8 files changed

+206
-223
lines changed

Cargo.lock

Lines changed: 110 additions & 165 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backends/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ authors.workspace = true
66
homepage.workspace = true
77

88
[dependencies]
9-
flume = "^0.10"
9+
flume = "^0.11"
1010
clap = { version = "4.1.4", features = ["derive"], optional = true }
1111
text-embeddings-backend-core = { path = "core" }
1212
text-embeddings-backend-python = { path = "python", optional = true }

backends/candle/Cargo.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@ candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn
1616
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "07e1a5490211e25ed0d096a2b21d3c607666eaae", optional = true }
1717
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true }
1818
lazy_static = "^1.4"
19-
flume = "^0.10"
2019
text-embeddings-backend-core = { path = "../core" }
2120
tracing = "^0.1"
22-
safetensors = "^0.3"
21+
safetensors = "^0.4"
2322
thiserror = "^1.0"
2423
serde = { version = "^1.0", features = ["serde_derive"] }
2524
serde_json = "^1.0"
26-
memmap2 = "^0.7"
25+
memmap2 = "^0.9"
2726

2827
[build-dependencies]
2928
anyhow = { version = "1", features = ["backtrace"] }

core/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ authors.workspace = true
66
homepage.workspace = true
77

88
[dependencies]
9-
flume = "^0.10"
9+
flume = "^0.11"
1010
hf-hub = { version = "^0.3.0", features = ["tokio"] }
1111
metrics = "^0.21"
1212
text-embeddings-backend = { path = "../backends" }
1313
thiserror = "^1.0"
14-
tokenizers = { version = "^0.13", default-features=false, features=["onig"] }
14+
tokenizers = { version = "^0.14", default-features=false, features=["onig"] }
1515
tracing = "^0.1"
1616
tokio = { version = "^1.25", features = ["rt", "rt-multi-thread", "parking_lot", "sync"] }

core/src/infer.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ impl Infer {
2929
) -> Self {
3030
let notify_batching_task = Arc::new(Notify::new());
3131

32+
// Create two batching tasks to prefetch batches
33+
tokio::spawn(batching_task(
34+
backend.clone(),
35+
queue.clone(),
36+
notify_batching_task.clone(),
37+
));
3238
tokio::spawn(batching_task(
3339
backend.clone(),
3440
queue.clone(),
@@ -86,7 +92,6 @@ impl Infer {
8692
metadata: Metadata {
8793
response_tx,
8894
span: Span::current(),
89-
temp_span: None,
9095
tokenization: start_time.elapsed(),
9196
queue_time: Instant::now(),
9297
batch_time: None,
@@ -133,12 +138,16 @@ impl Infer {
133138
}
134139
}
135140

141+
#[instrument(skip_all)]
136142
async fn batching_task(backend: Backend, queue: Queue, notify: Arc<Notify>) {
137143
loop {
138144
notify.notified().await;
139145

140146
while let Some(batch) = queue.next_batch().await {
141-
match backend.embed(batch.1).await {
147+
let results = backend.embed(batch.1).await;
148+
149+
// Handle sending responses in another thread to avoid starting the backend
150+
tokio::task::spawn_blocking(move || match results {
142151
Ok(embeddings) => {
143152
batch.0.into_iter().zip(embeddings).for_each(|(m, e)| {
144153
let _ = m.response_tx.send(Ok(InferResponse {
@@ -160,7 +169,7 @@ async fn batching_task(backend: Backend, queue: Queue, notify: Arc<Notify>) {
160169
let _ = m.response_tx.send(Err(err.clone()));
161170
});
162171
}
163-
}
172+
});
164173
}
165174
}
166175
}

core/src/queue.rs

Lines changed: 66 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
use crate::infer::InferResponse;
22
use crate::tokenization::Encoding;
3+
use std::alloc::{alloc, Layout};
34
use std::cmp::max;
45
use std::collections::VecDeque;
6+
use std::ptr;
57
use std::time::{Duration, Instant};
68
use text_embeddings_backend::{BackendError, Batch};
79
use tokio::sync::oneshot;
8-
use tracing::{info_span, instrument, Span};
10+
use tracing::{instrument, Span};
911

1012
/// Queue entry
1113
#[derive(Debug)]
@@ -23,8 +25,6 @@ pub struct Metadata {
2325
pub response_tx: oneshot::Sender<Result<InferResponse, BackendError>>,
2426
/// Span that will live as long as entry
2527
pub span: Span,
26-
/// Temporary span used as a guard when logging inference, wait times...
27-
pub temp_span: Option<Span>,
2828
/// Tokenization duration
2929
pub tokenization: Duration,
3030
/// Instant when this entry was queued
@@ -43,16 +43,23 @@ pub struct Queue {
4343
}
4444

4545
impl Queue {
46-
pub fn new(max_batch_tokens: usize, max_batch_requests: Option<usize>) -> Self {
46+
pub fn new(
47+
max_batch_tokens: usize,
48+
max_batch_requests: Option<usize>,
49+
max_concurrent_requests: usize,
50+
) -> Self {
4751
// Create channels
4852
let (queue_sender, queue_receiver) = flume::unbounded();
4953

5054
// Launch background queue task
51-
tokio::spawn(queue_task(
52-
max_batch_tokens,
53-
max_batch_requests,
54-
queue_receiver,
55-
));
55+
tokio::task::spawn_blocking(move || {
56+
queue_blocking_task(
57+
max_batch_tokens,
58+
max_batch_requests,
59+
max_concurrent_requests,
60+
queue_receiver,
61+
)
62+
});
5663

5764
Self { queue_sender }
5865
}
@@ -89,16 +96,17 @@ impl Queue {
8996
}
9097

9198
// Background task responsible of the queue state
92-
async fn queue_task(
99+
fn queue_blocking_task(
93100
max_batch_tokens: usize,
94101
max_batch_requests: Option<usize>,
102+
max_concurrent_requests: usize,
95103
queue_receiver: flume::Receiver<QueueCommand>,
96104
) {
97-
let capacity = max_batch_requests.unwrap_or(512);
105+
let capacity = max_batch_requests.unwrap_or(max_concurrent_requests);
98106

99-
let mut entries: VecDeque<Entry> = VecDeque::with_capacity(512);
107+
let mut entries: VecDeque<Entry> = VecDeque::with_capacity(max_concurrent_requests);
100108

101-
while let Ok(cmd) = queue_receiver.recv_async().await {
109+
while let Ok(cmd) = queue_receiver.recv() {
102110
match cmd {
103111
QueueCommand::Append(entry, span) => {
104112
let _span = span.entered();
@@ -108,23 +116,23 @@ async fn queue_task(
108116
QueueCommand::NextBatch {
109117
response_sender,
110118
span,
111-
} => {
119+
} => unsafe {
112120
let _span = span.entered();
113121

114-
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty, tokens = tracing::field::Empty);
115-
next_batch_span.follows_from(Span::current());
122+
// Allocate raw memory
123+
let raw_input_ids = raw_u32_vec(max_batch_tokens);
124+
let raw_token_type_ids = raw_u32_vec(max_batch_tokens);
125+
let raw_position_ids = raw_u32_vec(max_batch_tokens);
116126

117127
let mut metadata = Vec::with_capacity(capacity);
118-
119-
let mut input_ids = Vec::with_capacity(max_batch_tokens);
120-
let mut token_type_ids = Vec::with_capacity(max_batch_tokens);
121-
let mut position_ids = Vec::with_capacity(max_batch_tokens);
122128
let mut cu_seq_lengths = Vec::with_capacity(capacity);
123129
cu_seq_lengths.push(0);
124130

125131
let mut current_tokens = 0;
126132
let mut max_length = 0;
127133

134+
let batch_time = Instant::now();
135+
128136
while let Some(mut entry) = entries.pop_front() {
129137
// Filter entries where the response receiver was dropped (== entries where the request
130138
// was dropped by the client)
@@ -141,37 +149,47 @@ async fn queue_task(
141149
}
142150

143151
max_length = max(max_length, entry_tokens as u32);
144-
current_tokens += entry_tokens;
145152

146-
// Create a new span to link the batch back to this entry
147-
let entry_batch_span = info_span!(parent: &entry.metadata.span, "infer");
148-
// Add relationships
149-
next_batch_span.follows_from(&entry_batch_span);
150-
entry_batch_span.follows_from(&next_batch_span);
151-
152-
entry.metadata.batch_time = Some(Instant::now());
153-
entry.metadata.temp_span = Some(entry_batch_span);
153+
entry.metadata.batch_time = Some(batch_time);
154+
155+
// Copy memory to the correct spot in the raw vectors
156+
ptr::copy(
157+
entry.encoding.input_ids.as_mut_ptr(),
158+
raw_input_ids.add(current_tokens),
159+
entry.encoding.input_ids.len(),
160+
);
161+
ptr::copy(
162+
entry.encoding.token_type_ids.as_mut_ptr(),
163+
raw_token_type_ids.add(current_tokens),
164+
entry.encoding.token_type_ids.len(),
165+
);
166+
ptr::copy(
167+
entry.encoding.position_ids.as_mut_ptr(),
168+
raw_position_ids.add(current_tokens),
169+
entry.encoding.position_ids.len(),
170+
);
154171

172+
current_tokens += entry_tokens;
155173
metadata.push(entry.metadata);
156-
input_ids.extend(entry.encoding.input_ids);
157-
token_type_ids.extend(entry.encoding.token_type_ids);
158-
position_ids.extend(entry.encoding.position_ids);
159174
cu_seq_lengths.push(current_tokens as u32);
160175

161176
if Some(metadata.len()) == max_batch_requests {
162177
break;
163178
}
164179
}
165180

181+
// Create final vectors from raw memory
182+
let input_ids =
183+
Vec::from_raw_parts(raw_input_ids, current_tokens, max_batch_tokens);
184+
let token_type_ids =
185+
Vec::from_raw_parts(raw_token_type_ids, current_tokens, max_batch_tokens);
186+
let position_ids =
187+
Vec::from_raw_parts(raw_position_ids, current_tokens, max_batch_tokens);
188+
189+
let batch_size = metadata.len();
166190
let next_batch = if metadata.is_empty() {
167191
None
168192
} else {
169-
next_batch_span.record("batch_size", metadata.len() as u32);
170-
next_batch_span.record("tokens", current_tokens as u32);
171-
172-
metrics::histogram!("te_batch_next_size", metadata.len() as f64);
173-
metrics::histogram!("te_batch_next_tokens", current_tokens as f64);
174-
175193
Some((
176194
metadata,
177195
Batch {
@@ -181,18 +199,25 @@ async fn queue_task(
181199
cumulative_seq_lengths: cu_seq_lengths,
182200
max_length,
183201
},
184-
next_batch_span,
185202
))
186203
};
187204

188205
let _ = response_sender.send(next_batch);
206+
207+
metrics::histogram!("te_batch_next_size", batch_size as f64);
208+
metrics::histogram!("te_batch_next_tokens", current_tokens as f64);
189209
metrics::gauge!("te_queue_size", entries.len() as f64);
190-
}
210+
},
191211
}
192212
}
193213
}
194214

195-
type NextBatch = (Vec<Metadata>, Batch, Span);
215+
unsafe fn raw_u32_vec(capacity: usize) -> *mut u32 {
216+
let layout = Layout::array::<u32>(capacity).unwrap();
217+
alloc(layout).cast::<u32>()
218+
}
219+
220+
type NextBatch = (Vec<Metadata>, Batch);
196221

197222
#[derive(Debug)]
198223
enum QueueCommand {

router/Cargo.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,27 @@ text-embeddings-backend = { path = "../backends", features = ["clap"] }
2323
text-embeddings-core = { path = "../core" }
2424
clap = { version = "4.1.4", features = ["derive", "env"] }
2525
futures = "^0.3"
26-
flume = "0.10.14"
26+
flume = "0.11.0"
2727
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
2828
hf-hub = { version = "0.3.0", features = ["tokio"] }
2929
num_cpus = "1.16.0"
3030
metrics = "0.21.0"
3131
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
32-
opentelemetry = { version = "0.19.0", features = ["rt-tokio"] }
33-
opentelemetry-otlp = "0.12.0"
32+
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
33+
opentelemetry-otlp = "0.13.0"
3434
reqwest = { version = "0.11.14", features = [] }
3535
serde = "1.0.152"
3636
serde_json = "1.0.93"
3737
thiserror = "1.0.38"
38-
tokenizers = { version = "^0.13", default-features=false, features=["onig"] }
38+
tokenizers = { version = "0.14.1", default-features=false, features=["onig"] }
3939
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
4040
tower-http = { version = "0.4.0", features = ["cors"] }
4141
tracing = "0.1.37"
42-
tracing-opentelemetry = "0.19.0"
42+
tracing-chrome = "0.7.1"
43+
tracing-opentelemetry = "0.21.0"
4344
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
44-
utoipa = { version = "3.0.1", features = ["axum_extras"] }
45-
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
45+
utoipa = { version = "4.0.0", features = ["axum_extras"] }
46+
utoipa-swagger-ui = { version = "4.0.0", features = ["axum"] }
4647

4748
[build-dependencies]
4849
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

router/src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,11 @@ async fn main() -> Result<()> {
209209
});
210210

211211
// Queue logic
212-
let queue = Queue::new(args.max_batch_tokens, max_batch_requests);
212+
let queue = Queue::new(
213+
args.max_batch_tokens,
214+
max_batch_requests,
215+
args.max_concurrent_requests,
216+
);
213217

214218
// Create infer task
215219
let infer = Infer::new(tokenization, queue, args.max_concurrent_requests, backend);

0 commit comments

Comments
 (0)