Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipes/configs/llama2/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-2-70b-hf
checkpoint_files:
filename_format: pytorch_model-{}-of-{}.bin
max_filename: 00015
max_filename: "00015"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA2
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/70B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-2-70b-hf
checkpoint_files:
filename_format: pytorch_model-{}-of-{}.bin
max_filename: 00015
max_filename: "00015"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA2
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/405B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3.1-405B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00191
max_filename: "00191"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_3/70B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/qwen2_5/14B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2_5-14B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00008
max_filename: "00008"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/qwen2_5/32B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2_5-32B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00017
max_filename: "00017"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/qwen2_5/72B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2_5-72B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00037
max_filename: "00037"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,19 @@ def expected_filenames(self):
"model_0012_of_0012.pt",
]

def test_invalid_to_dict(self):
def test_invalid_from_dict_no_filename_format(self):
invalid_dict = {"bad_key": "model_{}_of_{}.pt", "max_filename": "0005"}
with pytest.raises(ValueError, match="Must pass 'filename_format'"):
_ = FormattedCheckpointFiles.from_dict(invalid_dict)

def test_invalid_from_dict_int_max_filename(self):
# the 0o0005 is an octal number. we use this insane value in this test
# as YAML treats numbers with a leading 0 as an octal number, so this
# may be a good example of `from_dict` being called with an invalid config
invalid_dict = {"filename_format": "model_{}_of_{}.pt", "max_filename": 0o00025}
with pytest.raises(ValueError, match="`max_filename` must be a string"):
_ = FormattedCheckpointFiles.from_dict(invalid_dict)

def test_invalid_filename_format(self):
formatted_string = "invalid_format_{}.pt"
formatted_file_dict = {
Expand Down
6 changes: 5 additions & 1 deletion torchtune/training/checkpointing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def from_dict(cls, d: dict) -> "FormattedCheckpointFiles":
raise ValueError(
"Must pass 'filename_format' and 'max_filename' keys to generate checkpoint filenames"
)
if not isinstance(d["max_filename"], str):
raise ValueError(
f"`max_filename` must be a string, but found {type(d['max_filename'])} instead."
)
return cls(
filename_format=d["filename_format"],
max_filename=d["max_filename"],
Expand Down Expand Up @@ -527,7 +531,7 @@ def validate_checkpoint_files(
# e.g.
# checkpoint_files:
# filename_format: model-{}-of-{}.safetensors
# max_filename: 00191
# max_filename: "00191"
# becomes checkpoint_files = [model-00001-of-00191.safetensors, model-00002-of-00191,..]
if not isinstance(checkpoint_files, List):
# TODO: this can be a function instead of a class
Expand Down
Loading