@@ -7,6 +7,8 @@ use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy};
7
7
use tokio:: sync:: oneshot;
8
8
use tracing:: { instrument, Span } ;
9
9
10
+ static MAX_CHAR_MULTIPLIER : usize = 250 ;
11
+
10
12
/// Validation
11
13
#[ derive( Debug , Clone ) ]
12
14
pub struct Tokenization {
@@ -215,6 +217,7 @@ fn tokenizer_worker(
215
217
let _ = response_tx. send ( tokenize_input (
216
218
inputs,
217
219
add_special_tokens,
220
+ max_input_length,
218
221
None ,
219
222
default_prompt_clone,
220
223
prompt_name,
@@ -269,9 +272,11 @@ fn prepare_pre_prompt(
269
272
Ok ( pre_prompt)
270
273
}
271
274
275
+ #[ allow( clippy:: too_many_arguments) ]
272
276
fn tokenize_input (
273
277
inputs : EncodingInput ,
274
278
add_special_tokens : bool ,
279
+ max_input_length : usize ,
275
280
truncate_params : Option < TruncationParams > ,
276
281
default_prompt : Option < String > ,
277
282
prompt_name : Option < String > ,
@@ -280,6 +285,14 @@ fn tokenize_input(
280
285
) -> Result < ( Option < String > , RawEncoding ) , TextEmbeddingsError > {
281
286
let pre_prompt = prepare_pre_prompt ( default_prompt, prompt_name, prompts) ?;
282
287
288
+ let input_chars = inputs. count_chars ( ) ;
289
+ let limit = max_input_length * MAX_CHAR_MULTIPLIER ;
290
+ if input_chars > limit {
291
+ return Err ( TextEmbeddingsError :: Validation ( format ! (
292
+ "`inputs` must have less than {limit} characters. Given: {input_chars}"
293
+ ) ) ) ;
294
+ }
295
+
283
296
let encoding = match inputs {
284
297
// encode input
285
298
EncodingInput :: Single ( s) => {
@@ -359,6 +372,7 @@ fn encode_input(
359
372
let ( _, encoding) = tokenize_input (
360
373
inputs,
361
374
true ,
375
+ max_input_length,
362
376
truncate_params,
363
377
default_prompt,
364
378
prompt_name,
@@ -404,6 +418,14 @@ impl EncodingInput {
404
418
EncodingInput :: Ids ( v) => v. is_empty ( ) ,
405
419
}
406
420
}
421
+
422
+ fn count_chars ( & self ) -> usize {
423
+ match self {
424
+ EncodingInput :: Single ( s) => s. chars ( ) . count ( ) ,
425
+ EncodingInput :: Dual ( s1, s2) => s1. chars ( ) . count ( ) + s2. chars ( ) . count ( ) ,
426
+ EncodingInput :: Ids ( v) => v. len ( ) ,
427
+ }
428
+ }
407
429
}
408
430
409
431
impl From < String > for EncodingInput {
0 commit comments