Skip to content

Commit cb1e594

Browse files
feat: auto limit string if truncate is set (#428)
1 parent 750898d commit cb1e594

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

core/src/tokenization.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ fn prepare_pre_prompt(
274274

275275
#[allow(clippy::too_many_arguments)]
276276
fn tokenize_input(
277-
inputs: EncodingInput,
277+
mut inputs: EncodingInput,
278278
add_special_tokens: bool,
279279
max_input_length: usize,
280280
truncate_params: Option<TruncationParams>,
@@ -288,9 +288,12 @@ fn tokenize_input(
288288
let input_chars = inputs.count_chars();
289289
let limit = max_input_length * MAX_CHAR_MULTIPLIER;
290290
if input_chars > limit {
291-
return Err(TextEmbeddingsError::Validation(format!(
292-
"`inputs` must have less than {limit} characters. Given: {input_chars}"
293-
)));
291+
if truncate_params.is_none() {
292+
return Err(TextEmbeddingsError::Validation(format!(
293+
"`inputs` must have less than {limit} characters. Given: {input_chars}"
294+
)));
295+
}
296+
inputs.apply_limit(limit);
294297
}
295298

296299
let encoding = match inputs {
@@ -426,6 +429,25 @@ impl EncodingInput {
426429
EncodingInput::Ids(v) => v.len(),
427430
}
428431
}
432+
433+
fn apply_limit(&mut self, limit: usize) {
434+
let truncate_string = |s: &mut String, limit: usize| {
435+
if s.is_char_boundary(limit) {
436+
s.truncate(limit)
437+
}
438+
};
439+
440+
match self {
441+
EncodingInput::Single(s) => {
442+
truncate_string(s, limit);
443+
}
444+
EncodingInput::Dual(s1, s2) => {
445+
truncate_string(s1, limit / 2);
446+
truncate_string(s2, limit / 2);
447+
}
448+
EncodingInput::Ids(_) => {}
449+
}
450+
}
429451
}
430452

431453
impl From<String> for EncodingInput {

0 commit comments

Comments
 (0)