Skip to content

Commit b692e9c

Browse files
authored
[Misc] Fix skipped max-model-len validation when deriving max model length from tokenizer config (#19660)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
1 parent 367871a commit b692e9c

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

tests/test_config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,31 @@ def test_load_config_pt_load_map_location(pt_load_map_location):
438438
config = VllmConfig(load_config=load_config)
439439

440440
assert config.load_config.pt_load_map_location == pt_load_map_location
441+
442+
443+
@pytest.mark.parametrize(
444+
("model_id", "max_model_len", "expected_max_len", "should_raise"), [
445+
("BAAI/bge-reranker-base", None, 512, False),
446+
("BAAI/bge-reranker-base", 256, 256, False),
447+
("BAAI/bge-reranker-base", 513, 512, True),
448+
])
449+
def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
450+
should_raise):
451+
"""Test get_and_verify_max_len with different configurations."""
452+
model_config = ModelConfig(
453+
model_id,
454+
task="auto",
455+
tokenizer=model_id,
456+
tokenizer_mode="auto",
457+
trust_remote_code=False,
458+
seed=0,
459+
dtype="float16",
460+
revision=None,
461+
)
462+
463+
if should_raise:
464+
with pytest.raises(ValueError):
465+
model_config.get_and_verify_max_len(max_model_len)
466+
else:
467+
actual_max_len = model_config.get_and_verify_max_len(max_model_len)
468+
assert actual_max_len == expected_max_len

vllm/config.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,25 +1429,19 @@ def matryoshka_dimensions(self):
14291429
return getattr(self.hf_config, "matryoshka_dimensions", None)
14301430

14311431
def get_and_verify_max_len(self, max_model_len: int):
1432+
tokenizer_config = try_get_tokenizer_config(
1433+
self.tokenizer,
1434+
trust_remote_code=self.trust_remote_code,
1435+
revision=self.tokenizer_revision)
14321436
max_model_len = _get_and_verify_max_len(
14331437
hf_config=self.hf_text_config,
1438+
tokenizer_config=tokenizer_config,
14341439
max_model_len=max_model_len,
14351440
disable_sliding_window=self.disable_sliding_window,
14361441
sliding_window_len=self.get_hf_config_sliding_window(),
14371442
spec_target_max_model_len=self.spec_target_max_model_len,
14381443
encoder_config=self.encoder_config)
1439-
1440-
tokenizer_config = try_get_tokenizer_config(
1441-
self.tokenizer,
1442-
trust_remote_code=self.trust_remote_code,
1443-
revision=self.tokenizer_revision)
1444-
1445-
if tokenizer_config is None:
1446-
return max_model_len
1447-
1448-
model_max_length = tokenizer_config.get("model_max_length",
1449-
max_model_len)
1450-
max_model_len = min(max_model_len, model_max_length)
1444+
logger.info("Using max model len %s", max_model_len)
14511445
return max_model_len
14521446

14531447

@@ -3283,6 +3277,7 @@ def _get_and_verify_dtype(
32833277

32843278
def _get_and_verify_max_len(
32853279
hf_config: PretrainedConfig,
3280+
tokenizer_config: Optional[dict],
32863281
max_model_len: Optional[int],
32873282
disable_sliding_window: bool,
32883283
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
@@ -3309,7 +3304,7 @@ def _get_and_verify_max_len(
33093304
"max_seq_length",
33103305
"seq_len",
33113306
]
3312-
# Choose the smallest "max_length" from the possible keys.
3307+
# Choose the smallest "max_length" from the possible keys
33133308
max_len_key = None
33143309
for key in possible_keys:
33153310
max_len = getattr(hf_config, key, None)
@@ -3332,6 +3327,13 @@ def _get_and_verify_max_len(
33323327
derived_max_model_len = min(derived_max_model_len,
33333328
sliding_window_len_min)
33343329

3330+
# Consider model_max_length in tokenizer_config
3331+
if tokenizer_config:
3332+
tokenizer_model_max_length = tokenizer_config.get(
3333+
"model_max_length", derived_max_model_len)
3334+
derived_max_model_len = min(derived_max_model_len,
3335+
tokenizer_model_max_length)
3336+
33353337
# If none of the keys were found in the config, use a default and
33363338
# log a warning.
33373339
if derived_max_model_len == float("inf"):

0 commit comments

Comments
 (0)