Skip to content

Commit 68d63ed

Browse files
feat: add auto-truncate arg (#224)
1 parent 53e28e0 commit 68d63ed

File tree

4 files changed

+38
-23
lines changed

4 files changed

+38
-23
lines changed

router/src/http/server.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ async fn predict(
159159
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
160160
let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner(
161161
inputs,
162-
req.truncate,
162+
req.truncate.unwrap_or(info.auto_truncate),
163163
req.raw_scores,
164164
infer.0,
165165
info.0,
@@ -208,7 +208,7 @@ async fn predict(
208208
let local_info = info.clone();
209209
futures.push(predict_inner(
210210
input,
211-
req.truncate,
211+
req.truncate.unwrap_or(info.auto_truncate),
212212
req.raw_scores,
213213
local_infer.0,
214214
local_info.0,
@@ -370,7 +370,7 @@ async fn rerank(
370370
futures.push(rerank_inner(
371371
req.query.clone(),
372372
text.clone(),
373-
req.truncate,
373+
req.truncate.unwrap_or(info.auto_truncate),
374374
req.raw_scores,
375375
local_infer.0,
376376
))
@@ -478,7 +478,12 @@ async fn embed(
478478

479479
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
480480
let response = infer
481-
.embed_pooled(input, req.truncate, req.normalize, permit)
481+
.embed_pooled(
482+
input,
483+
req.truncate.unwrap_or(info.auto_truncate),
484+
req.normalize,
485+
permit,
486+
)
482487
.await
483488
.map_err(ErrorResponse::from)?;
484489

@@ -531,11 +536,12 @@ async fn embed(
531536
for input in inputs {
532537
compute_chars += input.count_chars();
533538

539+
let truncate = req.truncate.unwrap_or(info.auto_truncate);
534540
let local_infer = infer.clone();
535541
futures.push(async move {
536542
let permit = local_infer.acquire_permit().await;
537543
local_infer
538-
.embed_pooled(input, req.truncate, req.normalize, permit)
544+
.embed_pooled(input, truncate, req.normalize, permit)
539545
.await
540546
})
541547
}
@@ -634,7 +640,7 @@ async fn embed_sparse(
634640

635641
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
636642
let response = infer
637-
.embed_sparse(input, req.truncate, permit)
643+
.embed_sparse(input, req.truncate.unwrap_or(info.auto_truncate), permit)
638644
.await
639645
.map_err(ErrorResponse::from)?;
640646

@@ -687,12 +693,11 @@ async fn embed_sparse(
687693
for input in inputs {
688694
compute_chars += input.count_chars();
689695

696+
let truncate = req.truncate.unwrap_or(info.auto_truncate);
690697
let local_infer = infer.clone();
691698
futures.push(async move {
692699
let permit = local_infer.acquire_permit().await;
693-
let response = local_infer
694-
.embed_sparse(input, req.truncate, permit)
695-
.await?;
700+
let response = local_infer.embed_sparse(input, truncate, permit).await?;
696701
Ok((sparsify(response.results), response.metadata))
697702
})
698703
}
@@ -782,7 +787,7 @@ async fn embed_all(
782787

783788
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
784789
let response = infer
785-
.embed_all(input, req.truncate, permit)
790+
.embed_all(input, req.truncate.unwrap_or(info.auto_truncate), permit)
786791
.await
787792
.map_err(ErrorResponse::from)?;
788793

@@ -835,10 +840,11 @@ async fn embed_all(
835840
for input in inputs {
836841
compute_chars += input.count_chars();
837842

843+
let truncate = req.truncate.unwrap_or(info.auto_truncate);
838844
let local_infer = infer.clone();
839845
futures.push(async move {
840846
let permit = local_infer.acquire_permit().await;
841-
local_infer.embed_all(input, req.truncate, permit).await
847+
local_infer.embed_all(input, truncate, permit).await
842848
})
843849
}
844850
let results = join_all(futures)

router/src/http/types.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,8 @@ impl<'__s> ToSchema<'__s> for PredictInput {
196196
#[derive(Deserialize, ToSchema)]
197197
pub(crate) struct PredictRequest {
198198
pub inputs: PredictInput,
199-
#[serde(default)]
200-
#[schema(default = "false", example = "false")]
201-
pub truncate: bool,
199+
#[schema(default = "false", example = "false", nullable = true)]
200+
pub truncate: Option<bool>,
202201
#[serde(default)]
203202
#[schema(default = "false", example = "false")]
204203
pub raw_scores: bool,
@@ -226,8 +225,8 @@ pub(crate) struct RerankRequest {
226225
#[schema(example = json!(["Deep Learning is ..."]))]
227226
pub texts: Vec<String>,
228227
#[serde(default)]
229-
#[schema(default = "false", example = "false")]
230-
pub truncate: bool,
228+
#[schema(default = "false", example = "false", nullable = true)]
229+
pub truncate: Option<bool>,
231230
#[serde(default)]
232231
#[schema(default = "false", example = "false")]
233232
pub raw_scores: bool,
@@ -322,8 +321,8 @@ pub(crate) struct OpenAICompatResponse {
322321
pub(crate) struct EmbedRequest {
323322
pub inputs: Input,
324323
#[serde(default)]
325-
#[schema(default = "false", example = "false")]
326-
pub truncate: bool,
324+
#[schema(default = "false", example = "false", nullable = true)]
325+
pub truncate: Option<bool>,
327326
#[serde(default = "default_normalize")]
328327
#[schema(default = "true", example = "true")]
329328
pub normalize: bool,
@@ -341,8 +340,8 @@ pub(crate) struct EmbedResponse(pub Vec<Vec<f32>>);
341340
pub(crate) struct EmbedSparseRequest {
342341
pub inputs: Input,
343342
#[serde(default)]
344-
#[schema(default = "false", example = "false")]
345-
pub truncate: bool,
343+
#[schema(default = "false", example = "false", nullable = true)]
344+
pub truncate: Option<bool>,
346345
}
347346

348347
#[derive(Serialize, ToSchema)]
@@ -358,8 +357,8 @@ pub(crate) struct EmbedSparseResponse(pub Vec<Vec<SparseValue>>);
358357
pub(crate) struct EmbedAllRequest {
359358
pub inputs: Input,
360359
#[serde(default)]
361-
#[schema(default = "false", example = "false")]
362-
pub truncate: bool,
360+
#[schema(default = "false", example = "false", nullable = true)]
361+
pub truncate: Option<bool>,
363362
}
364363

365364
#[derive(Serialize, ToSchema)]

router/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub async fn run(
5151
max_batch_tokens: usize,
5252
max_batch_requests: Option<usize>,
5353
max_client_batch_size: usize,
54+
auto_truncate: bool,
5455
hf_api_token: Option<String>,
5556
hostname: Option<String>,
5657
port: u16,
@@ -236,6 +237,7 @@ pub async fn run(
236237
tokenization_workers,
237238
max_batch_requests,
238239
max_client_batch_size,
240+
auto_truncate,
239241
version: env!("CARGO_PKG_VERSION"),
240242
sha: option_env!("VERGEN_GIT_SHA"),
241243
docker_label: option_env!("DOCKER_LABEL"),
@@ -428,6 +430,7 @@ pub struct Info {
428430
pub max_batch_requests: Option<usize>,
429431
#[cfg_attr(feature = "http", schema(example = "32"))]
430432
pub max_client_batch_size: usize,
433+
pub auto_truncate: bool,
431434
#[cfg_attr(feature = "http", schema(example = "4"))]
432435
pub tokenization_workers: usize,
433436
/// Router Info

router/src/main.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ struct Args {
7373
#[clap(default_value = "32", long, env)]
7474
max_client_batch_size: usize,
7575

76+
/// Automatically truncate inputs that are longer than the maximum supported size
77+
///
78+
/// Unused for gRPC servers
79+
#[clap(long, env)]
80+
auto_truncate: bool,
81+
7682
/// Your HuggingFace hub token
7783
#[clap(long, env)]
7884
#[redact(partial)]
@@ -117,7 +123,7 @@ struct Args {
117123
#[clap(long, env)]
118124
otlp_endpoint: Option<String>,
119125

120-
// Unused for gRPC servers
126+
/// Unused for gRPC servers
121127
#[clap(long, env)]
122128
cors_allow_origin: Option<Vec<String>>,
123129
}
@@ -143,6 +149,7 @@ async fn main() -> Result<()> {
143149
args.max_batch_tokens,
144150
args.max_batch_requests,
145151
args.max_client_batch_size,
152+
args.auto_truncate,
146153
args.hf_api_token,
147154
Some(args.hostname),
148155
args.port,

0 commit comments

Comments
 (0)