-
Notifications
You must be signed in to change notification settings - Fork 426
Add support for saving HF format tensors with DCP #1351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c0c4448
4342e31
bb319bb
d64c307
ba64000
ee1d695
6f76b28
3d18477
cecaea9
0928b8f
d067428
994082d
9d1f9ef
42bd47a
b13b332
2a3fb81
0f223cc
139a1f5
99725a9
1f3ecc0
8d5a8bc
2b5793f
a7d0ea9
2dabc41
e91174b
f186244
6ba5d18
c36c863
a3725b2
ffdb21f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ wandb | |
fsspec | ||
tyro | ||
tokenizers >= 0.15.0 | ||
safetensors | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,17 @@ | |
import shutil | ||
import threading | ||
import time | ||
from concurrent.futures import Future | ||
from typing import Any | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.distributed.checkpoint as dcp | ||
import torch.nn as nn | ||
from torch.distributed.checkpoint import ( | ||
HuggingFaceStorageReader, | ||
HuggingFaceStorageWriter, | ||
) | ||
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions | ||
from torch.distributed.checkpoint.state_dict import ( | ||
get_model_state_dict, | ||
|
@@ -49,6 +54,11 @@ class AsyncMode(str, enum.Enum): | |
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" | ||
|
||
|
||
class CheckpointType(str, enum.Enum): | ||
DCP = "DCP" | ||
SAFETENSORS = "safetensors" | ||
|
||
|
||
# For now, we will manually pop the freqs_cis buffer, as we made this permanent | ||
# temporarily and we don't want to include it in the exported state_dict. | ||
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 | ||
|
@@ -92,12 +102,6 @@ class SaveDone: | |
pass | ||
|
||
|
||
@torch.no_grad() | ||
def save_with_gc(state, checkpoint_id): | ||
dcp.save(state, checkpoint_id=checkpoint_id) | ||
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||
|
||
|
||
def purge_thread(purge_queue: queue.Queue): | ||
"""Thread to purge the old checkpoints. | ||
|
@@ -190,6 +194,9 @@ def __init__( | |
) -> None: | ||
ckpt_config = job_config.checkpoint | ||
self.enable_checkpoint = ckpt_config.enable_checkpoint | ||
self.last_save_in_safetensors_format = ( | ||
ckpt_config.last_save_in_safetensors_format | ||
) | ||
self.ft_manager = ( | ||
ft_manager.manager if ft_manager and ft_manager.enabled else None | ||
) | ||
|
@@ -314,6 +321,98 @@ def close(self): | |
if self.stager is not None: | ||
self.stager.close() | ||
|
||
@torch.no_grad() | ||
def dcp_save( | ||
self, | ||
state_dict: dict[str, Any], | ||
checkpoint_id: str, | ||
async_mode: AsyncMode, | ||
enable_garbage_collection: bool = False, | ||
save_in_safetensors_format: bool = False, | ||
) -> Future | None: | ||
"""Save the checkpoint with dcp. | ||
Args: | ||
state_dict (dict): The state dict to save. | ||
checkpoint_id (str): The checkpoint id to save. | ||
async_mode (AsyncMode): Whether the checkpoint is async. | ||
enable_garbage_collection (bool): Whether to enable garbage collection after save. | ||
save_in_safetensors_format (bool): Whether to save in safetensors format. | ||
Returns: | ||
Future: The future object if the checkpoint is async, otherwise None. | ||
""" | ||
|
||
ret: Future | None = None | ||
|
||
storage_writer: HuggingFaceStorageWriter | None = None | ||
checkpoint_save_id: str | None = None | ||
if save_in_safetensors_format: | ||
fqn_to_index_mapping = {} | ||
num_fqns_per_file = 30 | ||
# the use of 30 is just a heuristic for now. | ||
# Once these fqns map to HF ones, we can use the fqn mapping | ||
# from the model.safetensors.index.json file | ||
for i, key in enumerate(state_dict.keys()): | ||
group_num = (i // num_fqns_per_file) + 1 | ||
fqn_to_index_mapping[key] = group_num | ||
|
||
storage_writer = HuggingFaceStorageWriter( | ||
path=checkpoint_id, | ||
save_distributed=True, | ||
fqn_to_index_mapping=fqn_to_index_mapping, | ||
enable_consolidation=True, | ||
thread_count_consolidation=5, | ||
) | ||
else: | ||
checkpoint_save_id = checkpoint_id | ||
|
||
if async_mode == AsyncMode.ASYNC: | ||
ret = dcp.async_save( | ||
state_dict, | ||
storage_writer=storage_writer, | ||
checkpoint_id=checkpoint_save_id, | ||
process_group=self.pg, | ||
) | ||
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: | ||
ret = dcp.async_save( | ||
state_dict, | ||
storage_writer=storage_writer, | ||
checkpoint_id=checkpoint_save_id, | ||
process_group=self.pg, | ||
async_checkpointer_type=AsyncCheckpointerType.PROCESS, | ||
async_stager=self.stager, | ||
) | ||
else: | ||
ret = dcp.save( | ||
state_dict, | ||
storage_writer=storage_writer, | ||
checkpoint_id=checkpoint_save_id, | ||
) | ||
|
||
if enable_garbage_collection: | ||
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||
|
||
return ret | ||
|
||
def dcp_load( | ||
self, | ||
state_dict: dict[str, Any], | ||
checkpoint_id: str, | ||
checkpoint_type: CheckpointType, | ||
) -> None: | ||
"""Load the checkpoint with dcp. | ||
Args: | ||
state_dict (dict): The state dict to load. | ||
checkpoint_id (str): The checkpoint id to load. | ||
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. | ||
""" | ||
|
||
if checkpoint_type == CheckpointType.SAFETENSORS: | ||
storage_reader = HuggingFaceStorageReader(path=checkpoint_id) | ||
dcp.load(state_dict, storage_reader=storage_reader) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a question on how to combine this PR with torchtitan model <-> HF model definition conversion mappings, so that we can load / save with HF models and train with torchtitan. Let's say we have to mappings, I could imagine how we do this for save -- when HF format is used, we just don't call the current How are we supposed to do load in HF definition? |
||
else: | ||
dcp.load(state_dict, checkpoint_id=checkpoint_id) | ||
|
||
@torch.no_grad() | ||
def save(self, curr_step: int, last_step: bool = False) -> None: | ||
"""Save the checkpoint for the current step. | ||
|
@@ -354,23 +453,26 @@ def save(self, curr_step: int, last_step: bool = False) -> None: | |
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||
if self.stager is None: | ||
self.stager = DefaultStager(StagingOptions(True, True, True, True)) | ||
result = dcp.async_save( | ||
result = self.dcp_save( | ||
states, | ||
checkpoint_id=checkpoint_id, | ||
process_group=self.pg, | ||
async_checkpointer_type=AsyncCheckpointerType.PROCESS, | ||
async_stager=self.stager, | ||
async_mode=self.async_mode, | ||
) | ||
self.save_future = result.upload_completion | ||
self.staging_future = result.staging_completion | ||
elif self.async_mode == AsyncMode.ASYNC: | ||
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||
self.save_future = dcp.async_save( | ||
states, checkpoint_id=checkpoint_id, process_group=self.pg | ||
self.save_future = self.dcp_save( | ||
states, checkpoint_id=checkpoint_id, async_mode=self.async_mode | ||
) | ||
GarbageCollection.collect("GC collection invoked by checkpointer.") | ||
else: | ||
save_with_gc(states, checkpoint_id=checkpoint_id) | ||
self.dcp_save( | ||
states, | ||
checkpoint_id=checkpoint_id, | ||
async_mode=AsyncMode.DISABLED, | ||
enable_garbage_collection=True, | ||
) | ||
self._purge_stale_checkpoints() | ||
|
||
logger.info( | ||
|
@@ -432,10 +534,19 @@ def load(self, step: int = -1) -> bool: | |
f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found." | ||
) | ||
|
||
checkpoint_type = self._find_checkpoint_type(checkpoint_id) | ||
if checkpoint_type == CheckpointType.SAFETENSORS: | ||
assert ( | ||
model_only | ||
), "Only model weights can be loaded when loading from safetensors checkpoint." | ||
logger.info(f"Loading the checkpoint from {checkpoint_id}.") | ||
begin = time.monotonic() | ||
states = self._states_to_load(model_only) | ||
dcp.load(states, checkpoint_id=checkpoint_id) | ||
self.dcp_load( | ||
states, | ||
checkpoint_id=checkpoint_id, | ||
checkpoint_type=checkpoint_type, | ||
) | ||
GarbageCollection.collect("GC collection for checkpoint loading.") | ||
logger.info( | ||
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." | ||
|
@@ -470,13 +581,33 @@ def _find_load_step(self, folder: str = "") -> int: | |
|
||
for filename in os.listdir(folder): | ||
match = re.search(pattern, filename) | ||
metadata_probe = os.path.join(folder, filename, ".metadata") | ||
if match and os.path.isfile(metadata_probe): | ||
dcp_metadata_probe = os.path.join(folder, filename, ".metadata") | ||
safetensors_metadata_probe = os.path.join( | ||
folder, filename, "model.safetensors.index.json" | ||
) | ||
if match and os.path.isfile(dcp_metadata_probe): | ||
step_counts.append(int(match.group(1))) | ||
elif match and os.path.isfile(safetensors_metadata_probe): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh for safetensors do we also require
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from your other comment, it seems like you think the HF checkpoint in step 10, 20 etc isn't a valid use case, so I will just follow your suggestion |
||
step_counts.append(int(match.group(1))) | ||
if not step_counts: | ||
return -1 | ||
return max(step_counts) | ||
|
||
def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: | ||
"""Find the checkpoint type for the given id. | ||
Args: | ||
checkpoint_id (str): The folder to find the checkpoint type for. | ||
Returns: | ||
CheckpointType: The checkpoint type for the given folder. | ||
""" | ||
|
||
for filename in os.listdir(checkpoint_id): | ||
if filename == "model.safetensors.index.json": | ||
return CheckpointType.SAFETENSORS | ||
return CheckpointType.DCP | ||
|
||
def _ft_folder(self) -> str: | ||
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") | ||
|
||
|
@@ -488,8 +619,8 @@ def _ft_save(self, step: int) -> None: | |
begin = time.monotonic() | ||
self._async_wait() | ||
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) | ||
self.save_future = dcp.async_save( | ||
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg | ||
self.save_future = self.dcp_save( | ||
self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC | ||
) | ||
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.") | ||
|
||
|
@@ -501,7 +632,12 @@ def _ft_load(self) -> None: | |
begin = time.monotonic() | ||
logger.info(f"Loading the FT checkpoint at step {step}.") | ||
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) | ||
dcp.load(self.ft_states, checkpoint_id=checkpoint_id) | ||
self.dcp_load( | ||
self.ft_states, | ||
checkpoint_id=checkpoint_id, | ||
# FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. | ||
checkpoint_type=CheckpointType.DCP, | ||
) | ||
GarbageCollection.collect("GC collection for checkpoint loading.") | ||
logger.info( | ||
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." | ||
|
@@ -570,7 +706,18 @@ def _save_last_step(self, curr_step: int) -> None: | |
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") | ||
states = self._flattened_model_states_sd() | ||
|
||
save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step)) | ||
if self.last_save_in_safetensors_format: | ||
assert ( | ||
self.last_save_model_weights_only | ||
), "Only model weights can be saved when saving in safetensors format." | ||
|
||
self.dcp_save( | ||
states, | ||
checkpoint_id=self._create_checkpoint_id(curr_step), | ||
async_mode=AsyncMode.DISABLED, | ||
enable_garbage_collection=True, | ||
save_in_safetensors_format=self.last_save_in_safetensors_format, | ||
) | ||
|
||
def _should_save(self, curr_step: int, last_step: bool = False) -> bool: | ||
if not self.enable_checkpoint: | ||
|
Uh oh!
There was an error while loading. Please reload this page.