Skip to content

Commit 6b11290

Browse files
authored
Only calls destroy_process_group if the trainer exist successfully (#1342)
If we perform the destroy_process_group when some trainers have exceptions while others are doing collectives, the cleanup itself will cause deadlock. stacktrace: ``` Thread 0x7F81445A8440 (active): "MainThread" destroy_process_group (torch/distributed/distributed_c10d.py:2184) <module> (torchtitan/train.py:554) _run_code (runpy.py:86) _run_module_as_main (runpy.py:196) Thread 0x7F7E83CFF640 (active): "Thread-1 (_read_thread)" _recv_msg (torch/_inductor/compile_worker/subproc_pool.py:61) _read_thread (torch/_inductor/compile_worker/subproc_pool.py:195) run (threading.py:953) _bootstrap_inner (threading.py:1016) _bootstrap (threading.py:973) Thread 0x7F7D9CFF9640 (idle): "Thread-2" wait (threading.py:324) wait (threading.py:607) run (tqdm/_monitor.py:60) _bootstrap_inner (threading.py:1016) _bootstrap (threading.py:973) ```
1 parent 5d4cc9a commit 6b11290

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

torchtitan/experiments/flux/train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,11 @@ def train_step(
225225
logger.info("Created seed checkpoint")
226226
else:
227227
trainer.train()
228-
finally:
228+
except Exception:
229229
if trainer:
230230
trainer.close()
231-
232-
if torch.distributed.is_initialized():
233-
torch.distributed.destroy_process_group()
234-
logger.info("Process group destroyed.")
231+
raise
232+
else:
233+
trainer.close()
234+
torch.distributed.destroy_process_group()
235+
logger.info("Process group destroyed.")

torchtitan/train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,11 @@ def close(self) -> None:
556556
logger.info("Created seed checkpoint")
557557
else:
558558
trainer.train()
559-
finally:
559+
except Exception:
560560
if trainer:
561561
trainer.close()
562-
563-
if torch.distributed.is_initialized():
564-
torch.distributed.destroy_process_group()
565-
logger.info("Process group destroyed.")
562+
raise
563+
else:
564+
trainer.close()
565+
torch.distributed.destroy_process_group()
566+
logger.info("Process group destroyed.")

0 commit comments

Comments
 (0)