1
1
/// Text Embedding Inference Webserver
2
2
pub mod server;
3
3
4
- use serde:: { Deserialize , Serialize } ;
4
+ use serde:: de:: { SeqAccess , Visitor } ;
5
+ use serde:: { de, Deserialize , Deserializer , Serialize } ;
6
+ use serde_json:: json;
5
7
use std:: collections:: HashMap ;
8
+ use std:: fmt:: Formatter ;
6
9
use text_embeddings_core:: tokenization:: EncodingInput ;
10
+ use utoipa:: openapi:: { RefOr , Schema } ;
7
11
use utoipa:: ToSchema ;
8
12
9
13
#[ derive( Clone , Debug , Serialize , ToSchema ) ]
@@ -59,8 +63,7 @@ pub struct Info {
59
63
pub docker_label : Option < & ' static str > ,
60
64
}
61
65
62
- #[ derive( Deserialize , ToSchema , Debug ) ]
63
- #[ serde( untagged) ]
66
+ #[ derive( Debug ) ]
64
67
pub ( crate ) enum Sequence {
65
68
Single ( String ) ,
66
69
Pair ( String , String ) ,
@@ -84,9 +87,171 @@ impl From<Sequence> for EncodingInput {
84
87
}
85
88
}
86
89
90
+ #[ derive( Debug ) ]
91
+ pub ( crate ) enum PredictInput {
92
+ Single ( Sequence ) ,
93
+ Batch ( Vec < Sequence > ) ,
94
+ }
95
+
96
+ impl < ' de > Deserialize < ' de > for PredictInput {
97
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
98
+ where
99
+ D : Deserializer < ' de > ,
100
+ {
101
+ #[ derive( Deserialize ) ]
102
+ #[ serde( untagged) ]
103
+ enum Internal {
104
+ Single ( String ) ,
105
+ Multiple ( Vec < String > ) ,
106
+ }
107
+
108
+ struct PredictInputVisitor ;
109
+
110
+ impl < ' de > Visitor < ' de > for PredictInputVisitor {
111
+ type Value = PredictInput ;
112
+
113
+ fn expecting ( & self , formatter : & mut Formatter ) -> std:: fmt:: Result {
114
+ formatter. write_str (
115
+ "a string, \
116
+ a pair of strings [string, string] \
117
+ or a batch of mixed strings and pairs [[string], [string, string], ...]",
118
+ )
119
+ }
120
+
121
+ fn visit_str < E > ( self , v : & str ) -> Result < Self :: Value , E >
122
+ where
123
+ E : de:: Error ,
124
+ {
125
+ Ok ( PredictInput :: Single ( Sequence :: Single ( v. to_string ( ) ) ) )
126
+ }
127
+
128
+ fn visit_seq < A > ( self , mut seq : A ) -> Result < Self :: Value , A :: Error >
129
+ where
130
+ A : SeqAccess < ' de > ,
131
+ {
132
+ let sequence_from_vec = |mut value : Vec < String > | {
133
+ // Validate that value is correct
134
+ match value. len ( ) {
135
+ 1 => Ok ( Sequence :: Single ( value. pop ( ) . unwrap ( ) ) ) ,
136
+ 2 => {
137
+ // Second element is last
138
+ let second = value. pop ( ) . unwrap ( ) ;
139
+ let first = value. pop ( ) . unwrap ( ) ;
140
+ Ok ( Sequence :: Pair ( first, second) )
141
+ }
142
+ // Sequence can only be a single string or a pair of strings
143
+ _ => Err ( de:: Error :: invalid_length ( value. len ( ) , & self ) ) ,
144
+ }
145
+ } ;
146
+
147
+ // Get first element
148
+ // This will determine if input is a batch or not
149
+ let s = match seq
150
+ . next_element :: < Internal > ( ) ?
151
+ . ok_or_else ( || de:: Error :: invalid_length ( 0 , & self ) ) ?
152
+ {
153
+ // Input is not a batch
154
+ // Return early
155
+ Internal :: Single ( value) => {
156
+ // Option get second element
157
+ let second = seq. next_element ( ) ?;
158
+
159
+ if seq. next_element :: < String > ( ) ?. is_some ( ) {
160
+ // Error as we do not accept > 2 elements
161
+ return Err ( de:: Error :: invalid_length ( 3 , & self ) ) ;
162
+ }
163
+
164
+ if let Some ( second) = second {
165
+ // Second element exists
166
+ // This is a pair
167
+ return Ok ( PredictInput :: Single ( Sequence :: Pair ( value, second) ) ) ;
168
+ } else {
169
+ // Second element does not exist
170
+ return Ok ( PredictInput :: Single ( Sequence :: Single ( value) ) ) ;
171
+ }
172
+ }
173
+ // Input is a batch
174
+ Internal :: Multiple ( value) => sequence_from_vec ( value) ,
175
+ } ?;
176
+
177
+ let mut batch = Vec :: with_capacity ( 32 ) ;
178
+ // Push first sequence
179
+ batch. push ( s) ;
180
+
181
+ // Iterate on all sequences
182
+ while let Some ( value) = seq. next_element :: < Vec < String > > ( ) ? {
183
+ // Validate sequence
184
+ let s = sequence_from_vec ( value) ?;
185
+ // Push to batch
186
+ batch. push ( s) ;
187
+ }
188
+ Ok ( PredictInput :: Batch ( batch) )
189
+ }
190
+ }
191
+
192
+ deserializer. deserialize_any ( PredictInputVisitor )
193
+ }
194
+ }
195
+
196
+ impl < ' __s > ToSchema < ' __s > for PredictInput {
197
+ fn schema ( ) -> ( & ' __s str , RefOr < Schema > ) {
198
+ (
199
+ "PredictInput" ,
200
+ utoipa:: openapi:: OneOfBuilder :: new ( )
201
+ . item (
202
+ utoipa:: openapi:: ObjectBuilder :: new ( )
203
+ . schema_type ( utoipa:: openapi:: SchemaType :: String )
204
+ . description ( Some ( "A single string" ) ) ,
205
+ )
206
+ . item (
207
+ utoipa:: openapi:: ArrayBuilder :: new ( )
208
+ . items (
209
+ utoipa:: openapi:: ObjectBuilder :: new ( )
210
+ . schema_type ( utoipa:: openapi:: SchemaType :: String ) ,
211
+ )
212
+ . description ( Some ( "A pair of strings" ) )
213
+ . min_items ( Some ( 2 ) )
214
+ . max_items ( Some ( 2 ) ) ,
215
+ )
216
+ . item (
217
+ utoipa:: openapi:: ArrayBuilder :: new ( ) . items (
218
+ utoipa:: openapi:: OneOfBuilder :: new ( )
219
+ . item (
220
+ utoipa:: openapi:: ArrayBuilder :: new ( )
221
+ . items (
222
+ utoipa:: openapi:: ObjectBuilder :: new ( )
223
+ . schema_type ( utoipa:: openapi:: SchemaType :: String ) ,
224
+ )
225
+ . description ( Some ( "A single string" ) )
226
+ . min_items ( Some ( 1 ) )
227
+ . max_items ( Some ( 1 ) ) ,
228
+ )
229
+ . item (
230
+ utoipa:: openapi:: ArrayBuilder :: new ( )
231
+ . items (
232
+ utoipa:: openapi:: ObjectBuilder :: new ( )
233
+ . schema_type ( utoipa:: openapi:: SchemaType :: String ) ,
234
+ )
235
+ . description ( Some ( "A pair of strings" ) )
236
+ . min_items ( Some ( 2 ) )
237
+ . max_items ( Some ( 2 ) ) ,
238
+ )
239
+ ) . description ( Some ( "A batch" ) ) ,
240
+ )
241
+ . description ( Some (
242
+ "Model input. \
243
+ Can be either a single string, a pair of strings or a batch of mixed single and pairs \
244
+ of strings.",
245
+ ) )
246
+ . example ( Some ( json ! ( "What is Deep Learning?" ) ) )
247
+ . into ( ) ,
248
+ )
249
+ }
250
+ }
251
+
87
252
#[ derive( Deserialize , ToSchema ) ]
88
253
pub ( crate ) struct PredictRequest {
89
- pub inputs : Sequence ,
254
+ pub inputs : PredictInput ,
90
255
#[ serde( default ) ]
91
256
#[ schema( default = "false" , example = "false" ) ]
92
257
pub truncate : bool ,
@@ -104,7 +269,11 @@ pub(crate) struct Prediction {
104
269
}
105
270
106
271
#[ derive( Serialize , ToSchema ) ]
107
- pub ( crate ) struct PredictResponse ( Vec < Prediction > ) ;
272
+ #[ serde( untagged) ]
273
+ pub ( crate ) enum PredictResponse {
274
+ Single ( Vec < Prediction > ) ,
275
+ Batch ( Vec < Vec < Prediction > > ) ,
276
+ }
108
277
109
278
#[ derive( Deserialize , ToSchema ) ]
110
279
#[ serde( untagged) ]
0 commit comments