diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml index 717abacd18..7fed032ec4 100644 --- a/recipes/configs/llama2/70B_lora.yaml +++ b/recipes/configs/llama2/70B_lora.yaml @@ -31,7 +31,7 @@ checkpointer: checkpoint_dir: /tmp/Llama-2-70b-hf checkpoint_files: filename_format: pytorch_model-{}-of-{}.bin - max_filename: 00015 + max_filename: "00015" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA2 diff --git a/recipes/configs/llama2/70B_qlora.yaml b/recipes/configs/llama2/70B_qlora.yaml index 5a380d3d0e..b140624fc2 100644 --- a/recipes/configs/llama2/70B_qlora.yaml +++ b/recipes/configs/llama2/70B_qlora.yaml @@ -36,7 +36,7 @@ checkpointer: checkpoint_dir: /tmp/Llama-2-70b-hf checkpoint_files: filename_format: pytorch_model-{}-of-{}.bin - max_filename: 00015 + max_filename: "00015" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA2 diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index f08019bdab..5491ae093d 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -41,7 +41,7 @@ checkpointer: checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00030 + max_filename: "00030" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index 23151a7193..da82f156f6 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -31,7 +31,7 @@ checkpointer: checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00030 + max_filename: "00030" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/configs/llama3_1/405B_qlora.yaml b/recipes/configs/llama3_1/405B_qlora.yaml index 4a15d8b25f..044f47b48e 100644 --- a/recipes/configs/llama3_1/405B_qlora.yaml +++ b/recipes/configs/llama3_1/405B_qlora.yaml @@ -34,7 +34,7 @@ checkpointer: checkpoint_dir: /tmp/Meta-Llama-3.1-405B-Instruct/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00191 + max_filename: "00191" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index 1cd06413a2..1ecf130e1a 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -40,7 +40,7 @@ checkpointer: checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00030 + max_filename: "00030" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index ed0a917025..81ef5f6875 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -30,7 +30,7 @@ checkpointer: checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00030 + max_filename: "00030" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/configs/llama3_3/70B_full.yaml b/recipes/configs/llama3_3/70B_full.yaml index fc9621631b..f7ec013c15 100644 --- a/recipes/configs/llama3_3/70B_full.yaml +++ b/recipes/configs/llama3_3/70B_full.yaml @@ -40,7 +40,7 @@ checkpointer: checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00030 + max_filename: "00030" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/configs/llama3_3/70B_lora.yaml b/recipes/configs/llama3_3/70B_lora.yaml index 5c09749abb..06c2924f5c 100644 --- a/recipes/configs/llama3_3/70B_lora.yaml +++ b/recipes/configs/llama3_3/70B_lora.yaml @@ -30,7 +30,7 @@ checkpointer: checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00030 + max_filename: "00030" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/configs/llama3_3/70B_qlora.yaml b/recipes/configs/llama3_3/70B_qlora.yaml index ebc18e9b01..53c4a8c3b5 100644 --- a/recipes/configs/llama3_3/70B_qlora.yaml +++ b/recipes/configs/llama3_3/70B_qlora.yaml @@ -30,7 +30,7 @@ checkpointer: checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00030 + max_filename: "00030" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/configs/qwen2_5/14B_lora_single_device.yaml b/recipes/configs/qwen2_5/14B_lora_single_device.yaml index 35c85cdb6e..e918b8de09 100644 --- a/recipes/configs/qwen2_5/14B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/14B_lora_single_device.yaml @@ -39,7 +39,7 @@ checkpointer: checkpoint_dir: /tmp/Qwen2_5-14B-Instruct checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00008 + max_filename: "00008" recipe_checkpoint: null output_dir: ${output_dir} model_type: QWEN2 diff --git a/recipes/configs/qwen2_5/32B_lora.yaml b/recipes/configs/qwen2_5/32B_lora.yaml index f8d2f6850e..1633d59c3f 100644 --- a/recipes/configs/qwen2_5/32B_lora.yaml +++ b/recipes/configs/qwen2_5/32B_lora.yaml @@ -37,7 +37,7 @@ checkpointer: checkpoint_dir: /tmp/Qwen2_5-32B-Instruct checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00017 + max_filename: "00017" recipe_checkpoint: null output_dir: ${output_dir} model_type: QWEN2 diff --git a/recipes/configs/qwen2_5/72B_lora.yaml b/recipes/configs/qwen2_5/72B_lora.yaml index 86b36340fc..6eabf7eca9 100644 --- a/recipes/configs/qwen2_5/72B_lora.yaml +++ b/recipes/configs/qwen2_5/72B_lora.yaml @@ -37,7 +37,7 @@ checkpointer: checkpoint_dir: /tmp/Qwen2_5-72B-Instruct checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00037 + max_filename: "00037" recipe_checkpoint: null output_dir: ${output_dir} model_type: QWEN2 diff --git a/tests/torchtune/training/checkpointing/test_checkpointer_utils.py b/tests/torchtune/training/checkpointing/test_checkpointer_utils.py index c9579fbfbc..d73bb6fc03 100644 --- a/tests/torchtune/training/checkpointing/test_checkpointer_utils.py +++ b/tests/torchtune/training/checkpointing/test_checkpointer_utils.py @@ -196,11 +196,19 @@ def expected_filenames(self): "model_0012_of_0012.pt", ] - def test_invalid_to_dict(self): + def test_invalid_from_dict_no_filename_format(self): invalid_dict = {"bad_key": "model_{}_of_{}.pt", "max_filename": "0005"} with pytest.raises(ValueError, match="Must pass 'filename_format'"): _ = FormattedCheckpointFiles.from_dict(invalid_dict) + def test_invalid_from_dict_int_max_filename(self): + # the 0o0005 is an octal number. we use this insane value in this test + # as YAML treats numbers with a leading 0 as an octal number, so this + # may be a good example of `from_dict` being called with an invalid config + invalid_dict = {"filename_format": "model_{}_of_{}.pt", "max_filename": 0o00025} + with pytest.raises(ValueError, match="`max_filename` must be a string"): + _ = FormattedCheckpointFiles.from_dict(invalid_dict) + def test_invalid_filename_format(self): formatted_string = "invalid_format_{}.pt" formatted_file_dict = { diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 1d8a63daab..f8dc55452b 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -142,6 +142,10 @@ def from_dict(cls, d: dict) -> "FormattedCheckpointFiles": raise ValueError( "Must pass 'filename_format' and 'max_filename' keys to generate checkpoint filenames" ) + if not isinstance(d["max_filename"], str): + raise ValueError( + f"`max_filename` must be a string, but found {type(d['max_filename'])} instead." + ) return cls( filename_format=d["filename_format"], max_filename=d["max_filename"], @@ -527,7 +531,7 @@ def validate_checkpoint_files( # e.g. # checkpoint_files: # filename_format: model-{}-of-{}.safetensors - # max_filename: 00191 + # max_filename: "00191" # becomes checkpoint_files = [model-00001-of-00191.safetensors, model-00002-of-00191,..] if not isinstance(checkpoint_files, List): # TODO: this can be a function instead of a class