Skip to content

Commit 36161c1

Browse files
feat: accept batches in predict (#78)
1 parent b41601c commit 36161c1

File tree

4 files changed

+385
-82
lines changed

4 files changed

+385
-82
lines changed

core/src/infer.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ impl Infer {
180180
&self,
181181
inputs: I,
182182
truncate: bool,
183+
raw_scores: bool,
183184
_permit: OwnedSemaphorePermit,
184185
) -> Result<InferResponse, TextEmbeddingsError> {
185186
if !self.is_classifier() {
@@ -222,7 +223,7 @@ impl Infer {
222223

223224
self.notify_batching_task.notify_one();
224225

225-
let response = response_rx
226+
let mut response = response_rx
226227
.await
227228
.expect(
228229
"Infer batching task dropped the sender without sending a response. This is a bug.",
@@ -233,6 +234,30 @@ impl Infer {
233234
err
234235
})?;
235236

237+
if !raw_scores {
238+
// Softmax
239+
if response.results.len() > 1 {
240+
let max = *response
241+
.results
242+
.iter()
243+
.max_by(|x, y| x.abs().partial_cmp(&y.abs()).unwrap())
244+
.unwrap();
245+
246+
let mut den = 0.0;
247+
for v in response.results.iter_mut() {
248+
*v = (*v - max).exp();
249+
den += *v;
250+
}
251+
for v in response.results.iter_mut() {
252+
*v /= den;
253+
}
254+
}
255+
// Sigmoid
256+
else {
257+
response.results[0] = 1.0 / (1.0 + (-response.results[0]).exp());
258+
}
259+
}
260+
236261
// Timings
237262
let total_time = start_time.elapsed();
238263

docs/openapi.json

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -695,14 +695,59 @@
695695
}
696696
}
697697
},
698+
"PredictInput": {
699+
"oneOf": [
700+
{
701+
"type": "string",
702+
"description": "A single string"
703+
},
704+
{
705+
"type": "array",
706+
"items": {
707+
"type": "string"
708+
},
709+
"description": "A pair of strings",
710+
"maxItems": 2,
711+
"minItems": 2
712+
},
713+
{
714+
"type": "array",
715+
"items": {
716+
"oneOf": [
717+
{
718+
"type": "array",
719+
"items": {
720+
"type": "string"
721+
},
722+
"description": "A single string",
723+
"maxItems": 1,
724+
"minItems": 1
725+
},
726+
{
727+
"type": "array",
728+
"items": {
729+
"type": "string"
730+
},
731+
"description": "A pair of strings",
732+
"maxItems": 2,
733+
"minItems": 2
734+
}
735+
]
736+
},
737+
"description": "A batch"
738+
}
739+
],
740+
"description": "Model input. Can be either a single string, a pair of strings or a batch of mixed single and pairs of strings.",
741+
"example": "What is Deep Learning?"
742+
},
698743
"PredictRequest": {
699744
"type": "object",
700745
"required": [
701746
"inputs"
702747
],
703748
"properties": {
704749
"inputs": {
705-
"$ref": "#/components/schemas/Sequence"
750+
"$ref": "#/components/schemas/PredictInput"
706751
},
707752
"raw_scores": {
708753
"type": "boolean",
@@ -717,10 +762,23 @@
717762
}
718763
},
719764
"PredictResponse": {
720-
"type": "array",
721-
"items": {
722-
"$ref": "#/components/schemas/Prediction"
723-
}
765+
"oneOf": [
766+
{
767+
"type": "array",
768+
"items": {
769+
"$ref": "#/components/schemas/Prediction"
770+
}
771+
},
772+
{
773+
"type": "array",
774+
"items": {
775+
"type": "array",
776+
"items": {
777+
"$ref": "#/components/schemas/Prediction"
778+
}
779+
}
780+
}
781+
]
724782
},
725783
"Prediction": {
726784
"type": "object",
@@ -739,22 +797,6 @@
739797
"example": "0.5"
740798
}
741799
}
742-
},
743-
"Sequence": {
744-
"oneOf": [
745-
{
746-
"type": "string"
747-
},
748-
{
749-
"type": "array",
750-
"items": {
751-
"type": "string"
752-
},
753-
"description": "",
754-
"maxItems": 2,
755-
"minItems": 2
756-
}
757-
]
758800
}
759801
}
760802
},

router/src/lib.rs

Lines changed: 174 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
/// Text Embedding Inference Webserver
22
pub mod server;
33

4-
use serde::{Deserialize, Serialize};
4+
use serde::de::{SeqAccess, Visitor};
5+
use serde::{de, Deserialize, Deserializer, Serialize};
6+
use serde_json::json;
57
use std::collections::HashMap;
8+
use std::fmt::Formatter;
69
use text_embeddings_core::tokenization::EncodingInput;
10+
use utoipa::openapi::{RefOr, Schema};
711
use utoipa::ToSchema;
812

913
#[derive(Clone, Debug, Serialize, ToSchema)]
@@ -59,8 +63,7 @@ pub struct Info {
5963
pub docker_label: Option<&'static str>,
6064
}
6165

62-
#[derive(Deserialize, ToSchema, Debug)]
63-
#[serde(untagged)]
66+
#[derive(Debug)]
6467
pub(crate) enum Sequence {
6568
Single(String),
6669
Pair(String, String),
@@ -84,9 +87,171 @@ impl From<Sequence> for EncodingInput {
8487
}
8588
}
8689

