Skip to content

Commit 3ccd12c

Browse files
committed
[WIP] Integrate autoparallel into torchtitan
TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop Hack an init_fn for llama3 and observe loss decreasing with autoparallel """ [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """ Adopt new autoparallel API with meta-init model Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop. Fixes to align with latest autoparallel Add inductor config knobs for comms optimizations to torchtitan Make inductor always run compile passes basically, this is an annoying workaround for debugging iteratively. 1- you run the model, it compiles, but something weird happens 2- you enable some logging or tlparse, rerun. but inductor decides not to run your pass anymore, its results are cached. since (2) has confused me horribly on more than one occasion, i just disable caching for now Drop hacky llama3_init_fn and use autop init_weights feature Relying on pytorch-labs/autoparallel#20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ``` fix lint
1 parent 3ca7041 commit 3ccd12c

File tree

6 files changed

+159
-6
lines changed

6 files changed

+159
-6
lines changed

torchtitan/config_manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,28 @@ class Experimental:
664664
needs to ensure that the path can be imported.
665665
"""
666666

667+
reorder_for_compute_comm_overlap: bool = False
668+
"""
669+
Whether to enable inductor comm reordering passes
670+
"""
671+
672+
reorder_for_compute_comm_overlap_passes: list[str] = field(
673+
default_factory=lambda: [
674+
"sink_waits",
675+
"reorder_communication_preserving_peak_memory",
676+
]
677+
)
678+
"""
679+
Sequence of reordering passes (names of functions inside _inductor.comms) to call,
680+
if reorder_for_compute_comm_overlap is enabled.
681+
"""
682+
683+
reorder_prefetch_limit: int | None = None
684+
"""
685+
How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory'
686+
pass is enabled. default of None means unlimited
687+
"""
688+
667689

668690
@dataclass
669691
class Validation:

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import torchtitan.experiments.auto_parallel # noqa: F401
78
import torchtitan.experiments.llama4 # noqa: F401
89
import torchtitan.experiments.simple_fsdp # noqa: F401
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## Auto Parallel
2+
3+
requires installing git@github.com:pytorch-labs/autoparallel.git
4+
5+
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4`
6+
7+
(or llama3-8b.toml)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.optimizer import build_optimizers
12+
from torchtitan.components.tokenizer import build_hf_tokenizer
13+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14+
from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer
15+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16+
from .parallelize_llama import parallelize_llama
17+
18+
register_train_spec(
19+
TrainSpec(
20+
name="llama3_auto_parallel",
21+
cls=Transformer,
22+
config=llama3_configs,
23+
parallelize_fn=parallelize_llama,
24+
pipelining_fn=pipeline_llama,
25+
build_optimizers_fn=build_optimizers,
26+
build_lr_schedulers_fn=build_lr_schedulers,
27+
build_dataloader_fn=build_hf_dataloader,
28+
build_tokenizer_fn=build_hf_tokenizer,
29+
build_loss_fn=build_cross_entropy_loss,
30+
)
31+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import time
8+
9+
import torch
10+
11+
from autoparallel.api import AutoParallel
12+
13+
from torch.distributed import DeviceMesh
14+
from torch.distributed.tensor.placement_types import Replicate, Shard
15+
16+
from torchtitan.config_manager import JobConfig
17+
from torchtitan.distributed import ParallelDims
18+
19+
from torchtitan.tools.logging import logger
20+
21+
22+
def parallelize_llama(
23+
model,
24+
world_mesh: DeviceMesh,
25+
parallel_dims: ParallelDims,
26+
job_config: JobConfig,
27+
):
28+
"""
29+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
30+
parallelism to the model.
31+
32+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
33+
the model must fit on GPU or CPU memory.
34+
"""
35+
36+
def input_fn():
37+
global_batch_size = job_config.training.global_batch_size
38+
if global_batch_size < 0:
39+
# This global batch size results in 1 gradient accumulation
40+
# step.
41+
dp_degree = world_mesh["dp"].size()
42+
global_batch_size = job_config.training.local_batch_size * dp_degree
43+
return torch.rand(
44+
(global_batch_size, job_config.training.seq_len), device="cuda"
45+
)
46+
47+
# TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP
48+
assert (
49+
len(world_mesh.shape) == 2
50+
), "Only support 2D mesh (DP, TP) for now- OK if one has size=1"
51+
assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet"
52+
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet"
53+
assert parallel_dims.cp_enabled is False, "CP not supported yet"
54+
assert parallel_dims.pp_enabled is False, "PP not supported yet"
55+
56+
# bail out
57+
# model = model_fn()
58+
# return model
59+
60+
autop = AutoParallel(model, input_fn, world_mesh)
61+
autop.add_parameter_memory_constraint(low=None, high=None)
62+
63+
x_sharding = (Shard(0), Replicate())
64+
65+
autop.add_input_constraints([x_sharding])
66+
autop.add_output_constraints([x_sharding])
67+
t0 = time.time()
68+
sharding_placement = autop.optimize_placement()
69+
t1 = time.time()
70+
logger.info(f"AutoParallel took {t1 - t0} seconds")
71+
parallel_mod = autop.apply_placement(sharding_placement)
72+
73+
if job_config.training.compile:
74+
torch._inductor.config.reorder_for_peak_memory = False
75+
parallel_mod = torch.compile(parallel_mod, fullgraph=True)
76+
77+
return parallel_mod

torchtitan/train.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
from torch.distributed.elastic.multiprocessing.errors import record
15+
from torch.distributed.tensor import DTensor
1516

1617
import torchtitan.components.ft as ft
1718
import torchtitan.protocols.train_spec as train_spec_module
@@ -116,6 +117,21 @@ def __init__(self, job_config: JobConfig):
116117
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
117118
)
118119

120+
# TODO(whc)
121+
# I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering
122+
torch._inductor.config.force_disable_caches = True
123+
124+
# allow configuring inductor comms optimizations from torchtitan commandline
125+
torch._inductor.config.reorder_for_compute_comm_overlap = (
126+
job_config.experimental.reorder_for_compute_comm_overlap
127+
)
128+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = (
129+
job_config.experimental.reorder_for_compute_comm_overlap_passes
130+
)
131+
torch._inductor.config.reorder_prefetch_limit = (
132+
job_config.experimental.reorder_prefetch_limit
133+
)
134+
119135
# Set random seed, and maybe enable deterministic mode
120136
# (mainly for debugging, expect perf loss).
121137
dist_utils.set_determinism(
@@ -141,20 +157,19 @@ def __init__(self, job_config: JobConfig):
141157
)
142158

143159
# build model (using meta init)
144-
model_cls = self.train_spec.cls
145160
model_args = self.train_spec.config[job_config.model.flavor]
161+
model_cls = self.train_spec.cls
146162
# set the model args from training job configs
147163
model_args.update_from_config(job_config, tokenizer)
148-
149164
logger.info(
150165
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
151166
)
167+
152168
with torch.device("meta"):
153169
model = model_cls(model_args)
154-
155-
# Build the collection of model converters. No-op if `model.converters` empty
156-
model_converters = build_model_converters(job_config, parallel_dims)
157-
model_converters.convert(model)
170+
# Build the collection of model converters. No-op if `model.converters` empty
171+
model_converters = build_model_converters(job_config, parallel_dims)
172+
model_converters.convert(model)
158173

159174
# metrics logging
160175
build_metrics_processor_fn = (

0 commit comments

Comments
 (0)