Skip to content

Commit ce2f210

Browse files
feat(router): add truncation direction parameter (#299)
1 parent 30a5f7e commit ce2f210

File tree

9 files changed

+228
-52
lines changed

9 files changed

+228
-52
lines changed

Cargo.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@ candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev =
2424
candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "33b7ecf9ed82bb7c20f1a94555218fabfbaa2fe3", package = "candle-flash-attn" }
2525
hf-hub = { git = "https://github.com/huggingface/hf-hub", rev = "b167f69692be5f49eb8003788f7f8a499a98b096" }
2626

27-
2827
[profile.release]
2928
debug = 0
30-
incremental = true
3129
lto = "fat"
3230
opt-level = 3
3331
codegen-units = 1
3432
strip = "symbols"
3533
panic = "abort"
34+
35+
[profile.release-debug]
36+
inherits = "release"
37+
debug = 1
38+
lto = "thin"
39+
codegen-units = 16
40+
strip = "none"
41+
incremental = true

Dockerfile-cuda

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ ARG CUDA_COMPUTE_CAP=80
3535
ARG GIT_SHA
3636
ARG DOCKER_LABEL
3737

38+
# Limit parallelism
39+
ARG RAYON_NUM_THREADS
40+
ARG CARGO_BUILD_JOBS
41+
ARG CARGO_BUILD_INCREMENTAL
42+
3843
# sccache specific variables
3944
ARG ACTIONS_CACHE_URL
4045
ARG ACTIONS_RUNTIME_TOKEN

Dockerfile-cuda-all

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ ARG ACTIONS_CACHE_URL
4040
ARG ACTIONS_RUNTIME_TOKEN
4141
ARG SCCACHE_GHA_ENABLED
4242

43-
# limit the number of kernels built at the same time
43+
# Limit parallelism
4444
ARG RAYON_NUM_THREADS=4
45+
ARG CARGO_BUILD_JOBS
46+
ARG CARGO_BUILD_INCREMENTAL
4547

4648
WORKDIR /usr/src
4749

core/src/infer.rs

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::TextEmbeddingsError;
44
use std::sync::Arc;
55
use std::time::{Duration, Instant};
66
use text_embeddings_backend::{Backend, BackendError, Embedding, ModelType};
7+
use tokenizers::TruncationDirection;
78
use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore};
89
use tracing::instrument;
910

