-
Notifications
You must be signed in to change notification settings - Fork 428
Add support for saving HF format tensors with DCP #1351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
c0c4448
4342e31
bb319bb
d64c307
ba64000
ee1d695
6f76b28
3d18477
cecaea9
0928b8f
d067428
994082d
9d1f9ef
42bd47a
b13b332
2a3fb81
0f223cc
139a1f5
99725a9
1f3ecc0
8d5a8bc
2b5793f
a7d0ea9
2dabc41
e91174b
f186244
6ba5d18
c36c863
a3725b2
ffdb21f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,4 @@ wandb | |
fsspec | ||
tyro | ||
tokenizers >= 0.15.0 | ||
safetensors | ||
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -118,6 +118,17 @@ def build_test_list(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Checkpoint Integration Test - Save Load Full Checkpoint", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"full_checkpoint", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
OverrideDefinitions( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"--checkpoint.enable_checkpoint", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"--checkpoint.enable_save_safetensors_format", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"--checkpoint.last_save_model_weights_only", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Checkpoint Integration Test - Save Load Full Checkpoint", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"full_checkpoint_hf_safetensors", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to test both save in HF checkpoint and load from HF checkpoint in this test. Since we only support load weight-only initial checkpoint in HF format, I suggest the following for testing purposes only. In the first run we save a
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
OverrideDefinitions( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,12 +12,17 @@ | |||||
import shutil | ||||||
import threading | ||||||
import time | ||||||
from concurrent.futures import Future | ||||||
from typing import Any | ||||||
|
||||||
import torch | ||||||
import torch.distributed as dist | ||||||
import torch.distributed.checkpoint as dcp | ||||||
import torch.nn as nn | ||||||
from torch.distributed.checkpoint import ( | ||||||
HuggingFaceStorageReader, | ||||||
HuggingFaceStorageWriter, | ||||||
) | ||||||
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions | ||||||
from torch.distributed.checkpoint.state_dict import ( | ||||||
get_model_state_dict, | ||||||
|
@@ -49,6 +54,11 @@ class AsyncMode(str, enum.Enum): | |||||
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" | ||||||
|
||||||
|
||||||
class CheckpointType(str, enum.Enum): | ||||||
DCP = "DCP" | ||||||
SAFETENSORS = "safetensors" | ||||||
|
||||||
|
||||||
# For now, we will manually pop the freqs_cis buffer, as we made this permanent | ||||||
# temporarily and we don't want to include it in the exported state_dict. | ||||||
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 | ||||||
|
@@ -92,12 +102,6 @@ class SaveDone: | |||||
pass | ||||||
|
||||||
|
||||||
@torch.no_grad() | ||||||
def save_with_gc(state, checkpoint_id): | ||||||
dcp.save(state, checkpoint_id=checkpoint_id) | ||||||
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||||||
|
||||||
|
||||||
def purge_thread(purge_queue: queue.Queue): | ||||||
"""Thread to purge the old checkpoints. | ||||||
|
@@ -190,6 +194,7 @@ def __init__( | |||||
) -> None: | ||||||
ckpt_config = job_config.checkpoint | ||||||
self.enable_checkpoint = ckpt_config.enable_checkpoint | ||||||
self.enable_save_safetensors_format = ckpt_config.enable_save_safetensors_format | ||||||
self.ft_manager = ft_manager.manager if ft_manager.enabled else None | ||||||
|
||||||
if self.ft_manager: | ||||||
|
@@ -312,6 +317,89 @@ def close(self): | |||||
if self.stager is not None: | ||||||
self.stager.close() | ||||||
|
||||||
@torch.no_grad() | ||||||
def dcp_save( | ||||||
self, | ||||||
state_dict: dict[str, Any], | ||||||
checkpoint_id: str, | ||||||
async_mode: AsyncMode, | ||||||
enable_garbage_collection: bool = False, | ||||||
is_last_step: bool = False, | ||||||
) -> Future | None: | ||||||
"""Save the checkpoint with dcp. | ||||||
Args: | ||||||
state_dict (dict): The state dict to save. | ||||||
checkpoint_id (str): The checkpoint id to save. | ||||||
is_async (bool): Whether the checkpoint is async. | ||||||
Returns: | ||||||
Future: The future object if the checkpoint is async, otherwise None. | ||||||
""" | ||||||
|
||||||
ret: Future | None = None | ||||||
|
||||||
storage_writer: HuggingFaceStorageWriter | None = None | ||||||
if self.enable_save_safetensors_format and is_last_step: | ||||||
fqn_to_index_mapping = {} | ||||||
num_fqns_per_file = 30 | ||||||
# the use of 30 is just a heuristic for now. | ||||||
# Once these fqns map to HF ones, we can use the fqn mapping | ||||||
# from the model.safetensors.index.json file | ||||||
for i, key in enumerate(state_dict.keys()): | ||||||
group_num = (i // num_fqns_per_file) + 1 | ||||||
fqn_to_index_mapping[key] = group_num | ||||||
|
||||||
storage_writer = HuggingFaceStorageWriter( | ||||||
path=checkpoint_id, | ||||||
save_distributed=True, | ||||||
fqn_to_index_mapping=fqn_to_index_mapping, | ||||||
enable_consolidation=is_last_step, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems more readable if
Suggested change
|
||||||
thread_count_consolidation=5, | ||||||
) | ||||||
id = checkpoint_id if not self.enable_save_safetensors_format else None | ||||||
if async_mode == AsyncMode.ASYNC: | ||||||
ret = dcp.async_save( | ||||||
state_dict, | ||||||
storage_writer=storage_writer, | ||||||
checkpoint_id=id, | ||||||
process_group=self.pg, | ||||||
) | ||||||
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: | ||||||
ret = dcp.async_save( | ||||||
state_dict, | ||||||
storage_writer=storage_writer, | ||||||
checkpoint_id=id, | ||||||
process_group=self.pg, | ||||||
async_checkpointer_type=AsyncCheckpointerType.PROCESS, | ||||||
async_stager=self.stager, | ||||||
) | ||||||
else: | ||||||
ret = dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id) | ||||||
|
||||||
if enable_garbage_collection: | ||||||
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||||||
|
||||||
return ret | ||||||
|
||||||
def dcp_load( | ||||||
self, | ||||||
state_dict: dict[str, Any], | ||||||
checkpoint_id: str, | ||||||
checkpoint_type: CheckpointType, | ||||||
) -> None: | ||||||
"""Load the checkpoint with dcp. | ||||||
Args: | ||||||
state_dict (dict): The state dict to load. | ||||||
checkpoint_id (str): The checkpoint id to load. | ||||||
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. | ||||||
""" | ||||||
|
||||||
if checkpoint_type == CheckpointType.SAFETENSORS: | ||||||
storage_reader = HuggingFaceStorageReader(path=checkpoint_id) | ||||||
dcp.load(state_dict, storage_reader=storage_reader) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a question on how to combine this PR with torchtitan model <-> HF model definition conversion mappings, so that we can load / save with HF models and train with torchtitan. Let's say we have to mappings, I could imagine how we do this for save -- when HF format is used, we just don't call the current How are we supposed to do load in HF definition? |
||||||
else: | ||||||
dcp.load(state_dict, checkpoint_id=checkpoint_id) | ||||||
|
||||||
@torch.no_grad() | ||||||
def save(self, curr_step: int, last_step: bool = False) -> None: | ||||||
"""Save the checkpoint for the current step. | ||||||
|
@@ -352,23 +440,26 @@ def save(self, curr_step: int, last_step: bool = False) -> None: | |||||
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||||||
if self.stager is None: | ||||||
self.stager = DefaultStager(StagingOptions(True, True, True, True)) | ||||||
result = dcp.async_save( | ||||||
result = self.dcp_save( | ||||||
states, | ||||||
checkpoint_id=checkpoint_id, | ||||||
process_group=self.pg, | ||||||
async_checkpointer_type=AsyncCheckpointerType.PROCESS, | ||||||
async_stager=self.stager, | ||||||
async_mode=self.async_mode, | ||||||
) | ||||||
self.save_future = result.upload_completion | ||||||
self.staging_future = result.staging_completion | ||||||
elif self.async_mode == AsyncMode.ASYNC: | ||||||
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||||||
self.save_future = dcp.async_save( | ||||||
states, checkpoint_id=checkpoint_id, process_group=self.pg | ||||||
self.save_future = self.dcp_save( | ||||||
states, checkpoint_id=checkpoint_id, async_mode=self.async_mode | ||||||
) | ||||||
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||||||
else: | ||||||
save_with_gc(states, checkpoint_id=checkpoint_id) | ||||||
self.dcp_save( | ||||||
states, | ||||||
checkpoint_id=checkpoint_id, | ||||||
async_mode=AsyncMode.DISABLED, | ||||||
enable_garbage_collection=True, | ||||||
) | ||||||
self._purge_stale_checkpoints() | ||||||
|
||||||
logger.info( | ||||||
|
@@ -430,10 +521,17 @@ def load(self, step: int = -1) -> bool: | |||||
f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found." | ||||||
) | ||||||
|
||||||
checkpoint_type = self._find_checkpoint_type(checkpoint_id) | ||||||
if checkpoint_type == CheckpointType.SAFETENSORS: | ||||||
model_only = True | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should assert if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how should I change the logic to allow for this? right now if os.path.exists(self.folder), then there is no way for model_only to be True other than when step == 0. self.initial_load_model_weights_only isn't used in this code path either. It's also already silently being changed in line 516 which is why I did it this way, but happy to change in the way you think is best There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is depending on your answer to my question for loading folder structure. unless you think it makes sense to put the HF checkpoint in some step-10 / 20 / 50 folder (why? I can't think of such use cases). If that's the case we can modify the model_only logic on line 516 to reflect that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok if there is no use case, will follow your advice, I wasn't sure |
||||||
logger.info(f"Loading the checkpoint from {checkpoint_id}.") | ||||||
begin = time.monotonic() | ||||||
states = self._states_to_load(model_only) | ||||||
dcp.load(states, checkpoint_id=checkpoint_id) | ||||||
self.dcp_load( | ||||||
states, | ||||||
checkpoint_id=checkpoint_id, | ||||||
checkpoint_type=checkpoint_type, | ||||||
) | ||||||
GarbageCollection.collect("GC collection for checkpoint loading.") | ||||||
logger.info( | ||||||
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." | ||||||
|
@@ -468,13 +566,33 @@ def _find_load_step(self, folder: str = "") -> int: | |||||
|
||||||
for filename in os.listdir(folder): | ||||||
match = re.search(pattern, filename) | ||||||
metadata_probe = os.path.join(folder, filename, ".metadata") | ||||||
if match and os.path.isfile(metadata_probe): | ||||||
dcp_metadata_probe = os.path.join(folder, filename, ".metadata") | ||||||
safetensors_metadata_probe = os.path.join( | ||||||
folder, filename, "model.safetensors.index.json" | ||||||
) | ||||||
if match and os.path.isfile(dcp_metadata_probe): | ||||||
step_counts.append(int(match.group(1))) | ||||||
elif match and os.path.isfile(safetensors_metadata_probe): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh for safetensors do we also require
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from your other comment, it seems like you think the HF checkpoint in step 10, 20 etc isn't a valid use case, so I will just follow your suggestion |
||||||
step_counts.append(int(match.group(1))) | ||||||
if not step_counts: | ||||||
return -1 | ||||||
return max(step_counts) | ||||||
|
||||||
def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: | ||||||
"""Find the checkpoint type for the given id. | ||||||
Args: | ||||||
checkpoint_id (str): The folder to find the checkpoint type for. | ||||||
Returns: | ||||||
CheckpointType: The checkpoint type for the given folder. | ||||||
""" | ||||||
|
||||||
for filename in os.listdir(checkpoint_id): | ||||||
if filename == "model.safetensors.index.json": | ||||||
return CheckpointType.SAFETENSORS | ||||||
return CheckpointType.DCP | ||||||
|
||||||
def _ft_folder(self) -> str: | ||||||
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") | ||||||
|
||||||
|
@@ -486,8 +604,8 @@ def _ft_save(self, step: int) -> None: | |||||
begin = time.monotonic() | ||||||
self._async_wait() | ||||||
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) | ||||||
self.save_future = dcp.async_save( | ||||||
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg | ||||||
self.save_future = self.dcp_save( | ||||||
self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC | ||||||
) | ||||||
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.") | ||||||
|
||||||
|
@@ -499,7 +617,11 @@ def _ft_load(self) -> None: | |||||
begin = time.monotonic() | ||||||
logger.info(f"Loading the FT checkpoint at step {step}.") | ||||||
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) | ||||||
dcp.load(self.ft_states, checkpoint_id=checkpoint_id) | ||||||
self.dcp_load( | ||||||
self.ft_states, | ||||||
checkpoint_id=checkpoint_id, | ||||||
checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add one more line, "because FT checkpoint currently only save/load dataloader.". |
||||||
) | ||||||
GarbageCollection.collect("GC collection for checkpoint loading.") | ||||||
logger.info( | ||||||
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." | ||||||
|
@@ -568,7 +690,13 @@ def _save_last_step(self, curr_step: int) -> None: | |||||
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") | ||||||
states = self._flattened_model_states_sd() | ||||||
|
||||||
save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step)) | ||||||
self.dcp_save( | ||||||
states, | ||||||
checkpoint_id=self._create_checkpoint_id(curr_step), | ||||||
async_mode=AsyncMode.DISABLED, | ||||||
enable_garbage_collection=True, | ||||||
is_last_step=True, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest we don't pass In the beginning of this function, we need to assert |
||||||
) | ||||||
|
||||||
def _should_save(self, curr_step: int, last_step: bool = False) -> bool: | ||||||
if not self.enable_checkpoint: | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -467,6 +467,17 @@ class Checkpoint: | |||||
for many steps or checkpointing too frequently. The default value is False. | ||||||
""" | ||||||
|
||||||
enable_save_safetensors_format: bool = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit in naming
Suggested change
|
||||||
""" | ||||||
Enable the use of safetensors format for checkpointing. This will save the final checkpoints | ||||||
in safetensors format instead of the default DCP format. There will be a performance | ||||||
cost in using this as we need to consolidate the sharded tensors to full tensors as | ||||||
a separate step. Last_save_model_weights must be true because safetensors doesn't | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
support saving non tensors. On load, this argument isn't needed as we will detect | ||||||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
whether the loaded checkpoint is in safetensors format or not. | ||||||
The default value is False. | ||||||
""" | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class ActivationCheckpoint: | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.