diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index c33bfe4d8..9bf30b502 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -7,3 +7,4 @@ wandb fsspec tyro tokenizers >= 0.15.0 +safetensors diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 5275db1a2..0ffcafb02 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -85,3 +85,8 @@ e.g. ```bash NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` + + +## How to load / save a checkpoint in HF safetensors format +For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_weights_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format). +For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_weights_only` to load the checkpoint in HF format. diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 6218b5a5f..f3000eef7 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -118,6 +118,22 @@ def build_test_list(): "Checkpoint Integration Test - Save Load Full Checkpoint", "full_checkpoint", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.folder hf_checkpoint", + "--checkpoint.last_save_in_safetensors_format", + "--checkpoint.last_save_model_weights_only", + ], + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.initial_load_path artifacts-to-be-uploaded/full_checkpoint_hf_safetensors/hf_checkpoint/step-10/", + ], + ], + "Checkpoint Integration Test - save load full checkpoint in HF safetensors format", + "full_checkpoint_hf_safetensors", + ), OverrideDefinitions( [ [ diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index 3317a51fe..2f8127bfd 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -144,7 +144,7 @@ def tearDown(self): shutil.rmtree(self.base_temp_dir) time.sleep(0.1) - def fake_save(self, state_dict: dict, checkpoint_id: str): + def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None): os.makedirs(checkpoint_id, exist_ok=True) sd_to_save = {} for key, val in state_dict.items(): @@ -584,7 +584,7 @@ def __init__(self): @mock.patch("torchtitan.components.checkpoint.dcp.load") @mock.patch("torchtitan.components.checkpoint.dcp.save") def test_verify_prefix(self, mock_save, mock_load, mock_rank): - def fake_save(state_dict: dict, checkpoint_id: str): + def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None): self.assertIn("bias", state_dict) self.assertIn("weight", state_dict) # No model prefix diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 1bc07f2f2..f71417de8 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -12,12 +12,17 @@ import shutil import threading import time +from concurrent.futures import Future from typing import Any import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn +from torch.distributed.checkpoint import ( + HuggingFaceStorageReader, + HuggingFaceStorageWriter, +) from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, @@ -49,6 +54,11 @@ class AsyncMode(str, enum.Enum): ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" +class CheckpointType(str, enum.Enum): + DCP = "DCP" + SAFETENSORS = "safetensors" + + # For now, we will manually pop the freqs_cis buffer, as we made this permanent # temporarily and we don't want to include it in the exported state_dict. # Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 @@ -92,12 +102,6 @@ class SaveDone: pass -@torch.no_grad() -def save_with_gc(state, checkpoint_id): - dcp.save(state, checkpoint_id=checkpoint_id) - GarbageCollection.collect("GC collection invoked by checkpointer.") - - def purge_thread(purge_queue: queue.Queue): """Thread to purge the old checkpoints. @@ -190,6 +194,9 @@ def __init__( ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint + self.last_save_in_safetensors_format = ( + ckpt_config.last_save_in_safetensors_format + ) self.ft_manager = ( ft_manager.manager if ft_manager and ft_manager.enabled else None ) @@ -314,6 +321,98 @@ def close(self): if self.stager is not None: self.stager.close() + @torch.no_grad() + def dcp_save( + self, + state_dict: dict[str, Any], + checkpoint_id: str, + async_mode: AsyncMode, + enable_garbage_collection: bool = False, + save_in_safetensors_format: bool = False, + ) -> Future | None: + """Save the checkpoint with dcp. + Args: + state_dict (dict): The state dict to save. + checkpoint_id (str): The checkpoint id to save. + async_mode (AsyncMode): Whether the checkpoint is async. + enable_garbage_collection (bool): Whether to enable garbage collection after save. + save_in_safetensors_format (bool): Whether to save in safetensors format. + + Returns: + Future: The future object if the checkpoint is async, otherwise None. + """ + + ret: Future | None = None + + storage_writer: HuggingFaceStorageWriter | None = None + checkpoint_save_id: str | None = None + if save_in_safetensors_format: + fqn_to_index_mapping = {} + num_fqns_per_file = 30 + # the use of 30 is just a heuristic for now. + # Once these fqns map to HF ones, we can use the fqn mapping + # from the model.safetensors.index.json file + for i, key in enumerate(state_dict.keys()): + group_num = (i // num_fqns_per_file) + 1 + fqn_to_index_mapping[key] = group_num + + storage_writer = HuggingFaceStorageWriter( + path=checkpoint_id, + save_distributed=True, + fqn_to_index_mapping=fqn_to_index_mapping, + enable_consolidation=True, + thread_count_consolidation=5, + ) + else: + checkpoint_save_id = checkpoint_id + + if async_mode == AsyncMode.ASYNC: + ret = dcp.async_save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=checkpoint_save_id, + process_group=self.pg, + ) + elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: + ret = dcp.async_save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=checkpoint_save_id, + process_group=self.pg, + async_checkpointer_type=AsyncCheckpointerType.PROCESS, + async_stager=self.stager, + ) + else: + ret = dcp.save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=checkpoint_save_id, + ) + + if enable_garbage_collection: + GarbageCollection.collect("GC collection invoked by checkpointer.") + + return ret + + def dcp_load( + self, + state_dict: dict[str, Any], + checkpoint_id: str, + checkpoint_type: CheckpointType, + ) -> None: + """Load the checkpoint with dcp. + Args: + state_dict (dict): The state dict to load. + checkpoint_id (str): The checkpoint id to load. + hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. + """ + + if checkpoint_type == CheckpointType.SAFETENSORS: + storage_reader = HuggingFaceStorageReader(path=checkpoint_id) + dcp.load(state_dict, storage_reader=storage_reader) + else: + dcp.load(state_dict, checkpoint_id=checkpoint_id) + @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: """Save the checkpoint for the current step. @@ -354,23 +453,26 @@ def save(self, curr_step: int, last_step: bool = False) -> None: GarbageCollection.collect("GC collection invoked by checkpointer.") if self.stager is None: self.stager = DefaultStager(StagingOptions(True, True, True, True)) - result = dcp.async_save( + result = self.dcp_save( states, checkpoint_id=checkpoint_id, - process_group=self.pg, - async_checkpointer_type=AsyncCheckpointerType.PROCESS, - async_stager=self.stager, + async_mode=self.async_mode, ) self.save_future = result.upload_completion self.staging_future = result.staging_completion elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") - self.save_future = dcp.async_save( - states, checkpoint_id=checkpoint_id, process_group=self.pg + self.save_future = self.dcp_save( + states, checkpoint_id=checkpoint_id, async_mode=self.async_mode ) GarbageCollection.collect("GC collection invoked by checkpointer.") else: - save_with_gc(states, checkpoint_id=checkpoint_id) + self.dcp_save( + states, + checkpoint_id=checkpoint_id, + async_mode=AsyncMode.DISABLED, + enable_garbage_collection=True, + ) self._purge_stale_checkpoints() logger.info( @@ -432,10 +534,19 @@ def load(self, step: int = -1) -> bool: f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found." ) + checkpoint_type = self._find_checkpoint_type(checkpoint_id) + if checkpoint_type == CheckpointType.SAFETENSORS: + assert ( + model_only + ), "Only model weights can be loaded when loading from safetensors checkpoint." logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) - dcp.load(states, checkpoint_id=checkpoint_id) + self.dcp_load( + states, + checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, + ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -470,13 +581,33 @@ def _find_load_step(self, folder: str = "") -> int: for filename in os.listdir(folder): match = re.search(pattern, filename) - metadata_probe = os.path.join(folder, filename, ".metadata") - if match and os.path.isfile(metadata_probe): + dcp_metadata_probe = os.path.join(folder, filename, ".metadata") + safetensors_metadata_probe = os.path.join( + folder, filename, "model.safetensors.index.json" + ) + if match and os.path.isfile(dcp_metadata_probe): + step_counts.append(int(match.group(1))) + elif match and os.path.isfile(safetensors_metadata_probe): step_counts.append(int(match.group(1))) if not step_counts: return -1 return max(step_counts) + def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: + """Find the checkpoint type for the given id. + + Args: + checkpoint_id (str): The folder to find the checkpoint type for. + + Returns: + CheckpointType: The checkpoint type for the given folder. + """ + + for filename in os.listdir(checkpoint_id): + if filename == "model.safetensors.index.json": + return CheckpointType.SAFETENSORS + return CheckpointType.DCP + def _ft_folder(self) -> str: return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") @@ -488,8 +619,8 @@ def _ft_save(self, step: int) -> None: begin = time.monotonic() self._async_wait() checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - self.save_future = dcp.async_save( - self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg + self.save_future = self.dcp_save( + self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC ) logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.") @@ -501,7 +632,12 @@ def _ft_load(self) -> None: begin = time.monotonic() logger.info(f"Loading the FT checkpoint at step {step}.") checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - dcp.load(self.ft_states, checkpoint_id=checkpoint_id) + self.dcp_load( + self.ft_states, + checkpoint_id=checkpoint_id, + # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. + checkpoint_type=CheckpointType.DCP, + ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -570,7 +706,18 @@ def _save_last_step(self, curr_step: int) -> None: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") states = self._flattened_model_states_sd() - save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step)) + if self.last_save_in_safetensors_format: + assert ( + self.last_save_model_weights_only + ), "Only model weights can be saved when saving in safetensors format." + + self.dcp_save( + states, + checkpoint_id=self._create_checkpoint_id(curr_step), + async_mode=AsyncMode.DISABLED, + enable_garbage_collection=True, + save_in_safetensors_format=self.last_save_in_safetensors_format, + ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: if not self.enable_checkpoint: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 5f1a1e8b7..07c92b6f9 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -475,6 +475,17 @@ class Checkpoint: for many steps or checkpointing too frequently. The default value is False. """ + last_save_in_safetensors_format: bool = False + """ + Enable the use of safetensors format for checkpointing. This will save the final checkpoints + in safetensors format instead of the default DCP format. There will be a performance + cost in using this as we need to consolidate the sharded tensors to full tensors as + a separate step. last_save_model_weights_only must be true because safetensors doesn't + support saving non tensors. On load, this argument isn't needed as we will detect + whether the loaded checkpoint is in safetensors format or not. + The default value is False. + """ + @dataclass class ActivationCheckpoint: