Skip to content

Commit eef2912

Browse files
fix: fix auto_truncate for openai (#228)
1 parent a556f43 commit eef2912

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

router/src/http/server.rs

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ async fn predict(
151151
))
152152
};
153153

154+
let truncate = req.truncate.unwrap_or(info.auto_truncate);
155+
154156
let (response, metadata) = match req.inputs {
155157
PredictInput::Single(inputs) => {
156158
metrics::increment_counter!("te_request_count", "method" => "single");
@@ -159,7 +161,7 @@ async fn predict(
159161
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
160162
let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner(
161163
inputs,
162-
req.truncate.unwrap_or(info.auto_truncate),
164+
truncate,
163165
req.raw_scores,
164166
infer.0,
165167
info.0,
@@ -208,7 +210,7 @@ async fn predict(
208210
let local_info = info.clone();
209211
futures.push(predict_inner(
210212
input,
211-
req.truncate.unwrap_or(info.auto_truncate),
213+
truncate,
212214
req.raw_scores,
213215
local_infer.0,
214216
local_info.0,
@@ -342,6 +344,8 @@ async fn rerank(
342344
))
343345
};
344346

347+
let truncate = req.truncate.unwrap_or(info.auto_truncate);
348+
345349
let (response, metadata) = {
346350
metrics::increment_counter!("te_request_count", "method" => "batch");
347351

@@ -370,7 +374,7 @@ async fn rerank(
370374
futures.push(rerank_inner(
371375
req.query.clone(),
372376
text.clone(),
373-
req.truncate.unwrap_or(info.auto_truncate),
377+
truncate,
374378
req.raw_scores,
375379
local_infer.0,
376380
))
@@ -470,6 +474,8 @@ async fn embed(
470474
let span = tracing::Span::current();
471475
let start_time = Instant::now();
472476

477+
let truncate = req.truncate.unwrap_or(info.auto_truncate);
478+
473479
let (response, metadata) = match req.inputs {
474480
Input::Single(input) => {
475481
metrics::increment_counter!("te_request_count", "method" => "single");
@@ -478,12 +484,7 @@ async fn embed(
478484

479485
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
480486
let response = infer
481-
.embed_pooled(
482-
input,
483-
req.truncate.unwrap_or(info.auto_truncate),
484-
req.normalize,
485-
permit,
486-
)
487+
.embed_pooled(input, truncate, req.normalize, permit)
487488
.await
488489
.map_err(ErrorResponse::from)?;
489490

@@ -536,7 +537,6 @@ async fn embed(
536537
for input in inputs {
537538
compute_chars += input.count_chars();
538539

539-
let truncate = req.truncate.unwrap_or(info.auto_truncate);
540540
let local_infer = infer.clone();
541541
futures.push(async move {
542542
let permit = local_infer.acquire_permit().await;
@@ -631,6 +631,7 @@ async fn embed_sparse(
631631
}
632632
sparse_values
633633
};
634+
let truncate = req.truncate.unwrap_or(info.auto_truncate);
634635

635636
let (response, metadata) = match req.inputs {
636637
Input::Single(input) => {
@@ -640,7 +641,7 @@ async fn embed_sparse(
640641

641642
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
642643
let response = infer
643-
.embed_sparse(input, req.truncate.unwrap_or(info.auto_truncate), permit)
644+
.embed_sparse(input, truncate, permit)
644645
.await
645646
.map_err(ErrorResponse::from)?;
646647

@@ -693,7 +694,6 @@ async fn embed_sparse(
693694
for input in inputs {
694695
compute_chars += input.count_chars();
695696

696-
let truncate = req.truncate.unwrap_or(info.auto_truncate);
697697
let local_infer = infer.clone();
698698
futures.push(async move {
699699
let permit = local_infer.acquire_permit().await;
@@ -779,6 +779,8 @@ async fn embed_all(
779779
let span = tracing::Span::current();
780780
let start_time = Instant::now();
781781

782+
let truncate = req.truncate.unwrap_or(info.auto_truncate);
783+
782784
let (response, metadata) = match req.inputs {
783785
Input::Single(input) => {
784786
metrics::increment_counter!("te_request_count", "method" => "single");
@@ -787,7 +789,7 @@ async fn embed_all(
787789

788790
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
789791
let response = infer
790-
.embed_all(input, req.truncate.unwrap_or(info.auto_truncate), permit)
792+
.embed_all(input, truncate, permit)
791793
.await
792794
.map_err(ErrorResponse::from)?;
793795

@@ -840,7 +842,6 @@ async fn embed_all(
840842
for input in inputs {
841843
compute_chars += input.count_chars();
842844

843-
let truncate = req.truncate.unwrap_or(info.auto_truncate);
844845
let local_infer = infer.clone();
845846
futures.push(async move {
846847
let permit = local_infer.acquire_permit().await;
@@ -925,6 +926,8 @@ async fn openai_embed(
925926
let span = tracing::Span::current();
926927
let start_time = Instant::now();
927928

929+
let truncate = info.auto_truncate;
930+
928931
let (embeddings, metadata) = match req.input {
929932
Input::Single(input) => {
930933
metrics::increment_counter!("te_request_count", "method" => "single");
@@ -933,7 +936,7 @@ async fn openai_embed(
933936

934937
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
935938
let response = infer
936-
.embed_pooled(input, false, true, permit)
939+
.embed_pooled(input, truncate, true, permit)
937940
.await
938941
.map_err(ErrorResponse::from)?;
939942

@@ -993,7 +996,9 @@ async fn openai_embed(
993996
let local_infer = infer.clone();
994997
futures.push(async move {
995998
let permit = local_infer.acquire_permit().await;
996-
local_infer.embed_pooled(input, false, true, permit).await
999+
local_infer
1000+
.embed_pooled(input, truncate, true, permit)
1001+
.await
9971002
})
9981003
}
9991004
let results = join_all(futures)

0 commit comments

Comments
 (0)