@@ -1429,25 +1429,19 @@ def matryoshka_dimensions(self):
1429
1429
return getattr (self .hf_config , "matryoshka_dimensions" , None )
1430
1430
1431
1431
def get_and_verify_max_len (self , max_model_len : int ):
1432
+ tokenizer_config = try_get_tokenizer_config (
1433
+ self .tokenizer ,
1434
+ trust_remote_code = self .trust_remote_code ,
1435
+ revision = self .tokenizer_revision )
1432
1436
max_model_len = _get_and_verify_max_len (
1433
1437
hf_config = self .hf_text_config ,
1438
+ tokenizer_config = tokenizer_config ,
1434
1439
max_model_len = max_model_len ,
1435
1440
disable_sliding_window = self .disable_sliding_window ,
1436
1441
sliding_window_len = self .get_hf_config_sliding_window (),
1437
1442
spec_target_max_model_len = self .spec_target_max_model_len ,
1438
1443
encoder_config = self .encoder_config )
1439
-
1440
- tokenizer_config = try_get_tokenizer_config (
1441
- self .tokenizer ,
1442
- trust_remote_code = self .trust_remote_code ,
1443
- revision = self .tokenizer_revision )
1444
-
1445
- if tokenizer_config is None :
1446
- return max_model_len
1447
-
1448
- model_max_length = tokenizer_config .get ("model_max_length" ,
1449
- max_model_len )
1450
- max_model_len = min (max_model_len , model_max_length )
1444
+ logger .info ("Using max model len %s" , max_model_len )
1451
1445
return max_model_len
1452
1446
1453
1447
@@ -3283,6 +3277,7 @@ def _get_and_verify_dtype(
3283
3277
3284
3278
def _get_and_verify_max_len (
3285
3279
hf_config : PretrainedConfig ,
3280
+ tokenizer_config : Optional [dict ],
3286
3281
max_model_len : Optional [int ],
3287
3282
disable_sliding_window : bool ,
3288
3283
sliding_window_len : Optional [Union [int , list [Optional [int ]]]],
@@ -3309,7 +3304,7 @@ def _get_and_verify_max_len(
3309
3304
"max_seq_length" ,
3310
3305
"seq_len" ,
3311
3306
]
3312
- # Choose the smallest "max_length" from the possible keys.
3307
+ # Choose the smallest "max_length" from the possible keys
3313
3308
max_len_key = None
3314
3309
for key in possible_keys :
3315
3310
max_len = getattr (hf_config , key , None )
@@ -3332,6 +3327,13 @@ def _get_and_verify_max_len(
3332
3327
derived_max_model_len = min (derived_max_model_len ,
3333
3328
sliding_window_len_min )
3334
3329
3330
+ # Consider model_max_length in tokenizer_config
3331
+ if tokenizer_config :
3332
+ tokenizer_model_max_length = tokenizer_config .get (
3333
+ "model_max_length" , derived_max_model_len )
3334
+ derived_max_model_len = min (derived_max_model_len ,
3335
+ tokenizer_model_max_length )
3336
+
3335
3337
# If none of the keys were found in the config, use a default and
3336
3338
# log a warning.
3337
3339
if derived_max_model_len == float ("inf" ):
0 commit comments