Skip to content

Commit a1dd76d

Browse files
Input Types Compatibility with OpenAI's API (#112) (#214)
Co-authored-by: Numan Laanait <nlaanait@gmail.com>
1 parent 90ea664 commit a1dd76d

File tree

4 files changed

+97
-29
lines changed

4 files changed

+97
-29
lines changed

core/src/tokenization.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
use crate::TextEmbeddingsError;
33
use tokenizers::tokenizer::Tokenizer;
44
pub use tokenizers::Encoding as RawEncoding;
5-
use tokenizers::{EncodeInput, TruncationDirection, TruncationParams, TruncationStrategy};
5+
use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy};
66
use tokio::sync::{mpsc, oneshot};
77
use tracing::{instrument, Span};
88

@@ -222,14 +222,25 @@ fn tokenize_input(
222222
truncate_params: Option<TruncationParams>,
223223
tokenizer: &mut Tokenizer,
224224
) -> Result<RawEncoding, TextEmbeddingsError> {
225-
let inputs: EncodeInput = match inputs {
226-
EncodingInput::Single(s) => s.into(),
227-
EncodingInput::Dual(s1, s2) => (s1, s2).into(),
225+
let encoding = match inputs {
226+
// encode input
227+
EncodingInput::Single(s) => tokenizer
228+
.with_truncation(truncate_params)?
229+
.encode::<String>(s, add_special_tokens)?,
230+
EncodingInput::Dual(s1, s2) => {
231+
tokenizer
232+
.with_truncation(truncate_params)?
233+
.encode::<(String, String)>((s1, s2), add_special_tokens)?
234+
}
235+
// input is encoded -> convert to tokenizers Encoding
236+
EncodingInput::Ids(ids) => {
237+
let text = tokenizer.decode(&ids, false)?;
238+
tokenizer
239+
.with_truncation(truncate_params)?
240+
.encode::<String>(text, false)?
241+
}
228242
};
229-
230-
Ok(tokenizer
231-
.with_truncation(truncate_params)?
232-
.encode(inputs, add_special_tokens)?)
243+
Ok(encoding)
233244
}
234245

235246
/// Get input length and optionally truncate it
@@ -256,9 +267,7 @@ fn encode_input(
256267
"`inputs` must have less than {max_input_length} tokens. Given: {seq_len}"
257268
)));
258269
}
259-
260270
metrics::histogram!("te_request_input_length", seq_len as f64);
261-
262271
Ok(ValidEncoding {
263272
input_ids: encoding.get_ids().to_vec(),
264273
token_type_ids: encoding.get_type_ids().to_vec(),
@@ -278,13 +287,15 @@ pub struct ValidEncoding {
278287
pub enum EncodingInput {
279288
Single(String),
280289
Dual(String, String),
290+
Ids(Vec<u32>),
281291
}
282292

283293
impl EncodingInput {
284294
fn is_empty(&self) -> bool {
285295
match self {
286296
EncodingInput::Single(s) => s.is_empty(),
287297
EncodingInput::Dual(s1, s2) => s1.is_empty() && s2.is_empty(),
298+
EncodingInput::Ids(v) => v.is_empty(),
288299
}
289300
}
290301
}

router/src/http/server.rs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
/// HTTP Server logic
22
use crate::http::types::{
33
DecodeRequest, DecodeResponse, EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse,
4-
EmbedSparseRequest, EmbedSparseResponse, Input, InputIds, OpenAICompatEmbedding,
4+
EmbedSparseRequest, EmbedSparseResponse, Input, InputIds, InputType, OpenAICompatEmbedding,
55
OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
66
PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
7-
Sequence, SimpleToken, SparseValue, TokenizeRequest, TokenizeResponse, VertexPrediction,
8-
VertexRequest, VertexResponse,
7+
Sequence, SimpleToken, SparseValue, TokenizeInput, TokenizeRequest, TokenizeResponse,
8+
VertexPrediction, VertexRequest, VertexResponse,
99
};
1010
use crate::{
1111
shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
@@ -474,7 +474,7 @@ async fn embed(
474474
Input::Single(input) => {
475475
metrics::increment_counter!("te_request_count", "method" => "single");
476476

477-
let compute_chars = input.chars().count();
477+
let compute_chars = input.count_chars();
478478

479479
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
480480
let response = infer
@@ -529,7 +529,7 @@ async fn embed(
529529
let mut compute_chars = 0;
530530

531531
for input in inputs {
532-
compute_chars += input.chars().count();
532+
compute_chars += input.count_chars();
533533

534534
let local_infer = infer.clone();
535535
futures.push(async move {
@@ -630,7 +630,7 @@ async fn embed_sparse(
630630
Input::Single(input) => {
631631
metrics::increment_counter!("te_request_count", "method" => "single");
632632

633-
let compute_chars = input.chars().count();
633+
let compute_chars = input.count_chars();
634634

635635
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
636636
let response = infer
@@ -685,7 +685,7 @@ async fn embed_sparse(
685685
let mut compute_chars = 0;
686686

687687
for input in inputs {
688-
compute_chars += input.chars().count();
688+
compute_chars += input.count_chars();
689689

690690
let local_infer = infer.clone();
691691
futures.push(async move {
@@ -778,7 +778,7 @@ async fn embed_all(
778778
Input::Single(input) => {
779779
metrics::increment_counter!("te_request_count", "method" => "single");
780780

781-
let compute_chars = input.chars().count();
781+
let compute_chars = input.count_chars();
782782

783783
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
784784
let response = infer
@@ -833,7 +833,7 @@ async fn embed_all(
833833
let mut compute_chars = 0;
834834

835835
for input in inputs {
836-
compute_chars += input.chars().count();
836+
compute_chars += input.count_chars();
837837

838838
let local_infer = infer.clone();
839839
futures.push(async move {
@@ -892,7 +892,7 @@ async fn embed_all(
892892
#[utoipa::path(
893893
post,
894894
tag = "Text Embeddings Inference",
895-
path = "/embeddings",
895+
path = "/v1/embeddings",
896896
request_body = OpenAICompatRequest,
897897
responses(
898898
(status = 200, description = "Embeddings", body = OpenAICompatResponse),
@@ -923,7 +923,7 @@ async fn openai_embed(
923923
Input::Single(input) => {
924924
metrics::increment_counter!("te_request_count", "method" => "single");
925925

926-
let compute_chars = input.chars().count();
926+
let compute_chars = input.count_chars();
927927

928928
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
929929
let response = infer
@@ -982,7 +982,7 @@ async fn openai_embed(
982982
let mut compute_chars = 0;
983983

984984
for input in inputs {
985-
compute_chars += input.chars().count();
985+
compute_chars += input.count_chars();
986986

987987
let local_infer = infer.clone();
988988
futures.push(async move {
@@ -1107,8 +1107,10 @@ async fn tokenize(
11071107
};
11081108

11091109
let tokens = match req.inputs {
1110-
Input::Single(input) => vec![tokenize_inner(input, req.add_special_tokens, infer.0).await?],
1111-
Input::Batch(inputs) => {
1110+
TokenizeInput::Single(input) => {
1111+
vec![tokenize_inner(input, req.add_special_tokens, infer.0).await?]
1112+
}
1113+
TokenizeInput::Batch(inputs) => {
11121114
if inputs.is_empty() {
11131115
let message = "`inputs` cannot be empty".to_string();
11141116
tracing::error!("{message}");
@@ -1369,9 +1371,11 @@ pub async fn run(
13691371
EmbedResponse,
13701372
ErrorResponse,
13711373
OpenAICompatErrorResponse,
1374+
TokenizeInput,
13721375
TokenizeRequest,
13731376
TokenizeResponse,
13741377
SimpleToken,
1378+
InputType,
13751379
InputIds,
13761380
DecodeRequest,
13771381
DecodeResponse,
@@ -1448,6 +1452,7 @@ pub async fn run(
14481452
.route("/decode", post(decode))
14491453
// OpenAI compat route
14501454
.route("/embeddings", post(openai_embed))
1455+
.route("/v1/embeddings", post(openai_embed))
14511456
// Vertex compat route
14521457
.route("/vertex", post(vertex_compatibility))
14531458
// Base Health route

router/src/http/types.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,33 @@ pub(crate) struct Rank {
250250
#[derive(Serialize, ToSchema)]
251251
pub(crate) struct RerankResponse(pub Vec<Rank>);
252252

253+
#[derive(Deserialize, ToSchema, Debug)]
254+
#[serde(untagged)]
255+
pub(crate) enum InputType {
256+
String(String),
257+
Ids(Vec<u32>),
258+
}
259+
impl InputType {
260+
pub(crate) fn count_chars(&self) -> usize {
261+
match self {
262+
InputType::String(s) => s.chars().count(),
263+
InputType::Ids(v) => v.len(),
264+
}
265+
}
266+
}
267+
impl From<InputType> for EncodingInput {
268+
fn from(value: InputType) -> Self {
269+
match value {
270+
InputType::String(s) => Self::Single(s),
271+
InputType::Ids(v) => Self::Ids(v),
272+
}
273+
}
274+
}
253275
#[derive(Deserialize, ToSchema)]
254276
#[serde(untagged)]
255277
pub(crate) enum Input {
256-
Single(String),
257-
Batch(Vec<String>),
278+
Single(InputType),
279+
Batch(Vec<InputType>),
258280
}
259281

260282
#[derive(Deserialize, ToSchema)]
@@ -352,9 +374,16 @@ pub(crate) struct OpenAICompatErrorResponse {
352374
pub error_type: ErrorType,
353375
}
354376

377+
#[derive(Deserialize, ToSchema)]
378+
#[serde(untagged)]
379+
pub(crate) enum TokenizeInput {
380+
Single(String),
381+
Batch(Vec<String>),
382+
}
383+
355384
#[derive(Deserialize, ToSchema)]
356385
pub(crate) struct TokenizeRequest {
357-
pub inputs: Input,
386+
pub inputs: TokenizeInput,
358387
#[serde(default = "default_add_special_tokens")]
359388
#[schema(default = "true", example = "true")]
360389
pub add_special_tokens: bool,

router/tests/test_http_embed.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ async fn test_embeddings() -> Result<()> {
1919
let request = json!({
2020
"inputs": "test"
2121
});
22-
2322
let client = reqwest::Client::new();
2423
let res = client
2524
.post("http://0.0.0.0:8090/embed")
@@ -31,6 +30,18 @@ async fn test_embeddings() -> Result<()> {
3130
let matcher = YamlMatcher::<Vec<Vec<Score>>>::new();
3231
insta::assert_yaml_snapshot!("embeddings_single", embeddings_single, &matcher);
3332

33+
let test_tokens = vec![[101, 3231, 102]]; // tokenized "test"
34+
let request = json!({"inputs": &test_tokens});
35+
let res = client
36+
.post("http://0.0.0.0:8090/embed")
37+
.json(&request)
38+
.send()
39+
.await?;
40+
41+
let embeddings_single = res.json::<Vec<Vec<Score>>>().await?;
42+
let matcher = YamlMatcher::<Vec<Vec<Score>>>::new();
43+
insta::assert_yaml_snapshot!("embeddings_single", embeddings_single, &matcher);
44+
3445
let request = json!({
3546
"inputs": vec!["test", "test", "test", "test", "test"],
3647
});
@@ -41,10 +52,22 @@ async fn test_embeddings() -> Result<()> {
4152
.json(&request)
4253
.send()
4354
.await?;
44-
4555
let embeddings_batch = res.json::<Vec<Vec<Score>>>().await?;
4656
insta::assert_yaml_snapshot!("embeddings_batch", embeddings_batch, &matcher);
57+
for embeddings in &embeddings_batch {
58+
assert_eq!(embeddings, &embeddings_single[0]);
59+
}
4760

61+
let request =
62+
json!({"inputs": &test_tokens.repeat(request["inputs"].as_array().unwrap().len())});
63+
let res = client
64+
.post("http://0.0.0.0:8090/embed")
65+
.json(&request)
66+
.send()
67+
.await?;
68+
69+
let embeddings_batch = res.json::<Vec<Vec<Score>>>().await?;
70+
insta::assert_yaml_snapshot!("embeddings_batch", embeddings_batch, &matcher);
4871
for embeddings in &embeddings_batch {
4972
assert_eq!(embeddings, &embeddings_single[0]);
5073
}

0 commit comments

Comments
 (0)