Skip to content

Commit acbbb92

Browse files
authored
tokenizer max limit on input size (#324)
1 parent a0549e6 commit acbbb92

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

core/src/tokenization.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy};
77
use tokio::sync::oneshot;
88
use tracing::{instrument, Span};
99

10+
static MAX_CHAR_MULTIPLIER: usize = 250;
11+
1012
/// Validation
1113
#[derive(Debug, Clone)]
1214
pub struct Tokenization {
@@ -215,6 +217,7 @@ fn tokenizer_worker(
215217
let _ = response_tx.send(tokenize_input(
216218
inputs,
217219
add_special_tokens,
220+
max_input_length,
218221
None,
219222
default_prompt_clone,
220223
prompt_name,
@@ -269,9 +272,11 @@ fn prepare_pre_prompt(
269272
Ok(pre_prompt)
270273
}
271274

275+
#[allow(clippy::too_many_arguments)]
272276
fn tokenize_input(
273277
inputs: EncodingInput,
274278
add_special_tokens: bool,
279+
max_input_length: usize,
275280
truncate_params: Option<TruncationParams>,
276281
default_prompt: Option<String>,
277282
prompt_name: Option<String>,
@@ -280,6 +285,14 @@ fn tokenize_input(
280285
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
281286
let pre_prompt = prepare_pre_prompt(default_prompt, prompt_name, prompts)?;
282287

288+
let input_chars = inputs.count_chars();
289+
let limit = max_input_length * MAX_CHAR_MULTIPLIER;
290+
if input_chars > limit {
291+
return Err(TextEmbeddingsError::Validation(format!(
292+
"`inputs` must have less than {limit} characters. Given: {input_chars}"
293+
)));
294+
}
295+
283296
let encoding = match inputs {
284297
// encode input
285298
EncodingInput::Single(s) => {
@@ -359,6 +372,7 @@ fn encode_input(
359372
let (_, encoding) = tokenize_input(
360373
inputs,
361374
true,
375+
max_input_length,
362376
truncate_params,
363377
default_prompt,
364378
prompt_name,
@@ -404,6 +418,14 @@ impl EncodingInput {
404418
EncodingInput::Ids(v) => v.is_empty(),
405419
}
406420
}
421+
422+
fn count_chars(&self) -> usize {
423+
match self {
424+
EncodingInput::Single(s) => s.chars().count(),
425+
EncodingInput::Dual(s1, s2) => s1.chars().count() + s2.chars().count(),
426+
EncodingInput::Ids(v) => v.len(),
427+
}
428+
}
407429
}
408430

409431
impl From<String> for EncodingInput {

0 commit comments

Comments
 (0)