Skip to content

Commit 2fc7111

Browse files
add docs
1 parent b90ce0e commit 2fc7111

File tree

2 files changed

+28
-39
lines changed

2 files changed

+28
-39
lines changed

core/src/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ async fn batching_task(backend: Backend, queue: Queue, notify: Arc<Notify>) {
146146
while let Some(batch) = queue.next_batch().await {
147147
let results = backend.embed(batch.1).await;
148148

149-
// Handle sending responses in another thread to not starve the model
149+
// Handle sending responses in another thread to avoid starting the backend
150150
tokio::task::spawn_blocking(move || match results {
151151
Ok(embeddings) => {
152152
batch.0.into_iter().zip(embeddings).for_each(|(m, e)| {

core/src/queue.rs

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::ptr;
77
use std::time::{Duration, Instant};
88
use text_embeddings_backend::{BackendError, Batch};
99
use tokio::sync::oneshot;
10-
use tracing::{info_span, instrument, Span};
10+
use tracing::{instrument, Span};
1111

1212
/// Queue entry
1313
#[derive(Debug)]
@@ -51,8 +51,9 @@ impl Queue {
5151
// Create channels
5252
let (queue_sender, queue_receiver) = flume::unbounded();
5353

54+
// Launch background queue task
5455
tokio::task::spawn_blocking(move || {
55-
queue_task(
56+
queue_blocking_task(
5657
max_batch_tokens,
5758
max_batch_requests,
5859
max_concurrent_requests,
@@ -95,7 +96,7 @@ impl Queue {
9596
}
9697

9798
// Background task responsible of the queue state
98-
fn queue_task(
99+
fn queue_blocking_task(
99100
max_batch_tokens: usize,
100101
max_batch_requests: Option<usize>,
101102
max_concurrent_requests: usize,
@@ -118,12 +119,12 @@ fn queue_task(
118119
} => unsafe {
119120
let _span = span.entered();
120121

121-
let mut metadata = Vec::with_capacity(capacity);
122-
122+
// Allocate raw memory
123123
let raw_input_ids = raw_u32_vec(max_batch_tokens);
124124
let raw_token_type_ids = raw_u32_vec(max_batch_tokens);
125125
let raw_position_ids = raw_u32_vec(max_batch_tokens);
126126

127+
let mut metadata = Vec::with_capacity(capacity);
127128
let mut cu_seq_lengths = Vec::with_capacity(capacity);
128129
cu_seq_lengths.push(0);
129130

@@ -151,45 +152,33 @@ fn queue_task(
151152

152153
entry.metadata.batch_time = Some(batch_time);
153154

154-
{
155-
let _span = info_span!("extend").entered();
156-
157-
ptr::copy(
158-
entry.encoding.input_ids.as_mut_ptr(),
159-
raw_input_ids.add(current_tokens),
160-
entry.encoding.input_ids.len(),
161-
);
162-
ptr::copy(
163-
entry.encoding.token_type_ids.as_mut_ptr(),
164-
raw_token_type_ids.add(current_tokens),
165-
entry.encoding.token_type_ids.len(),
166-
);
167-
ptr::copy(
168-
entry.encoding.position_ids.as_mut_ptr(),
169-
raw_position_ids.add(current_tokens),
170-
entry.encoding.position_ids.len(),
171-
);
172-
173-
// input_ids.extend_from_slice(entry.encoding.input_ids.as_slice());
174-
// token_type_ids.extend_from_slice(entry.encoding.token_type_ids.as_slice());
175-
// position_ids.extend_from_slice(entry.encoding.position_ids.as_slice());
176-
177-
// for i in 0..entry.encoding.input_ids.len() {
178-
// input_ids.push(entry.encoding.input_ids[i]);
179-
// token_type_ids.push(entry.encoding.token_type_ids[i]);
180-
// position_ids.push(entry.encoding.position_ids[i]);
181-
// }
182-
183-
current_tokens += entry_tokens;
184-
metadata.push(entry.metadata);
185-
cu_seq_lengths.push(current_tokens as u32);
186-
}
155+
// Copy memory to the correct spot in the raw vectors
156+
ptr::copy(
157+
entry.encoding.input_ids.as_mut_ptr(),
158+
raw_input_ids.add(current_tokens),
159+
entry.encoding.input_ids.len(),
160+
);
161+
ptr::copy(
162+
entry.encoding.token_type_ids.as_mut_ptr(),
163+
raw_token_type_ids.add(current_tokens),
164+
entry.encoding.token_type_ids.len(),
165+
);
166+
ptr::copy(
167+
entry.encoding.position_ids.as_mut_ptr(),
168+
raw_position_ids.add(current_tokens),
169+
entry.encoding.position_ids.len(),
170+
);
171+
172+
current_tokens += entry_tokens;
173+
metadata.push(entry.metadata);
174+
cu_seq_lengths.push(current_tokens as u32);
187175

188176
if Some(metadata.len()) == max_batch_requests {
189177
break;
190178
}
191179
}
192180

181+
// Create final vectors from raw memory
193182
let input_ids =
194183
Vec::from_raw_parts(raw_input_ids, current_tokens, max_batch_tokens);
195184
let token_type_ids =

0 commit comments

Comments
 (0)