From c0c44489323f5b4db5d9c535ae9fa65ba29492ea Mon Sep 17 00:00:00 2001 From: Ankita George Date: Fri, 27 Jun 2025 11:07:06 -0700 Subject: [PATCH 01/28] add hf support --- tests/integration_tests.py | 15 ++++ torchtitan/components/checkpoint.py | 112 +++++++++++++++++++++++++--- torchtitan/config_manager.py | 6 ++ 3 files changed, 121 insertions(+), 12 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 3ccbc1890..dca2610be 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -118,6 +118,21 @@ def build_test_list(): "Checkpoint Integration Test - Save Load Full Checkpoint", "full_checkpoint", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.enable_hf_safetensors_format", + ], + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.enable_hf_safetensors_format", + "--training.steps 20", + ], + ], + "Checkpoint Integration Test - Save Load Full Checkpoint", + "full_checkpoint_hf_safetensors", + ), OverrideDefinitions( [ [ diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index d6b8d45c3..386bea3d1 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -12,7 +12,8 @@ import shutil import threading import time -from typing import Any +from concurrent.futures import Future +from typing import Any, Optional import torch import torch.distributed as dist @@ -20,6 +21,10 @@ import torch.multiprocessing as mp import torch.nn as nn from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint import ( + HuggingFaceStorageReader, + HuggingFaceStorageWriter, +) from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, set_model_state_dict, @@ -93,8 +98,64 @@ class SaveDone: @torch.no_grad() -def save_with_gc(state, checkpoint_id): - dcp.save(state, checkpoint_id=checkpoint_id) +def dcp_save( + state_dict: dict[str, Any], + checkpoint_id: str, + is_async: bool, + hf_safetensors_format: bool, + pg: Optional[dist.ProcessGroup] = None, +) -> Optional[Future]: + """Save the checkpoint with dcp. + Args: + state_dict (dict): The state dict to save. + checkpoint_id (str): The checkpoint id to save. + is_async (bool): Whether the checkpoint is async. + hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. + pg (Optional[dist.ProcessGroup]): The process group to use. + """ + if hf_safetensors_format: + storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) + if is_async: + return dcp.async_save( + state_dict, storage_writer=storage_writer, process_group=pg + ) + else: + return dcp.save(state_dict, storage_writer=storage_writer) + else: + if is_async: + return dcp.async_save( + state_dict, checkpoint_id=checkpoint_id, process_group=pg + ) + else: + return dcp.save(state_dict, checkpoint_id=checkpoint_id) + + +def dcp_load( + state_dict: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool +) -> None: + """Load the checkpoint with dcp. + Args: + state_dict (dict): The state dict to load. + checkpoint_id (str): The checkpoint id to load. + hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. + """ + if hf_safetensors_format: + storage_reader = HuggingFaceStorageReader(path=checkpoint_id) + dcp.load(state_dict, storage_writer=storage_reader) + else: + dcp.load(state_dict, checkpoint_id=checkpoint_id) + + +@torch.no_grad() +def save_with_gc( + state: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool +) -> None: + dcp_save( + state, + checkpoint_id=checkpoint_id, + is_async=False, + hf_safetensors_format=hf_safetensors_format, + ) GarbageCollection.collect("GC collection invoked by checkpointer.") @@ -125,7 +186,9 @@ def checkpoint_mp(recv: mp.Queue, send: mp.Queue): assert isinstance(obj, tuple) begin = time.monotonic() state, checkpoint_id = obj - save_with_gc(state, checkpoint_id=checkpoint_id) + save_with_gc( + state, checkpoint_id=checkpoint_id, hf_safetensors_format=False + ) logger.info( "Finish saving the checkpoint in the background process in %.2f seconds.", time.monotonic() - begin, @@ -227,6 +290,7 @@ def __init__( ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint + self.enable_hf_safetensors_format = ckpt_config.enable_hf_safetensors_format self.ft_manager = ft_manager.manager if ft_manager.enabled else None if self.ft_manager: @@ -391,12 +455,20 @@ def save(self, curr_step: int, last_step: bool = False) -> None: self._async_with_pinned_memory(checkpoint_id) elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") - self.async_future = dcp.async_save( - self.states, checkpoint_id=checkpoint_id, process_group=self.pg + self.async_future = dcp_save( + self.states, + checkpoint_id=checkpoint_id, + is_async=True, + hf_safetensors_format=self.enable_hf_safetensors_format, + pg=self.pg, ) GarbageCollection.collect("GC collection invoked by checkpointer.") else: - save_with_gc(self.states, checkpoint_id=checkpoint_id) + save_with_gc( + self.states, + checkpoint_id=checkpoint_id, + hf_safetensors_format=self.enable_hf_safetensors_format, + ) self._purge_stale_checkpoints() logger.info( @@ -461,7 +533,11 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) - dcp.load(states, checkpoint_id=checkpoint_id) + dcp_load( + states, + checkpoint_id=checkpoint_id, + hf_safetensors_format=self.enable_hf_safetensors_format, + ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -540,8 +616,12 @@ def _ft_save(self, step: int) -> None: begin = time.monotonic() self._async_wait() checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - self.async_future = dcp.async_save( - self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg + self.async_future = dcp_save( + self.ft_states, + checkpoint_id=checkpoint_id, + is_async=True, + hf_safetensors_format=self.enable_hf_safetensors_format, + pg=self.pg, ) logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.") @@ -553,7 +633,11 @@ def _ft_load(self) -> None: begin = time.monotonic() logger.info(f"Loading the FT checkpoint at step {step}.") checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - dcp.load(self.ft_states, checkpoint_id=checkpoint_id) + dcp_load( + self.ft_states, + checkpoint_id=checkpoint_id, + hf_safetensors_format=self.enable_hf_safetensors_format, + ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -614,7 +698,11 @@ def _save_last_step(self, curr_step: int) -> None: else: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") - save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) + save_with_gc( + self.states, + checkpoint_id=self._create_checkpoint_id(curr_step), + hf_safetensors_format=self.enable_hf_safetensors_format, + ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: if not self.enable_checkpoint: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 200acc36c..093052646 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -467,6 +467,12 @@ class Checkpoint: for many steps or checkpointing too frequently. The default value is False. """ + enable_hf_safetensors_format: bool = False + """ + Enable the use of safetensors format for checkpointing. This will save checkpoints + in safetensors format instead of the default DCP format. The default value is False. + """ + @dataclass class ActivationCheckpoint: From 4342e31d50d6223a305534bd05087b7cc93864a0 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 7 Jul 2025 10:56:51 -0700 Subject: [PATCH 02/28] address comments --- torchtitan/components/checkpoint.py | 41 +++++++++++++++++------------ torchtitan/config_manager.py | 4 ++- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 386bea3d1..05742965e 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -13,7 +13,7 @@ import threading import time from concurrent.futures import Future -from typing import Any, Optional +from typing import Any import torch import torch.distributed as dist @@ -103,8 +103,9 @@ def dcp_save( checkpoint_id: str, is_async: bool, hf_safetensors_format: bool, - pg: Optional[dist.ProcessGroup] = None, -) -> Optional[Future]: + pg: dist.ProcessGroup | None = None, +) -> Future | None: + """Save the checkpoint with dcp. Args: state_dict (dict): The state dict to save. @@ -112,27 +113,33 @@ def dcp_save( is_async (bool): Whether the checkpoint is async. hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. pg (Optional[dist.ProcessGroup]): The process group to use. + + Returns: + Future: The future object if the checkpoint is async, otherwise None. """ - if hf_safetensors_format: - storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) - if is_async: - return dcp.async_save( - state_dict, storage_writer=storage_writer, process_group=pg - ) - else: - return dcp.save(state_dict, storage_writer=storage_writer) + storage_writer = ( + HuggingFaceStorageWriter( + path=checkpoint_id, save_distributed=True, enable_consolidation=True + ) + if hf_safetensors_format + else None + ) + id = checkpoint_id if not hf_safetensors_format else None + if is_async: + return dcp.async_save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=id, + process_group=pg, + ) else: - if is_async: - return dcp.async_save( - state_dict, checkpoint_id=checkpoint_id, process_group=pg - ) - else: - return dcp.save(state_dict, checkpoint_id=checkpoint_id) + return dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id) def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool ) -> None: + """Load the checkpoint with dcp. Args: state_dict (dict): The state dict to load. diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 093052646..fdfa2fa62 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -470,7 +470,9 @@ class Checkpoint: enable_hf_safetensors_format: bool = False """ Enable the use of safetensors format for checkpointing. This will save checkpoints - in safetensors format instead of the default DCP format. The default value is False. + in safetensors format instead of the default DCP format. There will be a performance + cost in using this as we need to consolidate the sharded tensors to full tensors as + a separate step. The default value is False. """ From d64c307a1cc28ec956063761c5b36e365dcf11f0 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 7 Jul 2025 11:29:37 -0700 Subject: [PATCH 03/28] restructure --- torchtitan/components/checkpoint.py | 163 +++++++++++++--------------- 1 file changed, 77 insertions(+), 86 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 1615d0ad2..75c5218ec 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -19,15 +19,11 @@ import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn -<<<<<<< HEAD -from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict from torch.distributed.checkpoint import ( HuggingFaceStorageReader, HuggingFaceStorageWriter, ) -======= from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions ->>>>>>> main from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, set_model_state_dict, @@ -101,75 +97,6 @@ class SaveDone: pass -@torch.no_grad() -def dcp_save( - state_dict: dict[str, Any], - checkpoint_id: str, - is_async: bool, - hf_safetensors_format: bool, - pg: dist.ProcessGroup | None = None, -) -> Future | None: - - """Save the checkpoint with dcp. - Args: - state_dict (dict): The state dict to save. - checkpoint_id (str): The checkpoint id to save. - is_async (bool): Whether the checkpoint is async. - hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. - pg (Optional[dist.ProcessGroup]): The process group to use. - - Returns: - Future: The future object if the checkpoint is async, otherwise None. - """ - storage_writer = ( - HuggingFaceStorageWriter( - path=checkpoint_id, save_distributed=True, enable_consolidation=True - ) - if hf_safetensors_format - else None - ) - id = checkpoint_id if not hf_safetensors_format else None - if is_async: - return dcp.async_save( - state_dict, - storage_writer=storage_writer, - checkpoint_id=id, - process_group=pg, - ) - else: - return dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id) - - -def dcp_load( - state_dict: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool -) -> None: - - """Load the checkpoint with dcp. - Args: - state_dict (dict): The state dict to load. - checkpoint_id (str): The checkpoint id to load. - hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. - """ - if hf_safetensors_format: - storage_reader = HuggingFaceStorageReader(path=checkpoint_id) - dcp.load(state_dict, storage_writer=storage_reader) - else: - dcp.load(state_dict, checkpoint_id=checkpoint_id) - - -@torch.no_grad() -def save_with_gc( - state: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool -) -> None: - dcp_save( - state, - checkpoint_id=checkpoint_id, - is_async=False, - hf_safetensors_format=hf_safetensors_format, - ) - GarbageCollection.collect("GC collection invoked by checkpointer.") - - def purge_thread(purge_queue: queue.Queue): """Thread to purge the old checkpoints. @@ -385,6 +312,74 @@ def close(self): if self.stager is not None: self.stager.close() + @torch.no_grad() + def dcp_save( + self, + state_dict: dict[str, Any], + checkpoint_id: str, + async_mode: AsyncMode, + garbage_collection: bool = False, + ) -> Future | None: + """Save the checkpoint with dcp. + Args: + state_dict (dict): The state dict to save. + checkpoint_id (str): The checkpoint id to save. + is_async (bool): Whether the checkpoint is async. + + Returns: + Future: The future object if the checkpoint is async, otherwise None. + """ + ret : Future | None = None + + storage_writer = ( + HuggingFaceStorageWriter( + path=checkpoint_id, save_distributed=True, enable_consolidation=True + ) + if self.enable_hf_safetensors_format + else None + ) + id = checkpoint_id if not self.enable_hf_safetensors_format else None + if async_mode == AsyncMode.ASYNC: + ret = dcp.async_save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=id, + process_group=self.pg, + ) + elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: + ret = dcp.async_save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=id, + process_group=self.pg, + async_checkpointer_type=AsyncCheckpointerType.PROCESS, + async_stager=self.stager, + ) + else: + ret = dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id) + + if garbage_collection: + GarbageCollection.collect("GC collection invoked by checkpointer.") + + return ret + + + def dcp_load( + self, state_dict: dict[str, Any], checkpoint_id: str + ) -> None: + """Load the checkpoint with dcp. + Args: + state_dict (dict): The state dict to load. + checkpoint_id (str): The checkpoint id to load. + hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. + """ + + if self.enable_hf_safetensors_format: + storage_reader = HuggingFaceStorageReader(path=checkpoint_id) + dcp.load(state_dict, storage_writer=storage_reader) + else: + dcp.load(state_dict, checkpoint_id=checkpoint_id) + @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: """Save the checkpoint for the current step. @@ -425,23 +420,21 @@ def save(self, curr_step: int, last_step: bool = False) -> None: GarbageCollection.collect("GC collection invoked by checkpointer.") if self.stager is None: self.stager = DefaultStager(StagingOptions(True, True, True, True)) - result = dcp.async_save( + result = self.dcp_save( states, checkpoint_id=checkpoint_id, - process_group=self.pg, - async_checkpointer_type=AsyncCheckpointerType.PROCESS, - async_stager=self.stager, + async_mode=self.async_mode, ) self.save_future = result.upload_completion self.staging_future = result.staging_completion elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") - self.save_future = dcp.async_save( - states, checkpoint_id=checkpoint_id, process_group=self.pg + self.save_future = self.dcp_save( + states, checkpoint_id=checkpoint_id, async_mode=self.async_mode ) GarbageCollection.collect("GC collection invoked by checkpointer.") else: - save_with_gc(states, checkpoint_id=checkpoint_id) + self.dcp_save(states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.DISABLED, garbage_collection=True) self._purge_stale_checkpoints() logger.info( @@ -506,10 +499,9 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) - dcp_load( + self.dcp_load( states, checkpoint_id=checkpoint_id, - hf_safetensors_format=self.enable_hf_safetensors_format, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( @@ -563,8 +555,8 @@ def _ft_save(self, step: int) -> None: begin = time.monotonic() self._async_wait() checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - self.save_future = dcp.async_save( - self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg + self.save_future = self.dcp_save( + self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC ) logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.") @@ -576,10 +568,9 @@ def _ft_load(self) -> None: begin = time.monotonic() logger.info(f"Loading the FT checkpoint at step {step}.") checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - dcp_load( + self.dcp_load( self.ft_states, checkpoint_id=checkpoint_id, - hf_safetensors_format=self.enable_hf_safetensors_format, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( From ba64000e891e45f3f544bbf22b3da00d202c2983 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 7 Jul 2025 12:53:41 -0700 Subject: [PATCH 04/28] fix test --- torchtitan/components/checkpoint.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 75c5218ec..b994347e8 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -318,7 +318,7 @@ def dcp_save( state_dict: dict[str, Any], checkpoint_id: str, async_mode: AsyncMode, - garbage_collection: bool = False, + enable_garbage_collection: bool = False, ) -> Future | None: """Save the checkpoint with dcp. Args: @@ -329,7 +329,7 @@ def dcp_save( Returns: Future: The future object if the checkpoint is async, otherwise None. """ - ret : Future | None = None + ret: Future | None = None storage_writer = ( HuggingFaceStorageWriter( @@ -358,15 +358,12 @@ def dcp_save( else: ret = dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id) - if garbage_collection: + if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") return ret - - def dcp_load( - self, state_dict: dict[str, Any], checkpoint_id: str - ) -> None: + def dcp_load(self, state_dict: dict[str, Any], checkpoint_id: str) -> None: """Load the checkpoint with dcp. Args: state_dict (dict): The state dict to load. @@ -434,7 +431,12 @@ def save(self, curr_step: int, last_step: bool = False) -> None: ) GarbageCollection.collect("GC collection invoked by checkpointer.") else: - self.dcp_save(states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.DISABLED, garbage_collection=True) + self.dcp_save( + states, + checkpoint_id=checkpoint_id, + async_mode=AsyncMode.DISABLED, + enable_garbage_collection=True, + ) self._purge_stale_checkpoints() logger.info( @@ -640,7 +642,12 @@ def _save_last_step(self, curr_step: int) -> None: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") states = self._flattened_model_states_sd() - save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step)) + self.dcp_save( + states, + checkpoint_id=self._create_checkpoint_id(curr_step), + async_mode=AsyncMode.DISABLED, + enable_garbage_collection=True, + ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: if not self.enable_checkpoint: From ee1d695c9eb5093b28158eb32f318db6030d509e Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 7 Jul 2025 13:15:26 -0700 Subject: [PATCH 05/28] fix tests --- tests/unit_tests/test_checkpoint.py | 4 ++-- torchtitan/components/checkpoint.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index 3317a51fe..2f8127bfd 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -144,7 +144,7 @@ def tearDown(self): shutil.rmtree(self.base_temp_dir) time.sleep(0.1) - def fake_save(self, state_dict: dict, checkpoint_id: str): + def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None): os.makedirs(checkpoint_id, exist_ok=True) sd_to_save = {} for key, val in state_dict.items(): @@ -584,7 +584,7 @@ def __init__(self): @mock.patch("torchtitan.components.checkpoint.dcp.load") @mock.patch("torchtitan.components.checkpoint.dcp.save") def test_verify_prefix(self, mock_save, mock_load, mock_rank): - def fake_save(state_dict: dict, checkpoint_id: str): + def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None): self.assertIn("bias", state_dict) self.assertIn("weight", state_dict) # No model prefix diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index b994347e8..559a20aa8 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -373,7 +373,7 @@ def dcp_load(self, state_dict: dict[str, Any], checkpoint_id: str) -> None: if self.enable_hf_safetensors_format: storage_reader = HuggingFaceStorageReader(path=checkpoint_id) - dcp.load(state_dict, storage_writer=storage_reader) + dcp.load(state_dict, storage_reader=storage_reader) else: dcp.load(state_dict, checkpoint_id=checkpoint_id) From 6f76b28b358eebedf8c8daed464f2270290f9823 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 7 Jul 2025 14:01:46 -0700 Subject: [PATCH 06/28] fix lint and last step --- torchtitan/components/checkpoint.py | 4 +++- torchtitan/config_manager.py | 2 +- torchtitan/models/llama3/train_configs/llama3_8b.toml | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 559a20aa8..4b56d16d4 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -319,6 +319,7 @@ def dcp_save( checkpoint_id: str, async_mode: AsyncMode, enable_garbage_collection: bool = False, + is_last_step: bool = False ) -> Future | None: """Save the checkpoint with dcp. Args: @@ -333,7 +334,7 @@ def dcp_save( storage_writer = ( HuggingFaceStorageWriter( - path=checkpoint_id, save_distributed=True, enable_consolidation=True + path=checkpoint_id, save_distributed=True, enable_consolidation=is_last_step, ) if self.enable_hf_safetensors_format else None @@ -647,6 +648,7 @@ def _save_last_step(self, curr_step: int) -> None: checkpoint_id=self._create_checkpoint_id(curr_step), async_mode=AsyncMode.DISABLED, enable_garbage_collection=True, + is_last_step=True, ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index c0fd45715..d567e987b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -471,7 +471,7 @@ class Checkpoint: """ Enable the use of safetensors format for checkpointing. This will save checkpoints in safetensors format instead of the default DCP format. There will be a performance - cost in using this as we need to consolidate the sharded tensors to full tensors as + cost in using this as we need to consolidate the sharded tensors to full tensors as a separate step. The default value is False. """ diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 63b4ce6da..d2c4aaa2b 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" -tokenizer_path = "./assets/tokenizer/original/tokenizer.model" +tokenizer_path = "./assets/tokenizer/Meta-Llama-3.1-8B/original/tokenizer.model" # converters = ["float8"] [optimizer] From 3d184770a0f73ba3e553f0c22bf060550cac294a Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 7 Jul 2025 14:03:22 -0700 Subject: [PATCH 07/28] remove config change --- torchtitan/models/llama3/train_configs/llama3_8b.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index d2c4aaa2b..63b4ce6da 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" -tokenizer_path = "./assets/tokenizer/Meta-Llama-3.1-8B/original/tokenizer.model" +tokenizer_path = "./assets/tokenizer/original/tokenizer.model" # converters = ["float8"] [optimizer] From cecaea9af7fb0697343099c6c2376aa36d90c1f8 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 11:13:08 -0700 Subject: [PATCH 08/28] continue testing --- .ci/docker/requirements.txt | 1 + torchtitan/components/checkpoint.py | 53 +++++++++++++++++-- torchtitan/config_manager.py | 6 +++ .../llama3/train_configs/llama3_8b.toml | 10 ++-- torchtitan/train.py | 3 +- 5 files changed, 63 insertions(+), 10 deletions(-) diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 11eae863f..4d2a62c1b 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -9,3 +9,4 @@ wandb fsspec tyro tokenizers >= 0.15.0 +safetensors diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 4b56d16d4..3109651a0 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -6,6 +6,7 @@ import enum import functools +import json import os import queue import re @@ -289,6 +290,23 @@ def load_state_dict(state_dict): else: raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}") + self.hf_fqn_index_map = None + if ( + self.enable_hf_safetensors_format + and ckpt_config.safetensors_json is not None + ): + self.hf_fqn_index_map = {} + with open(ckpt_config.safetensors_json, "r") as f: + data = json.load(f) + weight_map = data["weight_map"] + for k, v in weight_map.items(): + # expect the value to be in the format of "model-00000n-of-00000m.safetensors" + try: + self.hf_fqn_index_map[k] = int(v.split("-")[1]) + except Exception: + self.hf_fqn_index_map = None + break + logger.info( f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" ) @@ -319,7 +337,7 @@ def dcp_save( checkpoint_id: str, async_mode: AsyncMode, enable_garbage_collection: bool = False, - is_last_step: bool = False + is_last_step: bool = False, ) -> Future | None: """Save the checkpoint with dcp. Args: @@ -332,9 +350,32 @@ def dcp_save( """ ret: Future | None = None + state_dict_to_save: dict[str, Any] = ( + {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} + if self.enable_hf_safetensors_format + else state_dict + ) + logger.info( + "Num keys before parsing %d, after %d", + len(state_dict), + len(state_dict_to_save), + ) + for k, v in state_dict_to_save.items(): + if isinstance(v, torch.Tensor): + logger.info("key %s, shape %s", k, v.shape) + break + + fqn_to_index_mapping = {} + for i, key in enumerate(state_dict_to_save.keys()): + group_num = (i // 30) + 1 + fqn_to_index_mapping[key] = group_num + storage_writer = ( HuggingFaceStorageWriter( - path=checkpoint_id, save_distributed=True, enable_consolidation=is_last_step, + path=checkpoint_id, + save_distributed=True, + fqn_to_index_mapping=fqn_to_index_mapping, + enable_consolidation=is_last_step, ) if self.enable_hf_safetensors_format else None @@ -342,14 +383,14 @@ def dcp_save( id = checkpoint_id if not self.enable_hf_safetensors_format else None if async_mode == AsyncMode.ASYNC: ret = dcp.async_save( - state_dict, + state_dict_to_save, storage_writer=storage_writer, checkpoint_id=id, process_group=self.pg, ) elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: ret = dcp.async_save( - state_dict, + state_dict_to_save, storage_writer=storage_writer, checkpoint_id=id, process_group=self.pg, @@ -357,7 +398,9 @@ def dcp_save( async_stager=self.stager, ) else: - ret = dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id) + ret = dcp.save( + state_dict_to_save, storage_writer=storage_writer, checkpoint_id=id + ) if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d567e987b..50cb63910 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -475,6 +475,12 @@ class Checkpoint: a separate step. The default value is False. """ + safetensors_json: str | None = None + """ + Path to the safetensors json file. This is only used when --checkpoint.enable_hf_safetensors_format + is set. The default value is None. + """ + @dataclass class ActivationCheckpoint: diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 63b4ce6da..75f90532b 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" -tokenizer_path = "./assets/tokenizer/original/tokenizer.model" +tokenizer_path = "./assets/tokenizer/Meta-Llama-3.1-8B/original/tokenizer.model" # converters = ["float8"] [optimizer] @@ -33,7 +33,7 @@ warmup_steps = 200 # lr scheduler warm up local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 3 compile = false dataset = "c4" @@ -45,12 +45,14 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable_checkpoint = true folder = "checkpoint" interval = 500 -last_save_model_weights_only = false +last_save_model_weights_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] +enable_hf_safetensors_format = true +safetensors_json = "./assets/safetensors/Meta-Llama-3.1-8B/model.index.safetensors.json" [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] diff --git a/torchtitan/train.py b/torchtitan/train.py index ca1480f2e..cfb78c365 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,10 +11,10 @@ from typing import Any, Generator, Iterable, Optional import torch -from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderStopIteration from torchtitan.components.loss import rescale_accumulated_loss @@ -498,6 +498,7 @@ def train(self): except DataloaderStopIteration: logger.warning("Ran out of data; last step was canceled.") break + logger.info("Calling checkpoint save after step %d", self.step) self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) ) From 0928b8f38ab27e625d7326c9ea3f5723a8277c3e Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 13:45:00 -0700 Subject: [PATCH 09/28] more testing --- torchtitan/components/checkpoint.py | 1 + torchtitan/models/llama3/train_configs/llama3_8b.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 3109651a0..3c5fe5582 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -544,6 +544,7 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() + model_only = True states = self._states_to_load(model_only) self.dcp_load( states, diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 75f90532b..0b4de978b 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -53,6 +53,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] enable_hf_safetensors_format = true safetensors_json = "./assets/safetensors/Meta-Llama-3.1-8B/model.index.safetensors.json" +# load_step = 3 [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] From d06742880ac1d24f7f38374e70680fa7ec345106 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 19:37:40 -0700 Subject: [PATCH 10/28] clean up --- run_train.sh | 2 +- torchtitan/components/checkpoint.py | 40 ++++++------------- torchtitan/config_manager.py | 10 ++--- .../llama3/train_configs/llama3_8b.toml | 11 ++--- torchtitan/train.py | 36 +++++++++++++++-- 5 files changed, 53 insertions(+), 46 deletions(-) diff --git a/run_train.sh b/run_train.sh index fbed394eb..131da107a 100755 --- a/run_train.sh +++ b/run_train.sh @@ -11,7 +11,7 @@ set -ex # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh NGPU=${NGPU:-"8"} -export LOG_RANK=${LOG_RANK:-0} +export LOG_RANK=${LOG_RANK:-0,1,2} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} overrides="" diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 3c5fe5582..6654311b7 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -290,23 +290,6 @@ def load_state_dict(state_dict): else: raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}") - self.hf_fqn_index_map = None - if ( - self.enable_hf_safetensors_format - and ckpt_config.safetensors_json is not None - ): - self.hf_fqn_index_map = {} - with open(ckpt_config.safetensors_json, "r") as f: - data = json.load(f) - weight_map = data["weight_map"] - for k, v in weight_map.items(): - # expect the value to be in the format of "model-00000n-of-00000m.safetensors" - try: - self.hf_fqn_index_map[k] = int(v.split("-")[1]) - except Exception: - self.hf_fqn_index_map = None - break - logger.info( f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" ) @@ -365,21 +348,23 @@ def dcp_save( logger.info("key %s, shape %s", k, v.shape) break - fqn_to_index_mapping = {} - for i, key in enumerate(state_dict_to_save.keys()): - group_num = (i // 30) + 1 - fqn_to_index_mapping[key] = group_num - - storage_writer = ( - HuggingFaceStorageWriter( + storage_writer: HuggingFaceStorageWriter | None = None + if self.enable_hf_safetensors_format and is_last_step: + fqn_to_index_mapping = {} + num_fqns_per_file = 30 + # the use of 30 is just a heuristic for now. + # Once these fqns map to HF ones, we can use the fqn mapping + # from the model.safetensors.index.json file + for i, key in enumerate(state_dict_to_save.keys()): + group_num = (i // num_fqns_per_file) + 1 + fqn_to_index_mapping[key] = group_num + + storage_writer = HuggingFaceStorageWriter( path=checkpoint_id, save_distributed=True, fqn_to_index_mapping=fqn_to_index_mapping, enable_consolidation=is_last_step, ) - if self.enable_hf_safetensors_format - else None - ) id = checkpoint_id if not self.enable_hf_safetensors_format else None if async_mode == AsyncMode.ASYNC: ret = dcp.async_save( @@ -544,7 +529,6 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() - model_only = True states = self._states_to_load(model_only) self.dcp_load( states, diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 50cb63910..2ea2b81f1 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -472,13 +472,9 @@ class Checkpoint: Enable the use of safetensors format for checkpointing. This will save checkpoints in safetensors format instead of the default DCP format. There will be a performance cost in using this as we need to consolidate the sharded tensors to full tensors as - a separate step. The default value is False. - """ - - safetensors_json: str | None = None - """ - Path to the safetensors json file. This is only used when --checkpoint.enable_hf_safetensors_format - is set. The default value is None. + a separate step. Last_save_model_weights and initial_load_model_weights_only + must be true because safetensors doesn't support saving non tensors. + The default value is False. """ diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 0b4de978b..63b4ce6da 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" -tokenizer_path = "./assets/tokenizer/Meta-Llama-3.1-8B/original/tokenizer.model" +tokenizer_path = "./assets/tokenizer/original/tokenizer.model" # converters = ["float8"] [optimizer] @@ -33,7 +33,7 @@ warmup_steps = 200 # lr scheduler warm up local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping -steps = 3 +steps = 1000 compile = false dataset = "c4" @@ -45,15 +45,12 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable_checkpoint = true +enable_checkpoint = false folder = "checkpoint" interval = 500 -last_save_model_weights_only = true +last_save_model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] -enable_hf_safetensors_format = true -safetensors_json = "./assets/safetensors/Meta-Llama-3.1-8B/model.index.safetensors.json" -# load_step = 3 [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] diff --git a/torchtitan/train.py b/torchtitan/train.py index cfb78c365..3dc8a61b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,10 +11,10 @@ from typing import Any, Generator, Iterable, Optional import torch +from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module -from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderStopIteration from torchtitan.components.loss import rescale_accumulated_loss @@ -52,6 +52,8 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): optimizers: train_spec_module.OptimizersContainer lr_schedulers: train_spec_module.LRSchedulersContainer + validator: train_spec_module.BaseValidator | None + pp_has_first_stage: bool pp_has_last_stage: bool @@ -89,6 +91,7 @@ def __init__(self, job_config: JobConfig): cp=parallelism_config.context_parallel_degree, tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) @@ -280,7 +283,7 @@ def __init__(self, job_config: JobConfig): # build optimizer after applying parallelisms to the model self.optimizers = self.train_spec.build_optimizers_fn( - self.model_parts, job_config, self.ft_manager + self.model_parts, job_config, parallel_dims, world_mesh, self.ft_manager ) self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( self.optimizers, job_config @@ -319,6 +322,25 @@ def __init__(self, job_config: JobConfig): device_type, ) + # Build validator if validation is configured + if job_config.validation.enabled: + assert self.train_spec.build_validator_fn is not None + assert ( + not parallel_dims.pp_enabled + ), "pp is enabled but validation doesn't support pipeline parallelism yet" + + self.validator = self.train_spec.build_validator_fn( + job_config=job_config, + dp_world_size=dp_degree, + dp_rank=dp_rank, + tokenizer=tokenizer, + parallel_dims=parallel_dims, + world_mesh=world_mesh, + loss_fn=self.train_spec.build_loss_fn(job_config), + validation_context=self.train_context, + maybe_enable_amp=self.maybe_enable_amp, + ) + logger.info( "Trainer is initialized with " f"local batch size {job_config.training.local_batch_size}, " @@ -436,6 +458,7 @@ def train_step( self.job_config.training.max_norm, foreach=True, pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, + parallel_dims=parallel_dims, ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() @@ -498,7 +521,14 @@ def train(self): except DataloaderStopIteration: logger.warning("Ran out of data; last step was canceled.") break - logger.info("Calling checkpoint save after step %d", self.step) + + # Run validation if validator is available + if ( + self.job_config.validation.enabled + and self.validator.should_validate(self.step) + ): + self.validator.validate(self.model_parts) + self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) ) From 994082deb6b3979ee36284f891f5feddd1eba186 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 19:43:22 -0700 Subject: [PATCH 11/28] more clean up --- run_train.sh | 2 +- torchtitan/train.py | 97 ++------------------------------------------- 2 files changed, 5 insertions(+), 94 deletions(-) diff --git a/run_train.sh b/run_train.sh index 131da107a..fbed394eb 100755 --- a/run_train.sh +++ b/run_train.sh @@ -11,7 +11,7 @@ set -ex # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh NGPU=${NGPU:-"8"} -export LOG_RANK=${LOG_RANK:-0,1,2} +export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} overrides="" diff --git a/torchtitan/train.py b/torchtitan/train.py index 3dc8a61b2..86f3ede11 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import importlib import os import time @@ -11,10 +10,9 @@ from typing import Any, Generator, Iterable, Optional import torch -from torch.distributed.elastic.multiprocessing.errors import record - import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderStopIteration from torchtitan.components.loss import rescale_accumulated_loss @@ -36,29 +34,23 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): job_config: JobConfig gc_handler: utils.GarbageCollection - parallel_dims: ParallelDims train_spec: train_spec_module.TrainSpec world_mesh: torch.distributed.DeviceMesh gradient_accumulation_steps: int - dataloader: train_spec_module.BaseDataLoader metrics_processor: train_spec_module.MetricsProcessor checkpointer: CheckpointManager train_context: Generator[None, None, None] - model_parts: list[torch.nn.Module] loss_fn: train_spec_module.LossFunction optimizers: train_spec_module.OptimizersContainer lr_schedulers: train_spec_module.LRSchedulersContainer - validator: train_spec_module.BaseValidator | None - pp_has_first_stage: bool pp_has_last_stage: bool device: torch.device - # states step: int @@ -66,22 +58,16 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): @record def __init__(self, job_config: JobConfig): torch._C._log_api_usage_once("torchtitan.train") - self.job_config = job_config - logger.info(f"Starting job: {job_config.job.description}") - if job_config.experimental.custom_import: importlib.import_module(job_config.experimental.custom_import) - if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") - device_module, device_type = utils.device_module, utils.device_type self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") # Device has to be set before creating TorchFT manager. device_module.set_device(self.device) - # init distributed world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism @@ -91,12 +77,11 @@ def __init__(self, job_config: JobConfig): cp=parallelism_config.context_parallel_degree, tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, - ep=parallelism_config.expert_parallel_degree, world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) - dist_utils.init_distributed(job_config) + dist_utils.init_distributed(job_config) # build meshes self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type) if parallel_dims.dp_enabled: @@ -104,18 +89,15 @@ def __init__(self, job_config: JobConfig): dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 - self.ft_manager = ft.init_ft_manager(job_config) # If TorchFT is enabled, the dp_rank and dp_degree, which are used for # dataloader must be changed. if self.ft_manager.enabled: dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) - # take control of garbage collection to avoid stragglers self.gc_handler = utils.GarbageCollection( gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) - # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -125,37 +107,31 @@ def __init__(self, job_config: JobConfig): job_config.training.deterministic, ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name) - # build dataloader tokenizer = ( self.train_spec.build_tokenizer_fn(job_config) if self.train_spec.build_tokenizer_fn is not None else None ) - self.dataloader = self.train_spec.build_dataloader_fn( dp_world_size=dp_degree, dp_rank=dp_rank, tokenizer=tokenizer, job_config=job_config, ) - # build model (using meta init) model_cls = self.train_spec.cls model_args = self.train_spec.config[job_config.model.flavor] # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) - logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) with torch.device("meta"): model = model_cls(model_args) - # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) - # metrics logging build_metrics_processor_fn = ( build_metrics_processor @@ -166,18 +142,15 @@ def __init__(self, job_config: JobConfig): job_config, parallel_dims, model_args ) color = self.metrics_processor.color - # calculate model size and flops per token ( model_param_count, self.metrics_processor.num_flops_per_token, ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) - logger.info( f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) - # move sharded model to CPU/GPU and initialize weights via DTensor if job_config.checkpoint.create_seed_checkpoint: init_device = "cpu" @@ -188,9 +161,7 @@ def __init__(self, job_config: JobConfig): else: init_device = device_type buffer_device = None - self.loss_fn = self.train_spec.build_loss_fn(job_config) - # verify batch sizes global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: @@ -205,7 +176,6 @@ def __init__(self, job_config: JobConfig): f"data-parallel degree ({global_batch_size} " f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" ) - # calculate gradient accumulation steps self.gradient_accumulation_steps = global_batch_size // ( job_config.training.local_batch_size * dp_degree @@ -214,7 +184,6 @@ def __init__(self, job_config: JobConfig): self.loss_fn = rescale_accumulated_loss( self.loss_fn, self.gradient_accumulation_steps ) - # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: @@ -222,7 +191,6 @@ def __init__(self, job_config: JobConfig): f"Pipeline Parallel is enabled but {self.train_spec.name} " f"does not support pipelining" ) - # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques ( self.pp_schedule, @@ -242,13 +210,11 @@ def __init__(self, job_config: JobConfig): # when PP is enabled, `model` obj is no longer used after this point, # model_parts is used instead del model - for m in self.model_parts: m.to_empty(device=init_device) with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() - # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: @@ -256,20 +222,16 @@ def __init__(self, job_config: JobConfig): model = self.train_spec.parallelize_fn( model, world_mesh, parallel_dims, job_config ) - model.to_empty(device=init_device) with torch.no_grad(): model.init_weights(buffer_device=buffer_device) model.train() - self.model_parts = [model] - if ( self.ft_manager.enabled and job_config.fault_tolerance.semi_sync_method is None ): self.ft_manager.set_all_reduce_hook(self.model_parts) - # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = self.metrics_processor.device_memory_monitor gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) @@ -283,7 +245,7 @@ def __init__(self, job_config: JobConfig): # build optimizer after applying parallelisms to the model self.optimizers = self.train_spec.build_optimizers_fn( - self.model_parts, job_config, parallel_dims, world_mesh, self.ft_manager + self.model_parts, job_config, self.ft_manager ) self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( self.optimizers, job_config @@ -297,11 +259,9 @@ def __init__(self, job_config: JobConfig): ) ) self.metrics_processor.optimizers = self.optimizers - # Initialize trainer states that will be saved in checkpoint. # These attributes must be initialized before checkpoint loading. self.step = 0 - self.checkpointer = CheckpointManager( dataloader=self.dataloader, model_parts=self.model_parts, @@ -311,7 +271,6 @@ def __init__(self, job_config: JobConfig): job_config=job_config, ft_manager=self.ft_manager, ) - self.train_context = dist_utils.get_train_context( parallel_dims.loss_parallel_enabled, parallelism_config.enable_compiled_autograd, @@ -322,25 +281,6 @@ def __init__(self, job_config: JobConfig): device_type, ) - # Build validator if validation is configured - if job_config.validation.enabled: - assert self.train_spec.build_validator_fn is not None - assert ( - not parallel_dims.pp_enabled - ), "pp is enabled but validation doesn't support pipeline parallelism yet" - - self.validator = self.train_spec.build_validator_fn( - job_config=job_config, - dp_world_size=dp_degree, - dp_rank=dp_rank, - tokenizer=tokenizer, - parallel_dims=parallel_dims, - world_mesh=world_mesh, - loss_fn=self.train_spec.build_loss_fn(job_config), - validation_context=self.train_context, - maybe_enable_amp=self.maybe_enable_amp, - ) - logger.info( "Trainer is initialized with " f"local batch size {job_config.training.local_batch_size}, " @@ -357,7 +297,6 @@ def batch_generator( """Returns an iterator that processes batches from the data iterator.""" device_type = utils.device_type data_iterator = iter(data_iterable) - while True: try: batch = next(data_iterator) @@ -371,13 +310,11 @@ def batch_generator( self.metrics_processor.data_loading_times.append( time.perf_counter() - data_load_start ) - # Move tensors to the appropriate device for k, v in input_dict.items(): if isinstance(v, torch.Tensor): input_dict[k] = v.to(device_type) labels = labels.to(device_type) - yield input_dict, labels def forward_backward_step( @@ -385,7 +322,6 @@ def forward_backward_step( ) -> torch.Tensor: model_parts = self.model_parts parallel_dims = self.parallel_dims - # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["input"] @@ -400,7 +336,6 @@ def forward_backward_step( if parallel_dims.cp_enabled else None ) - if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call with self.train_context(optional_context_parallel_ctx): @@ -415,7 +350,6 @@ def forward_backward_step( self.pp_schedule.step( target=targets, losses=losses, input_batch=inputs ) - # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU loss = ( @@ -433,18 +367,15 @@ def forward_backward_step( # need to free to before bwd to avoid peaking memory del pred loss.backward() - return loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): self.optimizers.zero_grad() - # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. parallel_dims = self.parallel_dims - accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. @@ -452,25 +383,21 @@ def train_step( input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) - grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, - parallel_dims=parallel_dims, ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() - self.lr_schedulers.step() + self.lr_schedulers.step() # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) - # log metrics if not self.metrics_processor.should_log(self.step): return - if parallel_dims.dp_cp_enabled or self.ft_manager.enabled: loss = loss.detach() # Skip ft manager communication when using semi sync training @@ -485,7 +412,6 @@ def train_step( ) else: global_avg_loss = global_max_loss = loss.detach().item() - self.metrics_processor.log( self.step, global_avg_loss, @@ -496,10 +422,8 @@ def train_step( @record def train(self): job_config = self.job_config - self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}.") - with ( maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler, maybe_enable_memory_snapshot( @@ -521,14 +445,6 @@ def train(self): except DataloaderStopIteration: logger.warning("Ran out of data; last step was canceled.") break - - # Run validation if validator is available - if ( - self.job_config.validation.enabled - and self.validator.should_validate(self.step) - ): - self.validator.validate(self.model_parts) - self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) ) @@ -538,7 +454,6 @@ def train(self): torch_profiler.step() if memory_profiler: memory_profiler.step() - # reduce timeout after first train step for faster signal # (assuming lazy init and compilation are finished) if self.step == 1: @@ -548,11 +463,9 @@ def train(self): ), world_mesh=self.world_mesh, ) - if torch.distributed.get_rank() == 0: logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) - self.metrics_processor.close() logger.info("Training completed") @@ -572,10 +485,8 @@ def close(self) -> None: config_manager = ConfigManager() config = config_manager.parse_args() trainer: Optional[Trainer] = None - try: trainer = Trainer(config) - if config.checkpoint.create_seed_checkpoint: assert ( int(os.environ["WORLD_SIZE"]) == 1 From 9d1f9efdb67fb239a1486cab1fc98499eb52b4f7 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 19:51:10 -0700 Subject: [PATCH 12/28] undo train changes --- torchtitan/train.py | 62 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 86f3ede11..b217de5a5 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import importlib import os import time @@ -10,6 +11,7 @@ from typing import Any, Generator, Iterable, Optional import torch + import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torch.distributed.elastic.multiprocessing.errors import record @@ -34,14 +36,17 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): job_config: JobConfig gc_handler: utils.GarbageCollection + parallel_dims: ParallelDims train_spec: train_spec_module.TrainSpec world_mesh: torch.distributed.DeviceMesh gradient_accumulation_steps: int + dataloader: train_spec_module.BaseDataLoader metrics_processor: train_spec_module.MetricsProcessor checkpointer: CheckpointManager train_context: Generator[None, None, None] + model_parts: list[torch.nn.Module] loss_fn: train_spec_module.LossFunction optimizers: train_spec_module.OptimizersContainer @@ -51,6 +56,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): pp_has_last_stage: bool device: torch.device + # states step: int @@ -58,16 +64,22 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): @record def __init__(self, job_config: JobConfig): torch._C._log_api_usage_once("torchtitan.train") + self.job_config = job_config + logger.info(f"Starting job: {job_config.job.description}") + if job_config.experimental.custom_import: importlib.import_module(job_config.experimental.custom_import) + if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") + device_module, device_type = utils.device_module, utils.device_type self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") # Device has to be set before creating TorchFT manager. device_module.set_device(self.device) + # init distributed world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism @@ -80,8 +92,8 @@ def __init__(self, job_config: JobConfig): world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) - dist_utils.init_distributed(job_config) + # build meshes self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type) if parallel_dims.dp_enabled: @@ -89,15 +101,18 @@ def __init__(self, job_config: JobConfig): dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + self.ft_manager = ft.init_ft_manager(job_config) # If TorchFT is enabled, the dp_rank and dp_degree, which are used for # dataloader must be changed. if self.ft_manager.enabled: dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + # take control of garbage collection to avoid stragglers self.gc_handler = utils.GarbageCollection( gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -107,31 +122,37 @@ def __init__(self, job_config: JobConfig): job_config.training.deterministic, ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name) + # build dataloader tokenizer = ( self.train_spec.build_tokenizer_fn(job_config) if self.train_spec.build_tokenizer_fn is not None else None ) + self.dataloader = self.train_spec.build_dataloader_fn( dp_world_size=dp_degree, dp_rank=dp_rank, tokenizer=tokenizer, job_config=job_config, ) + # build model (using meta init) model_cls = self.train_spec.cls model_args = self.train_spec.config[job_config.model.flavor] # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) + logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) with torch.device("meta"): model = model_cls(model_args) + # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) + # metrics logging build_metrics_processor_fn = ( build_metrics_processor @@ -142,15 +163,18 @@ def __init__(self, job_config: JobConfig): job_config, parallel_dims, model_args ) color = self.metrics_processor.color + # calculate model size and flops per token ( model_param_count, self.metrics_processor.num_flops_per_token, ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + logger.info( f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) + # move sharded model to CPU/GPU and initialize weights via DTensor if job_config.checkpoint.create_seed_checkpoint: init_device = "cpu" @@ -161,7 +185,9 @@ def __init__(self, job_config: JobConfig): else: init_device = device_type buffer_device = None + self.loss_fn = self.train_spec.build_loss_fn(job_config) + # verify batch sizes global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: @@ -176,6 +202,7 @@ def __init__(self, job_config: JobConfig): f"data-parallel degree ({global_batch_size} " f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" ) + # calculate gradient accumulation steps self.gradient_accumulation_steps = global_batch_size // ( job_config.training.local_batch_size * dp_degree @@ -184,6 +211,7 @@ def __init__(self, job_config: JobConfig): self.loss_fn = rescale_accumulated_loss( self.loss_fn, self.gradient_accumulation_steps ) + # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: @@ -191,6 +219,7 @@ def __init__(self, job_config: JobConfig): f"Pipeline Parallel is enabled but {self.train_spec.name} " f"does not support pipelining" ) + # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques ( self.pp_schedule, @@ -210,11 +239,13 @@ def __init__(self, job_config: JobConfig): # when PP is enabled, `model` obj is no longer used after this point, # model_parts is used instead del model + for m in self.model_parts: m.to_empty(device=init_device) with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() + # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: @@ -222,16 +253,20 @@ def __init__(self, job_config: JobConfig): model = self.train_spec.parallelize_fn( model, world_mesh, parallel_dims, job_config ) + model.to_empty(device=init_device) with torch.no_grad(): model.init_weights(buffer_device=buffer_device) model.train() + self.model_parts = [model] + if ( self.ft_manager.enabled and job_config.fault_tolerance.semi_sync_method is None ): self.ft_manager.set_all_reduce_hook(self.model_parts) + # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = self.metrics_processor.device_memory_monitor gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) @@ -259,9 +294,11 @@ def __init__(self, job_config: JobConfig): ) ) self.metrics_processor.optimizers = self.optimizers + # Initialize trainer states that will be saved in checkpoint. # These attributes must be initialized before checkpoint loading. self.step = 0 + self.checkpointer = CheckpointManager( dataloader=self.dataloader, model_parts=self.model_parts, @@ -271,6 +308,7 @@ def __init__(self, job_config: JobConfig): job_config=job_config, ft_manager=self.ft_manager, ) + self.train_context = dist_utils.get_train_context( parallel_dims.loss_parallel_enabled, parallelism_config.enable_compiled_autograd, @@ -297,6 +335,7 @@ def batch_generator( """Returns an iterator that processes batches from the data iterator.""" device_type = utils.device_type data_iterator = iter(data_iterable) + while True: try: batch = next(data_iterator) @@ -310,11 +349,13 @@ def batch_generator( self.metrics_processor.data_loading_times.append( time.perf_counter() - data_load_start ) + # Move tensors to the appropriate device for k, v in input_dict.items(): if isinstance(v, torch.Tensor): input_dict[k] = v.to(device_type) labels = labels.to(device_type) + yield input_dict, labels def forward_backward_step( @@ -322,6 +363,7 @@ def forward_backward_step( ) -> torch.Tensor: model_parts = self.model_parts parallel_dims = self.parallel_dims + # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["input"] @@ -336,6 +378,7 @@ def forward_backward_step( if parallel_dims.cp_enabled else None ) + if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call with self.train_context(optional_context_parallel_ctx): @@ -350,6 +393,7 @@ def forward_backward_step( self.pp_schedule.step( target=targets, losses=losses, input_batch=inputs ) + # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU loss = ( @@ -367,15 +411,18 @@ def forward_backward_step( # need to free to before bwd to avoid peaking memory del pred loss.backward() + return loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): self.optimizers.zero_grad() + # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. parallel_dims = self.parallel_dims + accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. @@ -383,6 +430,7 @@ def train_step( input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, @@ -391,13 +439,15 @@ def train_step( ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() - self.lr_schedulers.step() + # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) + # log metrics if not self.metrics_processor.should_log(self.step): return + if parallel_dims.dp_cp_enabled or self.ft_manager.enabled: loss = loss.detach() # Skip ft manager communication when using semi sync training @@ -412,6 +462,7 @@ def train_step( ) else: global_avg_loss = global_max_loss = loss.detach().item() + self.metrics_processor.log( self.step, global_avg_loss, @@ -422,8 +473,10 @@ def train_step( @record def train(self): job_config = self.job_config + self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}.") + with ( maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler, maybe_enable_memory_snapshot( @@ -454,6 +507,7 @@ def train(self): torch_profiler.step() if memory_profiler: memory_profiler.step() + # reduce timeout after first train step for faster signal # (assuming lazy init and compilation are finished) if self.step == 1: @@ -463,9 +517,11 @@ def train(self): ), world_mesh=self.world_mesh, ) + if torch.distributed.get_rank() == 0: logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) + self.metrics_processor.close() logger.info("Training completed") @@ -485,8 +541,10 @@ def close(self) -> None: config_manager = ConfigManager() config = config_manager.parse_args() trainer: Optional[Trainer] = None + try: trainer = Trainer(config) + if config.checkpoint.create_seed_checkpoint: assert ( int(os.environ["WORLD_SIZE"]) == 1 From 42bd47a260c32c595db436c5198b1f67a42b3840 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 19:52:44 -0700 Subject: [PATCH 13/28] undo train changes 2 --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index b217de5a5..ca1480f2e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,10 +11,10 @@ from typing import Any, Generator, Iterable, Optional import torch +from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module -from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderStopIteration from torchtitan.components.loss import rescale_accumulated_loss From b13b332c39bed6ce4c9ac9fd768d605f8fcb8c95 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 19:56:32 -0700 Subject: [PATCH 14/28] lint --- torchtitan/components/checkpoint.py | 2 +- torchtitan/config_manager.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 6654311b7..b4c2d19cb 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -6,7 +6,6 @@ import enum import functools -import json import os import queue import re @@ -331,6 +330,7 @@ def dcp_save( Returns: Future: The future object if the checkpoint is async, otherwise None. """ + ret: Future | None = None state_dict_to_save: dict[str, Any] = ( diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2ea2b81f1..aec6a6091 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -469,11 +469,11 @@ class Checkpoint: enable_hf_safetensors_format: bool = False """ - Enable the use of safetensors format for checkpointing. This will save checkpoints + Enable the use of safetensors format for checkpointing. This will save the final checkpoints in safetensors format instead of the default DCP format. There will be a performance cost in using this as we need to consolidate the sharded tensors to full tensors as a separate step. Last_save_model_weights and initial_load_model_weights_only - must be true because safetensors doesn't support saving non tensors. + must be true because safetensors doesn't support saving non tensors. The default value is False. """ From 2a3fb81baec08e3a521478e84b9939458a15ff03 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 20:10:53 -0700 Subject: [PATCH 15/28] look for safetensors metadata --- torchtitan/components/checkpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index b4c2d19cb..403da24dd 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -568,8 +568,9 @@ def _find_load_step(self, folder: str = "") -> int: for filename in os.listdir(folder): match = re.search(pattern, filename) - metadata_probe = os.path.join(folder, filename, ".metadata") - if match and os.path.isfile(metadata_probe): + dcp_metadata_probe = os.path.join(folder, filename, ".metadata") + safetensors_metadata_probe = os.path.join(folder, filename, "model.safetensors.index.json") + if match and (os.path.isfile(dcp_metadata_probe) or os.path.isfile(safetensors_metadata_probe)): step_counts.append(int(match.group(1))) if not step_counts: return -1 From 0f223ccb151ac90aecf2cb65041ac9881d9605db Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 20:58:20 -0700 Subject: [PATCH 16/28] some fixes --- tests/integration_tests.py | 5 ---- torchtitan/components/checkpoint.py | 38 ++++++++++++++++++++++++++--- torchtitan/config_manager.py | 5 ++-- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index dca2610be..7f45f4176 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -124,11 +124,6 @@ def build_test_list(): "--checkpoint.enable_checkpoint", "--checkpoint.enable_hf_safetensors_format", ], - [ - "--checkpoint.enable_checkpoint", - "--checkpoint.enable_hf_safetensors_format", - "--training.steps 20", - ], ], "Checkpoint Integration Test - Save Load Full Checkpoint", "full_checkpoint_hf_safetensors", diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 403da24dd..bfbdb6fd9 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -53,6 +53,10 @@ class AsyncMode(str, enum.Enum): ASYNC = "async" ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" +class CheckpointType(str, enum.Enum): + DCP = "DCP" + SAFETENSORS = "safetensors" + # For now, we will manually pop the freqs_cis buffer, as we made this permanent # temporarily and we don't want to include it in the exported state_dict. @@ -392,7 +396,7 @@ def dcp_save( return ret - def dcp_load(self, state_dict: dict[str, Any], checkpoint_id: str) -> None: + def dcp_load(self, state_dict: dict[str, Any], checkpoint_id: str, checkpoint_type: CheckpointType) -> None: """Load the checkpoint with dcp. Args: state_dict (dict): The state dict to load. @@ -400,7 +404,7 @@ def dcp_load(self, state_dict: dict[str, Any], checkpoint_id: str) -> None: hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. """ - if self.enable_hf_safetensors_format: + if checkpoint_type == CheckpointType.SAFETENSORS: storage_reader = HuggingFaceStorageReader(path=checkpoint_id) dcp.load(state_dict, storage_reader=storage_reader) else: @@ -526,13 +530,17 @@ def load(self, step: int = -1) -> bool: raise FileNotFoundError( f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found." ) - + + checkpoint_type = self._find_checkpoint_type(checkpoint_id) + if checkpoint_type == CheckpointType.SAFETENSORS: + model_only = True logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) self.dcp_load( states, checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( @@ -562,6 +570,7 @@ def _find_load_step(self, folder: str = "") -> int: folder = folder if folder else self.folder pattern = r"step-(\d+)" step_counts = [] + checkpoint_type_map = {} if not os.path.isdir(folder): return -1 @@ -570,12 +579,32 @@ def _find_load_step(self, folder: str = "") -> int: match = re.search(pattern, filename) dcp_metadata_probe = os.path.join(folder, filename, ".metadata") safetensors_metadata_probe = os.path.join(folder, filename, "model.safetensors.index.json") - if match and (os.path.isfile(dcp_metadata_probe) or os.path.isfile(safetensors_metadata_probe)): + if match and os.path.isfile(dcp_metadata_probe): + step_counts.append(int(match.group(1))) + checkpoint_type_map[int(match.group(1))] = CheckpointType.DCP + elif match and os.path.isfile(safetensors_metadata_probe): step_counts.append(int(match.group(1))) + checkpoint_type_map[int(match.group(1))] = CheckpointType.SAFETENSORS if not step_counts: return -1 return max(step_counts) + def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: + """Find the checkpoint type for the given id. + + Args: + checkpoint_id (str): The folder to find the checkpoint type for. + + Returns: + CheckpointType: The checkpoint type for the given folder. + """ + + for filename in os.listdir(checkpoint_id): + if filename == "model.safetensors.index.json": + return CheckpointType.SAFETENSORS + return CheckpointType.DCP + + def _ft_folder(self) -> str: return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") @@ -603,6 +632,7 @@ def _ft_load(self) -> None: self.dcp_load( self.ft_states, checkpoint_id=checkpoint_id, + checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index aec6a6091..62183f5b2 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -472,8 +472,9 @@ class Checkpoint: Enable the use of safetensors format for checkpointing. This will save the final checkpoints in safetensors format instead of the default DCP format. There will be a performance cost in using this as we need to consolidate the sharded tensors to full tensors as - a separate step. Last_save_model_weights and initial_load_model_weights_only - must be true because safetensors doesn't support saving non tensors. + a separate step. Last_save_model_weights must be true because safetensors doesn't + support saving non tensors. On load, this argument isn't needed as we will detect + whether the loaded checkpoint is in safetensors format or not. The default value is False. """ From 139a1f5df19b1803b920e23a4be9d28efa0ce544 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 21:03:31 -0700 Subject: [PATCH 17/28] change arg name --- tests/integration_tests.py | 2 +- torchtitan/components/checkpoint.py | 32 +++++++---------------------- torchtitan/config_manager.py | 2 +- 3 files changed, 9 insertions(+), 27 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 7f45f4176..ea87584ce 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -122,7 +122,7 @@ def build_test_list(): [ [ "--checkpoint.enable_checkpoint", - "--checkpoint.enable_hf_safetensors_format", + "--checkpoint.enable_save_safetensors_format", ], ], "Checkpoint Integration Test - Save Load Full Checkpoint", diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index bfbdb6fd9..39e1fc098 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -193,7 +193,7 @@ def __init__( ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint - self.enable_hf_safetensors_format = ckpt_config.enable_hf_safetensors_format + self.enable_save_safetensors_format = ckpt_config.enable_save_safetensors_format self.ft_manager = ft_manager.manager if ft_manager.enabled else None if self.ft_manager: @@ -337,29 +337,14 @@ def dcp_save( ret: Future | None = None - state_dict_to_save: dict[str, Any] = ( - {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} - if self.enable_hf_safetensors_format - else state_dict - ) - logger.info( - "Num keys before parsing %d, after %d", - len(state_dict), - len(state_dict_to_save), - ) - for k, v in state_dict_to_save.items(): - if isinstance(v, torch.Tensor): - logger.info("key %s, shape %s", k, v.shape) - break - storage_writer: HuggingFaceStorageWriter | None = None - if self.enable_hf_safetensors_format and is_last_step: + if self.enable_save_safetensors_format and is_last_step: fqn_to_index_mapping = {} num_fqns_per_file = 30 # the use of 30 is just a heuristic for now. # Once these fqns map to HF ones, we can use the fqn mapping # from the model.safetensors.index.json file - for i, key in enumerate(state_dict_to_save.keys()): + for i, key in enumerate(state_dict.keys()): group_num = (i // num_fqns_per_file) + 1 fqn_to_index_mapping[key] = group_num @@ -369,17 +354,17 @@ def dcp_save( fqn_to_index_mapping=fqn_to_index_mapping, enable_consolidation=is_last_step, ) - id = checkpoint_id if not self.enable_hf_safetensors_format else None + id = checkpoint_id if not self.enable_save_safetensors_format else None if async_mode == AsyncMode.ASYNC: ret = dcp.async_save( - state_dict_to_save, + state_dict, storage_writer=storage_writer, checkpoint_id=id, process_group=self.pg, ) elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: ret = dcp.async_save( - state_dict_to_save, + state_dict, storage_writer=storage_writer, checkpoint_id=id, process_group=self.pg, @@ -388,7 +373,7 @@ def dcp_save( ) else: ret = dcp.save( - state_dict_to_save, storage_writer=storage_writer, checkpoint_id=id + state_dict, storage_writer=storage_writer, checkpoint_id=id ) if enable_garbage_collection: @@ -570,7 +555,6 @@ def _find_load_step(self, folder: str = "") -> int: folder = folder if folder else self.folder pattern = r"step-(\d+)" step_counts = [] - checkpoint_type_map = {} if not os.path.isdir(folder): return -1 @@ -581,10 +565,8 @@ def _find_load_step(self, folder: str = "") -> int: safetensors_metadata_probe = os.path.join(folder, filename, "model.safetensors.index.json") if match and os.path.isfile(dcp_metadata_probe): step_counts.append(int(match.group(1))) - checkpoint_type_map[int(match.group(1))] = CheckpointType.DCP elif match and os.path.isfile(safetensors_metadata_probe): step_counts.append(int(match.group(1))) - checkpoint_type_map[int(match.group(1))] = CheckpointType.SAFETENSORS if not step_counts: return -1 return max(step_counts) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 62183f5b2..2a8c7f4c8 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -467,7 +467,7 @@ class Checkpoint: for many steps or checkpointing too frequently. The default value is False. """ - enable_hf_safetensors_format: bool = False + enable_save_safetensors_format: bool = False """ Enable the use of safetensors format for checkpointing. This will save the final checkpoints in safetensors format instead of the default DCP format. There will be a performance From 99725a9e1f5f44a86c9294c300ac76f43a2627dc Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 10 Jul 2025 21:08:40 -0700 Subject: [PATCH 18/28] lint --- torchtitan/components/checkpoint.py | 6 +++--- torchtitan/config_manager.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 39e1fc098..430e28fb5 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -515,7 +515,7 @@ def load(self, step: int = -1) -> bool: raise FileNotFoundError( f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found." ) - + checkpoint_type = self._find_checkpoint_type(checkpoint_id) if checkpoint_type == CheckpointType.SAFETENSORS: model_only = True @@ -585,7 +585,7 @@ def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: if filename == "model.safetensors.index.json": return CheckpointType.SAFETENSORS return CheckpointType.DCP - + def _ft_folder(self) -> str: return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") @@ -614,7 +614,7 @@ def _ft_load(self) -> None: self.dcp_load( self.ft_states, checkpoint_id=checkpoint_id, - checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP + checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2a8c7f4c8..04bb7a57c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -472,7 +472,7 @@ class Checkpoint: Enable the use of safetensors format for checkpointing. This will save the final checkpoints in safetensors format instead of the default DCP format. There will be a performance cost in using this as we need to consolidate the sharded tensors to full tensors as - a separate step. Last_save_model_weights must be true because safetensors doesn't + a separate step. Last_save_model_weights must be true because safetensors doesn't support saving non tensors. On load, this argument isn't needed as we will detect whether the loaded checkpoint is in safetensors format or not. The default value is False. From 1f3ecc0057fbed7a01a6cd301baaca50f4642368 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Fri, 11 Jul 2025 04:54:22 -0700 Subject: [PATCH 19/28] test pass --- tests/integration_tests.py | 1 + torchtitan/components/checkpoint.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index ea87584ce..e31343f8e 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -123,6 +123,7 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--checkpoint.enable_save_safetensors_format", + "--checkpoint.last_save_model_weights_only", ], ], "Checkpoint Integration Test - Save Load Full Checkpoint", diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 430e28fb5..fac698735 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -53,6 +53,7 @@ class AsyncMode(str, enum.Enum): ASYNC = "async" ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" + class CheckpointType(str, enum.Enum): DCP = "DCP" SAFETENSORS = "safetensors" @@ -372,16 +373,19 @@ def dcp_save( async_stager=self.stager, ) else: - ret = dcp.save( - state_dict, storage_writer=storage_writer, checkpoint_id=id - ) + ret = dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id) if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") return ret - def dcp_load(self, state_dict: dict[str, Any], checkpoint_id: str, checkpoint_type: CheckpointType) -> None: + def dcp_load( + self, + state_dict: dict[str, Any], + checkpoint_id: str, + checkpoint_type: CheckpointType, + ) -> None: """Load the checkpoint with dcp. Args: state_dict (dict): The state dict to load. @@ -562,7 +566,9 @@ def _find_load_step(self, folder: str = "") -> int: for filename in os.listdir(folder): match = re.search(pattern, filename) dcp_metadata_probe = os.path.join(folder, filename, ".metadata") - safetensors_metadata_probe = os.path.join(folder, filename, "model.safetensors.index.json") + safetensors_metadata_probe = os.path.join( + folder, filename, "model.safetensors.index.json" + ) if match and os.path.isfile(dcp_metadata_probe): step_counts.append(int(match.group(1))) elif match and os.path.isfile(safetensors_metadata_probe): @@ -586,7 +592,6 @@ def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: return CheckpointType.SAFETENSORS return CheckpointType.DCP - def _ft_folder(self) -> str: return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") From 8d5a8bcc8af80dd54b843eadf4f70f22d87f6bf1 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Fri, 11 Jul 2025 06:30:41 -0700 Subject: [PATCH 20/28] add thread count consolidation --- torchtitan/components/checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index fac698735..0cab8a0c4 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -354,6 +354,7 @@ def dcp_save( save_distributed=True, fqn_to_index_mapping=fqn_to_index_mapping, enable_consolidation=is_last_step, + thread_count_consolidation=5, ) id = checkpoint_id if not self.enable_save_safetensors_format else None if async_mode == AsyncMode.ASYNC: From 2b5793f2a9089efda097a6a061bc5b9e450feb7f Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Mon, 14 Jul 2025 08:02:19 -0700 Subject: [PATCH 21/28] made requested changes --- tests/integration_tests.py | 7 +++++- torchtitan/components/checkpoint.py | 39 +++++++++++++++++++++-------- torchtitan/config_manager.py | 4 +-- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index e31343f8e..072c16640 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -122,11 +122,16 @@ def build_test_list(): [ [ "--checkpoint.enable_checkpoint", + "--checkpoint.folder hf_checkpoint", "--checkpoint.enable_save_safetensors_format", "--checkpoint.last_save_model_weights_only", ], + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.initial_load_path outputs/hf_checkpoint/step-10", + ], ], - "Checkpoint Integration Test - Save Load Full Checkpoint", + "Checkpoint Integration Test - save load full checkpoint in HF safetensors format", "full_checkpoint_hf_safetensors", ), OverrideDefinitions( diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 0cab8a0c4..ddcb8c629 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -194,7 +194,9 @@ def __init__( ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint - self.enable_save_safetensors_format = ckpt_config.enable_save_safetensors_format + self.last_save_in_safetensors_format = ( + ckpt_config.last_save_in_safetensors_format + ) self.ft_manager = ft_manager.manager if ft_manager.enabled else None if self.ft_manager: @@ -324,7 +326,7 @@ def dcp_save( checkpoint_id: str, async_mode: AsyncMode, enable_garbage_collection: bool = False, - is_last_step: bool = False, + save_in_safetensors_format: bool = False, ) -> Future | None: """Save the checkpoint with dcp. Args: @@ -339,7 +341,7 @@ def dcp_save( ret: Future | None = None storage_writer: HuggingFaceStorageWriter | None = None - if self.enable_save_safetensors_format and is_last_step: + if save_in_safetensors_format: fqn_to_index_mapping = {} num_fqns_per_file = 30 # the use of 30 is just a heuristic for now. @@ -353,28 +355,35 @@ def dcp_save( path=checkpoint_id, save_distributed=True, fqn_to_index_mapping=fqn_to_index_mapping, - enable_consolidation=is_last_step, + enable_consolidation=True, thread_count_consolidation=5, ) - id = checkpoint_id if not self.enable_save_safetensors_format else None + + checkpoint_save_id = ( + checkpoint_id if not self.last_save_in_safetensors_format else None + ) if async_mode == AsyncMode.ASYNC: ret = dcp.async_save( state_dict, storage_writer=storage_writer, - checkpoint_id=id, + checkpoint_id=checkpoint_save_id, process_group=self.pg, ) elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: ret = dcp.async_save( state_dict, storage_writer=storage_writer, - checkpoint_id=id, + checkpoint_id=checkpoint_save_id, process_group=self.pg, async_checkpointer_type=AsyncCheckpointerType.PROCESS, async_stager=self.stager, ) else: - ret = dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id) + ret = dcp.save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=checkpoint_save_id, + ) if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") @@ -523,7 +532,9 @@ def load(self, step: int = -1) -> bool: checkpoint_type = self._find_checkpoint_type(checkpoint_id) if checkpoint_type == CheckpointType.SAFETENSORS: - model_only = True + assert ( + model_only + ), "Only model weights can be loaded when loading from safetensors checkpoint." logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) @@ -620,7 +631,8 @@ def _ft_load(self) -> None: self.dcp_load( self.ft_states, checkpoint_id=checkpoint_id, - checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP + # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. + checkpoint_type=CheckpointType.DCP, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( @@ -690,12 +702,17 @@ def _save_last_step(self, curr_step: int) -> None: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") states = self._flattened_model_states_sd() + if self.last_save_in_safetensors_format: + assert ( + self.last_save_model_weights_only + ), "Only model weights can be saved when saving in safetensors format." + self.dcp_save( states, checkpoint_id=self._create_checkpoint_id(curr_step), async_mode=AsyncMode.DISABLED, enable_garbage_collection=True, - is_last_step=True, + save_in_safetensors_format=self.last_save_in_safetensors_format, ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 04bb7a57c..18217c04b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -467,12 +467,12 @@ class Checkpoint: for many steps or checkpointing too frequently. The default value is False. """ - enable_save_safetensors_format: bool = False + last_save_in_safetensors_format: bool = False """ Enable the use of safetensors format for checkpointing. This will save the final checkpoints in safetensors format instead of the default DCP format. There will be a performance cost in using this as we need to consolidate the sharded tensors to full tensors as - a separate step. Last_save_model_weights must be true because safetensors doesn't + a separate step. last_save_model_weights_only must be true because safetensors doesn't support saving non tensors. On load, this argument isn't needed as we will detect whether the loaded checkpoint is in safetensors format or not. The default value is False. From a7d0ea902e88918697592f0ae2e57ddbd2bc3467 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 14 Jul 2025 08:08:03 -0700 Subject: [PATCH 22/28] fix comment --- torchtitan/components/checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index ddcb8c629..315264714 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -332,7 +332,9 @@ def dcp_save( Args: state_dict (dict): The state dict to save. checkpoint_id (str): The checkpoint id to save. - is_async (bool): Whether the checkpoint is async. + async_mode (AsyncMode): Whether the checkpoint is async. + enable_garbage_collection (bool): Whether to enable garbage collection after save. + save_in_safetensors_format (bool): Whether to save in safetensors format. Returns: Future: The future object if the checkpoint is async, otherwise None. From 2dabc4122da3de3db109ffd4b8c9216c08d9c96d Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 14 Jul 2025 08:44:40 -0700 Subject: [PATCH 23/28] update test --- tests/integration_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 072c16640..db897a47c 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -123,7 +123,7 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--checkpoint.folder hf_checkpoint", - "--checkpoint.enable_save_safetensors_format", + "--checkpoint.last_save_in_safetensors_format", "--checkpoint.last_save_model_weights_only", ], [ From e91174be36aa3d17740bb8a8d580b749a7a6d9f2 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 14 Jul 2025 09:48:52 -0700 Subject: [PATCH 24/28] update test --- tests/integration_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index db897a47c..c3f384979 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -128,7 +128,7 @@ def build_test_list(): ], [ "--checkpoint.enable_checkpoint", - "--checkpoint.initial_load_path outputs/hf_checkpoint/step-10", + "--checkpoint.initial_load_path outputs/full_checkpoint_hf_safetensors/hf_checkpoint/step-10/", ], ], "Checkpoint Integration Test - save load full checkpoint in HF safetensors format", From f1862446c7b520fc5b4bf18e0195ca5c705a9780 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 14 Jul 2025 10:32:15 -0700 Subject: [PATCH 25/28] add debug statement --- torchtitan/components/checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 315264714..fbe4f15a0 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -339,6 +339,7 @@ def dcp_save( Returns: Future: The future object if the checkpoint is async, otherwise None. """ + print("saving to checkpoint_id", checkpoint_id) ret: Future | None = None From 6ba5d1829703b2eaaa5eec927776e47e4fe7f88e Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 14 Jul 2025 10:53:05 -0700 Subject: [PATCH 26/28] test passes --- tests/integration_tests.py | 2 +- torchtitan/components/checkpoint.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index c3f384979..c93c08922 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -128,7 +128,7 @@ def build_test_list(): ], [ "--checkpoint.enable_checkpoint", - "--checkpoint.initial_load_path outputs/full_checkpoint_hf_safetensors/hf_checkpoint/step-10/", + "--checkpoint.initial_load_path artifacts-to-be-uploaded/full_checkpoint_hf_safetensors/hf_checkpoint/step-10/", ], ], "Checkpoint Integration Test - save load full checkpoint in HF safetensors format", diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index fbe4f15a0..315264714 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -339,7 +339,6 @@ def dcp_save( Returns: Future: The future object if the checkpoint is async, otherwise None. """ - print("saving to checkpoint_id", checkpoint_id) ret: Future | None = None From c36c8632cb63af30b71ce8f4a091d9b9120028b7 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Mon, 14 Jul 2025 12:44:04 -0700 Subject: [PATCH 27/28] add docs --- docs/checkpoint.md | 5 +++++ torchtitan/components/checkpoint.py | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 0dad44e67..4ab25f373 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -85,3 +85,8 @@ e.g. ```bash NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 ``` + + +## "How to load / save a checkpoint in HF safetensors format +For save, users need to set --checkpoint.last_save_in_safetensors_format and --checkpoint.last_save_model_weights_only to save the last checkpoint in HF format (intermediate ones are always in DCP format). +For load, users need to either put the checkpoint in step-0 folder if using --checkpoint.folder, or specify --checkpoint.initial_load_path to load from a different folder. They also need to set --checkpoint.initial_load_model_weights_only to load the checkpoint in HF format. diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 315264714..c5f3bf235 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -343,6 +343,7 @@ def dcp_save( ret: Future | None = None storage_writer: HuggingFaceStorageWriter | None = None + checkpoint_save_id: str | None = None if save_in_safetensors_format: fqn_to_index_mapping = {} num_fqns_per_file = 30 @@ -360,10 +361,9 @@ def dcp_save( enable_consolidation=True, thread_count_consolidation=5, ) + else: + checkpoint_save_id = checkpoint_id - checkpoint_save_id = ( - checkpoint_id if not self.last_save_in_safetensors_format else None - ) if async_mode == AsyncMode.ASYNC: ret = dcp.async_save( state_dict, From a3725b292eecd8b191941b8e008c7947724a443b Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Mon, 14 Jul 2025 16:51:25 -0400 Subject: [PATCH 28/28] Update docs/checkpoint.md Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com> --- docs/checkpoint.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 4ab25f373..ec555b22e 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -87,6 +87,6 @@ NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoi ``` -## "How to load / save a checkpoint in HF safetensors format -For save, users need to set --checkpoint.last_save_in_safetensors_format and --checkpoint.last_save_model_weights_only to save the last checkpoint in HF format (intermediate ones are always in DCP format). -For load, users need to either put the checkpoint in step-0 folder if using --checkpoint.folder, or specify --checkpoint.initial_load_path to load from a different folder. They also need to set --checkpoint.initial_load_model_weights_only to load the checkpoint in HF format. +## How to load / save a checkpoint in HF safetensors format +For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_weights_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format). +For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_weights_only` to load the checkpoint in HF format.