@@ -117,6 +118,7 @@ impl Infer {
117118
&self,
118119
inputs: I,
119120
truncate: bool,
121+
truncation_direction: TruncationDirection,
120122
permit: OwnedSemaphorePermit,
121123
) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> {
122124
let start_time = Instant::now();
@@ -131,7 +133,14 @@ impl Infer {
131133
}
132134

133135
let results = self
134-
.embed(inputs, truncate, false, &start_time, permit)
136+
.embed(
137+
inputs,
138+
truncate,
139+
truncation_direction,
140+
false,
141+
&start_time,
142+
permit,
143+
)
135144
.await?;
136145

137146
let InferResult::AllEmbedding(response) = results else {
@@ -165,6 +174,7 @@ impl Infer {
165174
&self,
166175
inputs: I,
167176
truncate: bool,
177+
truncation_direction: TruncationDirection,
168178
permit: OwnedSemaphorePermit,
169179
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
170180
let start_time = Instant::now();
@@ -179,7 +189,14 @@ impl Infer {
179189
}
180190

181191
let results = self
182-
.embed(inputs, truncate, true, &start_time, permit)
192+
.embed(
193+
inputs,
194+
truncate,
195+
truncation_direction,
196+
true,
197+
&start_time,
198+
permit,
199+
)
183200
.await?;
184201

185202
let InferResult::PooledEmbedding(response) = results else {
@@ -213,6 +230,7 @@ impl Infer {
213230
&self,
214231
inputs: I,
215232
truncate: bool,
233+
truncation_direction: TruncationDirection,
216234
normalize: bool,
217235
permit: OwnedSemaphorePermit,
218236
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
@@ -228,7 +246,14 @@ impl Infer {
228246
}
229247

230248
let results = self
231-
.embed(inputs, truncate, true, &start_time, permit)
249+
.embed(
250+
inputs,
251+
truncate,
252+
truncation_direction,
253+
true,
254+
&start_time,
255+
permit,
256+
)
232257
.await?;
233258

234259
let InferResult::PooledEmbedding(mut response) = results else {
@@ -278,6 +303,7 @@ impl Infer {
278303
&self,
279304
inputs: I,
280305
truncate: bool,
306+
truncation_direction: TruncationDirection,
281307
pooling: bool,
282308
start_time: &Instant,
283309
_permit: OwnedSemaphorePermit,
@@ -296,7 +322,7 @@ impl Infer {
296322
// Tokenization
297323
let encoding = self
298324
.tokenization
299-
.encode(inputs.into(), truncate)
325+
.encode(inputs.into(), truncate, truncation_direction)
300326
.await
301327
.map_err(|err| {
302328
metrics::increment_counter!("te_request_failure", "err" => "tokenization");
@@ -340,6 +366,7 @@ impl Infer {
340366
&self,
341367
inputs: I,
342368
truncate: bool,
369+
truncation_direction: TruncationDirection,
343370
raw_scores: bool,
344371
_permit: OwnedSemaphorePermit,
345372
) -> Result<ClassificationInferResponse, TextEmbeddingsError> {
@@ -357,7 +384,7 @@ impl Infer {
357384
// Tokenization
358385
let encoding = self
359386
.tokenization
360-
.encode(inputs.into(), truncate)
387+
.encode(inputs.into(), truncate, truncation_direction)
361388
.await
362389
.map_err(|err| {
363390
metrics::increment_counter!("te_request_failure", "err" => "tokenization");

core/src/tokenization.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ impl Tokenization {
6464
&self,
6565
inputs: EncodingInput,
6666
truncate: bool,
67+
truncation_direction: TruncationDirection,
6768
) -> Result<ValidEncoding, TextEmbeddingsError> {
6869
// Check if inputs is empty
6970
if inputs.is_empty() {
@@ -80,6 +81,7 @@ impl Tokenization {
8081
.send(TokenizerRequest::Encode(
8182
inputs,
8283
truncate,
84+
truncation_direction,
8385
response_sender,
8486
Span::current(),
8587
))
@@ -163,14 +165,21 @@ fn tokenizer_worker(
163165
// Loop over requests
164166
while let Some(request) = receiver.blocking_recv() {
165167
match request {
166-
TokenizerRequest::Encode(inputs, truncate, response_tx, parent_span) => {
168+
TokenizerRequest::Encode(
169+
inputs,
170+
truncate,
171+
truncation_direction,
172+
response_tx,
173+
parent_span,
174+
) => {
167175
parent_span.in_scope(|| {
168176
if !response_tx.is_closed() {
169177
// It's possible that the user dropped its request resulting in a send error.
170178
// We just discard the error
171179
let _ = response_tx.send(encode_input(
172180
inputs,
173181
truncate,
182+
truncation_direction,
174183
max_input_length,
175184
position_offset,
176185
&mut tokenizer,
@@ -247,13 +256,14 @@ fn tokenize_input(
247256
fn encode_input(
248257
inputs: EncodingInput,
249258
truncate: bool,
259+
truncation_direction: TruncationDirection,
250260
max_input_length: usize,
251261
position_offset: usize,
252262
tokenizer: &mut Tokenizer,
253263
) -> Result<ValidEncoding, TextEmbeddingsError> {
254264
// Default truncation params
255265
let truncate_params = truncate.then_some(TruncationParams {
256-
direction: TruncationDirection::Right,
266+
direction: truncation_direction,
257267
max_length: max_input_length,
258268
strategy: TruncationStrategy::LongestFirst,
259269
stride: 0,
@@ -316,6 +326,7 @@ enum TokenizerRequest {
316326
Encode(
317327
EncodingInput,
318328
bool,
329+
TruncationDirection,
319330
oneshot::Sender<Result<ValidEncoding, TextEmbeddingsError>>,
320331
Span,
321332
),

proto/tei.proto

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,16 @@ message Metadata {
6969
uint64 inference_time_ns = 6;
7070
}
7171

72+
enum TruncationDirection {
73+
TRUNCATION_DIRECTION_RIGHT = 0;
74+
TRUNCATION_DIRECTION_LEFT = 1;
75+
}
76+
7277
message EmbedRequest {
7378
string inputs = 1;
7479
bool truncate = 2;
7580
bool normalize = 3;
81+
TruncationDirection truncation_direction = 4;
7682
}
7783

7884
message EmbedResponse {
@@ -83,6 +89,7 @@ message EmbedResponse {
8389
message EmbedSparseRequest {
8490
string inputs = 1;
8591
bool truncate = 2;
92+
TruncationDirection truncation_direction = 3;
8693
}
8794

8895
message SparseValue {
@@ -98,6 +105,7 @@ message EmbedSparseResponse {
98105
message EmbedAllRequest {
99106
string inputs = 1;
100107
bool truncate = 2;
108+
TruncationDirection truncation_direction = 3;
101109
}
102110

103111
message TokenEmbedding {
@@ -113,12 +121,14 @@ message PredictRequest {
113121
string inputs = 1;
114122
bool truncate = 2;
115123
bool raw_scores = 3;
124+
TruncationDirection truncation_direction = 4;
116125
}
117126

118127
message PredictPairRequest {
119128
repeated string inputs = 1;
120129
bool truncate = 2;
121130
bool raw_scores = 3;
131+
TruncationDirection truncation_direction = 4;
122132
}
123133

124134
message Prediction {
@@ -137,6 +147,7 @@ message RerankRequest {
137147
bool truncate = 3;
138148
bool raw_scores = 4;
139149
bool return_text = 5;
150+
TruncationDirection truncation_direction = 6;
140151
}
141152

142153
message RerankStreamRequest{
@@ -147,6 +158,7 @@ message RerankStreamRequest{
147158
bool raw_scores = 4;
148159
// The server will only consider the first value
149160
bool return_text = 5;
161+
TruncationDirection truncation_direction = 6;
150162
}
151163

152164
message Rank {

0 commit comments

Comments
 (0)