Skip to content

Commit 5ae4823

Browse files
committed
fixes
1 parent f6e0022 commit 5ae4823

File tree

1 file changed

+49
-46
lines changed

1 file changed

+49
-46
lines changed

torchtitan/components/checkpoint.py

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -98,45 +98,6 @@ class SaveDone:
9898
pass
9999

100100

101-
def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
102-
"""Process to save the checkpoint in the background.
103-
104-
This is only used when async_checkpoint_with_pinned_memory is enabled.
105-
106-
Args:
107-
recv (mp.Queue): The queue to receive the state_dict and Terminate signal.
108-
send (mp.Queue): The queue to send the SaveDone signal.
109-
"""
110-
init_logger()
111-
os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2)
112-
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
113-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
114-
dist.init_process_group()
115-
try:
116-
while True:
117-
logger.debug("Checkpoint background process is done.")
118-
send.put(SaveDone())
119-
logger.debug("Wait for the new state_dict.")
120-
obj = recv.get()
121-
logger.debug("Received the new state_dict.")
122-
if isinstance(obj, Terminate):
123-
logger.info("Terminating the checkpoint background process.")
124-
return
125-
assert isinstance(obj, tuple)
126-
begin = time.monotonic()
127-
state, checkpoint_id = obj
128-
save_with_gc(
129-
state, checkpoint_id=checkpoint_id, hf_safetensors_format=False
130-
)
131-
logger.info(
132-
"Finish saving the checkpoint in the background process in %.2f seconds.",
133-
time.monotonic() - begin,
134-
)
135-
finally:
136-
logger.info("Destroying the process group.")
137-
dist.destroy_process_group()
138-
139-
140101
@torch.no_grad()
141102
def dcp_save(
142103
state_dict: dict[str, Any],
@@ -145,13 +106,13 @@ def dcp_save(
145106
hf_safetensors_format: bool,
146107
pg: Optional[dist.ProcessGroup] = None,
147108
) -> Optional[Future]:
148-
"""Save the checkpoint for the current step.
149-
150-
109+
"""Save the checkpoint with dcp.
151110
Args:
152111
state_dict (dict): The state dict to save.
153112
checkpoint_id (str): The checkpoint id to save.
154113
is_async (bool): Whether the checkpoint is async.
114+
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
115+
pg (Optional[dist.ProcessGroup]): The process group to use.
155116
"""
156117
if hf_safetensors_format:
157118
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
@@ -173,12 +134,11 @@ def dcp_save(
173134
def dcp_load(
174135
state_dict: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool
175136
) -> None:
176-
"""Save the checkpoint for the current step.
177-
178-
137+
"""Load the checkpoint with dcp.
179138
Args:
180139
state_dict (dict): The state dict to load.
181140
checkpoint_id (str): The checkpoint id to load.
141+
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
182142
"""
183143
if hf_safetensors_format:
184144
storage_reader = HuggingFaceStorageReader(path=checkpoint_id)
@@ -200,6 +160,45 @@ def save_with_gc(
200160
GarbageCollection.collect("GC collection invoked by checkpointer.")
201161

202162

163+
def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
164+
"""Process to save the checkpoint in the background.
165+
166+
This is only used when async_checkpoint_with_pinned_memory is enabled.
167+
168+
Args:
169+
recv (mp.Queue): The queue to receive the state_dict and Terminate signal.
170+
send (mp.Queue): The queue to send the SaveDone signal.
171+
"""
172+
init_logger()
173+
os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2)
174+
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
175+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
176+
dist.init_process_group()
177+
try:
178+
while True:
179+
logger.debug("Checkpoint background process is done.")
180+
send.put(SaveDone())
181+
logger.debug("Wait for the new state_dict.")
182+
obj = recv.get()
183+
logger.debug("Received the new state_dict.")
184+
if isinstance(obj, Terminate):
185+
logger.info("Terminating the checkpoint background process.")
186+
return
187+
assert isinstance(obj, tuple)
188+
begin = time.monotonic()
189+
state, checkpoint_id = obj
190+
save_with_gc(
191+
state, checkpoint_id=checkpoint_id, hf_safetensors_format=False
192+
)
193+
logger.info(
194+
"Finish saving the checkpoint in the background process in %.2f seconds.",
195+
time.monotonic() - begin,
196+
)
197+
finally:
198+
logger.info("Destroying the process group.")
199+
dist.destroy_process_group()
200+
201+
203202
def purge_thread(purge_queue: queue.Queue):
204203
"""Thread to purge the old checkpoints.
205204
@@ -466,7 +465,11 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
466465
)
467466
GarbageCollection.collect("GC collection invoked by checkpointer.")
468467
else:
469-
self.save_with_gc(self.states, checkpoint_id=checkpoint_id)
468+
save_with_gc(
469+
self.states,
470+
checkpoint_id=checkpoint_id,
471+
hf_safetensors_format=self.enable_hf_safetensors_format,
472+
)
470473
self._purge_stale_checkpoints()
471474

472475
logger.info(

0 commit comments

Comments
 (0)