-
Notifications
You must be signed in to change notification settings - Fork 427
Add support for saving HF format tensors with DCP #1351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@Saiteja64 This will conflict with your PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall the logic LGTM, please address comments and ensure that this PR doesn't conflict with the PR from @Saiteja64. Please also add a test result -- save a hf checkpoint and load one back and check the accuracy.
torchtitan/components/checkpoint.py
Outdated
if hf_safetensors_format: | ||
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) | ||
if is_async: | ||
return dcp.async_save( | ||
state_dict, storage_writer=storage_writer, process_group=pg | ||
) | ||
else: | ||
return dcp.save(state_dict, storage_writer=storage_writer) | ||
else: | ||
if is_async: | ||
return dcp.async_save( | ||
state_dict, checkpoint_id=checkpoint_id, process_group=pg | ||
) | ||
else: | ||
return dcp.save(state_dict, checkpoint_id=checkpoint_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should simplify the function as follow
if hf_safetensors_format: | |
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) | |
if is_async: | |
return dcp.async_save( | |
state_dict, storage_writer=storage_writer, process_group=pg | |
) | |
else: | |
return dcp.save(state_dict, storage_writer=storage_writer) | |
else: | |
if is_async: | |
return dcp.async_save( | |
state_dict, checkpoint_id=checkpoint_id, process_group=pg | |
) | |
else: | |
return dcp.save(state_dict, checkpoint_id=checkpoint_id) | |
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) if hf_safetensors_format else None | |
checkpoint_id = checkpoint_id if not hf_safetensors_format else None | |
if is_async: | |
return dcp.async_save( | |
state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id, process_group=pg | |
) | |
else: | |
return dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id) |
torchtitan/config_manager.py
Outdated
enable_hf_safetensors_format: bool = False | ||
""" | ||
Enable the use of safetensors format for checkpointing. This will save checkpoints | ||
in safetensors format instead of the default DCP format. The default value is False. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also mention the possible performance penalty? It's not cost free, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! Thanks for working on this! So from the logging, save a llama3 8B model checkpoints as HF format takes ~200s, and load a HF checkpoint needs ~30s, is this correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, LGTM, please fix the remaining comments.
torchtitan/components/checkpoint.py
Outdated
if checkpoint_type == CheckpointType.SAFETENSORS: | ||
model_only = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should assert if model_only
is not True, rather than silently change model_only
value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how should I change the logic to allow for this? right now if os.path.exists(self.folder), then there is no way for model_only to be True other than when step == 0. self.initial_load_model_weights_only isn't used in this code path either. It's also already silently being changed in line 516 which is why I did it this way, but happy to change in the way you think is best
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is depending on your answer to my question for loading folder structure.
If I'm right in that question -- to load a HF safetensors format, it requires the user to either use initial_load_path
or put it in a step-x folder. I think it makes sense to assert model_only == True
here,
unless you think it makes sense to put the HF checkpoint in some step-10 / 20 / 50 folder (why? I can't think of such use cases). If that's the case we can modify the model_only logic on line 516 to reflect that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok if there is no use case, will follow your advice, I wasn't sure
torchtitan/components/checkpoint.py
Outdated
self.dcp_load( | ||
self.ft_states, | ||
checkpoint_id=checkpoint_id, | ||
checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add one more line, "because FT checkpoint currently only save/load dataloader.".
save is actually faster than that. This run was ~140 seconds, but it was before I added the num_threads argument, so it should be faster now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I left some suggestions, and a question on the next step of integrating the torchtitan <> HF model conversion.
torchtitan/config_manager.py
Outdated
Enable the use of safetensors format for checkpointing. This will save the final checkpoints | ||
in safetensors format instead of the default DCP format. There will be a performance | ||
cost in using this as we need to consolidate the sharded tensors to full tensors as | ||
a separate step. Last_save_model_weights must be true because safetensors doesn't |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a separate step. Last_save_model_weights must be true because safetensors doesn't | |
a separate step. last_save_model_weights_only must be true because safetensors doesn't |
torchtitan/config_manager.py
Outdated
@@ -467,6 +467,17 @@ class Checkpoint: | |||
for many steps or checkpointing too frequently. The default value is False. | |||
""" | |||
|
|||
enable_save_safetensors_format: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit in naming
enable_save_safetensors_format: bool = False | |
last_save_in_safetensors_format: bool = False |
torchtitan/components/checkpoint.py
Outdated
path=checkpoint_id, | ||
save_distributed=True, | ||
fqn_to_index_mapping=fqn_to_index_mapping, | ||
enable_consolidation=is_last_step, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems more readable if
enable_consolidation=is_last_step, | |
enable_consolidation=True, |
) | ||
if match and os.path.isfile(dcp_metadata_probe): | ||
step_counts.append(int(match.group(1))) | ||
elif match and os.path.isfile(safetensors_metadata_probe): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh for safetensors do we also require step-10
, step-20
, etc. type of checkpoint structure? If users want to load an existing checkpoint from HF, is the workflow
- download safetensors into local folder
- either put it in
step-0
subfolder, or use theinitial_load_path
config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from your other comment, it seems like you think the HF checkpoint in step 10, 20 etc isn't a valid use case, so I will just follow your suggestion
torchtitan/components/checkpoint.py
Outdated
if checkpoint_type == CheckpointType.SAFETENSORS: | ||
model_only = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is depending on your answer to my question for loading folder structure.
If I'm right in that question -- to load a HF safetensors format, it requires the user to either use initial_load_path
or put it in a step-x folder. I think it makes sense to assert model_only == True
here,
unless you think it makes sense to put the HF checkpoint in some step-10 / 20 / 50 folder (why? I can't think of such use cases). If that's the case we can modify the model_only logic on line 516 to reflect that.
torchtitan/components/checkpoint.py
Outdated
checkpoint_id=self._create_checkpoint_id(curr_step), | ||
async_mode=AsyncMode.DISABLED, | ||
enable_garbage_collection=True, | ||
is_last_step=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we don't pass is_last_step
in, since it's only used when HF format matters.
I suggest we pass in save_in_safetensors_format = self.last_save_in_safetensors_format
(after config name change).
In the beginning of this function, we need to assert last_save_model_weights_only == True
if last_save_in_safetensors_format
tests/integration_tests.py
Outdated
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
"--checkpoint.enable_save_safetensors_format", | ||
"--checkpoint.last_save_model_weights_only", | ||
], | ||
], | ||
"Checkpoint Integration Test - Save Load Full Checkpoint", | ||
"full_checkpoint_hf_safetensors", | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to test both save in HF checkpoint and load from HF checkpoint in this test. Since we only support load weight-only initial checkpoint in HF format, I suggest the following for testing purposes only.
In the first run we save a step-10
checkpoint and pretend it to be the initial checkpoint to load; and then in the second run we use initial_load_path
to locate the one we saved. Something like the following, but you may have to tweak the paths a bit.
OverrideDefinitions( | |
[ | |
[ | |
"--checkpoint.enable_checkpoint", | |
"--checkpoint.enable_save_safetensors_format", | |
"--checkpoint.last_save_model_weights_only", | |
], | |
], | |
"Checkpoint Integration Test - Save Load Full Checkpoint", | |
"full_checkpoint_hf_safetensors", | |
), | |
OverrideDefinitions( | |
[ | |
[ | |
"--checkpoint.enable_checkpoint", | |
"--checkpoint.folder hf_checkpoint", | |
"--checkpoint.enable_save_safetensors_format", | |
"--checkpoint.last_save_model_weights_only", | |
], | |
[ | |
"--checkpoint.enable_checkpoint", | |
"--checkpoint.initial_load_path outputs/hf_checkpoint/step-10", | |
], | |
], | |
"Checkpoint Integration Test - save load full checkpoint in HF safetensors format", | |
"full_checkpoint_hf_safetensors", | |
), |
|
||
if checkpoint_type == CheckpointType.SAFETENSORS: | ||
storage_reader = HuggingFaceStorageReader(path=checkpoint_id) | ||
dcp.load(state_dict, storage_reader=storage_reader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a question on how to combine this PR with torchtitan model <-> HF model definition conversion mappings, so that we can load / save with HF models and train with torchtitan.
Let's say we have to mappings, torchtitan_to_hf
, hf_to_torchtitan
, both taking a state dict and convert it to another, e.g. very similar to https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama4/_convert_weights.py
I could imagine how we do this for save -- when HF format is used, we just don't call the current ModelWrapper.state_dict
, we call conversion map torchtitan_to_hf
on top of it.
How are we supposed to do load in HF definition? ModelWrapper.load_state_dict
is only in fault-tolerant path, but seems not used by dcp.load
, so how can we do things like fuse / unfuse of tensors after load? I think we need to manually add a function in dcp_load
to do the reverse of torchtitan_to_hf
inplace, similar to what ModelWrapper.load_state_dict
does in set_model_state_dict
. This also means hf_to_torchtitan
itself is not useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, had one more comment.
Also, would you please add a small section "How to load / save checkpoint in HF safetensor format" in https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md
For save, users need to set --checkpoint.last_save_in_safetensors_format
and --checkpoint.last_save_model_weights_only
and it only saves the last checkpoint in HF format (intermediate ones are in DCP format).
For load, users need to either put it in step-0
folder if using --checkpoint.folder
, or specify --checkpoint.initial_load_path
.
torchtitan/components/checkpoint.py
Outdated
checkpoint_save_id = ( | ||
checkpoint_id if not self.last_save_in_safetensors_format else None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate what this field is for? Can we use save_in_safetensors_format
instead of self.last_save_in_safetensors_format
? If so let's put declare it together with storage_writer
, and set it to checkpoint_id
in the else
branch.
Currently, if self.last_save_in_safetensors_format==True
then checkpoint_save_id
will always be None
, regardless if it's last step or not.
This may not be covered by test, because with 10 steps we always same in one format but not both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the non-HF case we need to pass in a checkpoint_id because we don't pass in a storage writer and just use the default storage writer and instantiate it with checkpoint_id. With HF case we pass it in as an arg to HFStorgaeWriter and then update it in that class, so we need to pass in an empty checkpoint_id to save
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks a lot for adding this feature!
Please address last nit comment on the .md tutorial.
Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
If checkpoint.enable_save_safetensors_format is set, then save the checkpoint with DCP HF components that will save the checkpoint in .safetensors files instead of regular DCP format on final save. On load, we can decide which type of load to do based on checkpoint type.
Successful save:
Successful load: