Skip to content

Commit 40340a8

Browse files
committed
Add inductor config knobs for comms optimizations to torchtitan
1 parent 0ec2b2f commit 40340a8

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

torchtitan/config_manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,28 @@ class Experimental:
576576
needs to ensure that the path can be imported.
577577
"""
578578

579+
reorder_for_compute_comm_overlap: bool = False
580+
"""
581+
Whether to enable inductor comm reordering passes
582+
"""
583+
584+
reorder_for_compute_comm_overlap_passes: list[str] = field(
585+
default_factory=lambda: [
586+
"sink_waits",
587+
"reorder_communication_preserving_peak_memory",
588+
]
589+
)
590+
"""
591+
Sequence of reordering passes (names of functions inside _inductor.comms) to call,
592+
if reorder_for_compute_comm_overlap is enabled.
593+
"""
594+
595+
reorder_prefetch_limit: int | None = None
596+
"""
597+
How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory'
598+
pass is enabled. default of None means unlimited
599+
"""
600+
579601

580602
@dataclass
581603
class JobConfig:

torchtitan/train.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,17 @@ def __init__(self, job_config: JobConfig):
113113
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
114114
)
115115

116+
# allow configuring inductor comms optimizations from torchtitan commandline
117+
torch._inductor.config.reorder_for_compute_comm_overlap = (
118+
job_config.experimental.reorder_for_compute_comm_overlap
119+
)
120+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = (
121+
job_config.experimental.reorder_for_compute_comm_overlap_passes
122+
)
123+
torch._inductor.config.reorder_prefetch_limit = (
124+
job_config.experimental.reorder_prefetch_limit
125+
)
126+
116127
# Set random seed, and maybe enable deterministic mode
117128
# (mainly for debugging, expect perf loss).
118129
dist_utils.set_determinism(

0 commit comments

Comments
 (0)