90+
#[derive(Debug)]
91+
pub(crate) enum PredictInput {
92+
Single(Sequence),
93+
Batch(Vec<Sequence>),
94+
}
95+
96+
impl<'de> Deserialize<'de> for PredictInput {
97+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
98+
where
99+
D: Deserializer<'de>,
100+
{
101+
#[derive(Deserialize)]
102+
#[serde(untagged)]
103+
enum Internal {
104+
Single(String),
105+
Multiple(Vec<String>),
106+
}
107+
108+
struct PredictInputVisitor;
109+
110+
impl<'de> Visitor<'de> for PredictInputVisitor {
111+
type Value = PredictInput;
112+
113+
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
114+
formatter.write_str(
115+
"a string, \
116+
a pair of strings [string, string] \
117+
or a batch of mixed strings and pairs [[string], [string, string], ...]",
118+
)
119+
}
120+
121+
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
122+
where
123+
E: de::Error,
124+
{
125+
Ok(PredictInput::Single(Sequence::Single(v.to_string())))
126+
}
127+
128+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
129+
where
130+
A: SeqAccess<'de>,
131+
{
132+
let sequence_from_vec = |mut value: Vec<String>| {
133+
// Validate that value is correct
134+
match value.len() {
135+
1 => Ok(Sequence::Single(value.pop().unwrap())),
136+
2 => {
137+
// Second element is last
138+
let second = value.pop().unwrap();
139+
let first = value.pop().unwrap();
140+
Ok(Sequence::Pair(first, second))
141+
}
142+
// Sequence can only be a single string or a pair of strings
143+
_ => Err(de::Error::invalid_length(value.len(), &self)),
144+
}
145+
};
146+
147+
// Get first element
148+
// This will determine if input is a batch or not
149+
let s = match seq
150+
.next_element::<Internal>()?
151+
.ok_or_else(|| de::Error::invalid_length(0, &self))?
152+
{
153+
// Input is not a batch
154+
// Return early
155+
Internal::Single(value) => {
156+
// Option get second element
157+
let second = seq.next_element()?;
158+
159+
if seq.next_element::<String>()?.is_some() {
160+
// Error as we do not accept > 2 elements
161+
return Err(de::Error::invalid_length(3, &self));
162+
}
163+
164+
if let Some(second) = second {
165+
// Second element exists
166+
// This is a pair
167+
return Ok(PredictInput::Single(Sequence::Pair(value, second)));
168+
} else {
169+
// Second element does not exist
170+
return Ok(PredictInput::Single(Sequence::Single(value)));
171+
}
172+
}
173+
// Input is a batch
174+
Internal::Multiple(value) => sequence_from_vec(value),
175+
}?;
176+
177+
let mut batch = Vec::with_capacity(32);
178+
// Push first sequence
179+
batch.push(s);
180+
181+
// Iterate on all sequences
182+
while let Some(value) = seq.next_element::<Vec<String>>()? {
183+
// Validate sequence
184+
let s = sequence_from_vec(value)?;
185+
// Push to batch
186+
batch.push(s);
187+
}
188+
Ok(PredictInput::Batch(batch))
189+
}
190+
}
191+
192+
deserializer.deserialize_any(PredictInputVisitor)
193+
}
194+
}
195+
196+
impl<'__s> ToSchema<'__s> for PredictInput {
197+
fn schema() -> (&'__s str, RefOr<Schema>) {
198+
(
199+
"PredictInput",
200+
utoipa::openapi::OneOfBuilder::new()
201+
.item(
202+
utoipa::openapi::ObjectBuilder::new()
203+
.schema_type(utoipa::openapi::SchemaType::String)
204+
.description(Some("A single string")),
205+
)
206+
.item(
207+
utoipa::openapi::ArrayBuilder::new()
208+
.items(
209+
utoipa::openapi::ObjectBuilder::new()
210+
.schema_type(utoipa::openapi::SchemaType::String),
211+
)
212+
.description(Some("A pair of strings"))
213+
.min_items(Some(2))
214+
.max_items(Some(2)),
215+
)
216+
.item(
217+
utoipa::openapi::ArrayBuilder::new().items(
218+
utoipa::openapi::OneOfBuilder::new()
219+
.item(
220+
utoipa::openapi::ArrayBuilder::new()
221+
.items(
222+
utoipa::openapi::ObjectBuilder::new()
223+
.schema_type(utoipa::openapi::SchemaType::String),
224+
)
225+
.description(Some("A single string"))
226+
.min_items(Some(1))
227+
.max_items(Some(1)),
228+
)
229+
.item(
230+
utoipa::openapi::ArrayBuilder::new()
231+
.items(
232+
utoipa::openapi::ObjectBuilder::new()
233+
.schema_type(utoipa::openapi::SchemaType::String),
234+
)
235+
.description(Some("A pair of strings"))
236+
.min_items(Some(2))
237+
.max_items(Some(2)),
238+
)
239+
).description(Some("A batch")),
240+
)
241+
.description(Some(
242+
"Model input. \
243+
Can be either a single string, a pair of strings or a batch of mixed single and pairs \
244+
of strings.",
245+
))
246+
.example(Some(json!("What is Deep Learning?")))
247+
.into(),
248+
)
249+
}
250+
}
251+
87252
#[derive(Deserialize, ToSchema)]
88253
pub(crate) struct PredictRequest {
89-
pub inputs: Sequence,
254+
pub inputs: PredictInput,
90255
#[serde(default)]
91256
#[schema(default = "false", example = "false")]
92257
pub truncate: bool,
@@ -104,7 +269,11 @@ pub(crate) struct Prediction {
104269
}
105270

106271
#[derive(Serialize, ToSchema)]
107-
pub(crate) struct PredictResponse(Vec<Prediction>);
272+
#[serde(untagged)]
273+
pub(crate) enum PredictResponse {
274+
Single(Vec<Prediction>),
275+
Batch(Vec<Vec<Prediction>>),
276+
}
108277

109278
#[derive(Deserialize, ToSchema)]
110279
#[serde(untagged)]

0 commit comments

Comments
 (0)