Skip to content

Commit 5e93263

Browse files
committed
[WIP] Integrate autoparallel into torchtitan
TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop
1 parent 7c8301a commit 5e93263

File tree

6 files changed

+146
-28
lines changed

6 files changed

+146
-28
lines changed

torchtitan/components/metrics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def log(
354354
global_max_loss: float,
355355
extra_metrics: dict[str, Any] | None = None,
356356
):
357-
assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
357+
# assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
358358

359359
time_delta = time.perf_counter() - self.time_last_log
360360

@@ -365,8 +365,8 @@ def log(
365365
# model FLOPS utilization
366366
# For its definition and calculation, please refer to the PaLM paper:
367367
# https://arxiv.org/abs/2204.02311
368-
mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops
369-
tflops = self.num_flops_per_token * tps / 1e12
368+
# mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops
369+
# tflops = self.num_flops_per_token * tps / 1e12
370370

371371
time_end_to_end = time_delta / self.job_config.metrics.log_freq
372372
time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times)
@@ -378,8 +378,8 @@ def log(
378378
"loss_metrics/global_avg_loss": global_avg_loss,
379379
"loss_metrics/global_max_loss": global_max_loss,
380380
"throughput(tps)": tps,
381-
"tflops": tflops,
382-
"mfu(%)": mfu,
381+
# "tflops": tflops,
382+
# "mfu(%)": mfu,
383383
"time_metrics/end_to_end(s)": time_end_to_end,
384384
"time_metrics/data_loading(s)": time_data_loading,
385385
"time_metrics/data_loading(%)": time_data_loading_pct,
@@ -403,8 +403,8 @@ def log(
403403
f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
404404
f"({device_mem_stats.max_reserved_pct:.2f}%) "
405405
f"{color.blue}tps: {round(tps):,} "
406-
f"{color.cyan}tflops: {tflops:,.2f} "
407-
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
406+
# f"{color.cyan}tflops: {tflops:,.2f} "
407+
# f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
408408
)
409409

410410
self.ntokens_since_last_log = 0

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import torchtitan.experiments.auto_parallel # noqa: F401
78
import torchtitan.experiments.llama4 # noqa: F401
89
import torchtitan.experiments.simple_fsdp # noqa: F401
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## Auto Parallel
2+
3+
requires installing git@github.com:pytorch-labs/autoparallel.git
4+
5+
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4`
6+
7+
(or llama3-8b.toml)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.optimizer import build_optimizers
12+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
13+
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
14+
from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer
15+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16+
from .parallelize_llama import parallelize_llama
17+
18+
register_train_spec(
19+
TrainSpec(
20+
name="llama3_auto_parallel",
21+
cls=Transformer,
22+
config=llama3_configs,
23+
parallelize_fn=parallelize_llama,
24+
pipelining_fn=pipeline_llama,
25+
build_optimizers_fn=build_optimizers,
26+
build_lr_schedulers_fn=build_lr_schedulers,
27+
build_dataloader_fn=build_hf_dataloader,
28+
build_tokenizer_fn=build_tiktoken_tokenizer,
29+
build_loss_fn=build_cross_entropy_loss,
30+
)
31+
)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import time
8+
9+
import torch
10+
11+
from autoparallel.api import AutoParallel
12+
13+
from torch.distributed import DeviceMesh
14+
from torch.distributed.tensor.placement_types import Replicate, Shard
15+
16+
from torchtitan.config_manager import JobConfig
17+
from torchtitan.distributed import ParallelDims
18+
19+
from torchtitan.tools.logging import logger
20+
21+
22+
def parallelize_llama(
23+
model_fn,
24+
world_mesh: DeviceMesh,
25+
parallel_dims: ParallelDims,
26+
job_config: JobConfig,
27+
):
28+
"""
29+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
30+
parallelism to the model.
31+
32+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
33+
the model must fit on GPU or CPU memory.
34+
"""
35+
# model = model.to_empty(device="cuda")
36+
37+
# TODO: make auto-p work with already created model object?
38+
39+
def input_fn():
40+
global_batch_size = job_config.training.global_batch_size
41+
if global_batch_size < 0:
42+
# This global batch size results in 1 gradient accumulation
43+
# step.
44+
dp_degree = world_mesh["dp"].size()
45+
global_batch_size = job_config.training.local_batch_size * dp_degree
46+
return torch.rand(
47+
(global_batch_size, job_config.training.seq_len), device="cuda"
48+
)
49+
50+
# TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP
51+
assert (
52+
len(world_mesh.shape) == 2
53+
), "Only support 2D mesh (DP, TP) for now- OK if one has size=1"
54+
assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet"
55+
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet"
56+
assert parallel_dims.cp_enabled is False, "CP not supported yet"
57+
assert parallel_dims.pp_enabled is False, "PP not supported yet"
58+
59+
autop = AutoParallel(model_fn, input_fn, world_mesh)
60+
autop.add_parameter_memory_constraint(low=None, high=None)
61+
62+
x_sharding = (Shard(0), Replicate())
63+
64+
autop.add_input_constraints([x_sharding])
65+
autop.add_output_constraints([x_sharding])
66+
t0 = time.time()
67+
sharding_placement = autop.optimize_placement()
68+
t1 = time.time()
69+
logger.info(f"AutoParallel took {t1 - t0} seconds")
70+
parallel_mod = autop.apply_placement(sharding_placement)
71+
72+
if job_config.training.compile:
73+
torch._inductor.config.reorder_for_peak_memory = False
74+
parallel_mod = torch.compile(parallel_mod, fullgraph=True)
75+
76+
return parallel_mod

torchtitan/train.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
)
2525
from torchtitan.config_manager import ConfigManager, JobConfig
2626
from torchtitan.distributed import ParallelDims, utils as dist_utils
27-
from torchtitan.protocols.model_converter import build_model_converters
27+
28+
# from torchtitan.protocols.model_converter import build_model_converters
2829
from torchtitan.tools import utils
2930
from torchtitan.tools.logging import init_logger, logger
3031
from torchtitan.tools.profiling import (
@@ -138,20 +139,22 @@ def __init__(self, job_config: JobConfig):
138139
)
139140

140141
# build model (using meta init)
141-
model_cls = self.train_spec.cls
142142
model_args = self.train_spec.config[job_config.model.flavor]
143+
model_cls = self.train_spec.cls
143144
# set the model args from training job configs
144145
model_args.update_from_config(job_config, tokenizer)
145-
146146
logger.info(
147147
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
148148
)
149-
with torch.device("meta"):
150-
model = model_cls.from_model_args(model_args)
151149

150+
def model_fn():
151+
return model_cls.from_model_args(model_args).cuda()
152+
153+
# with torch.device("meta"):
154+
# model = model_fn()
152155
# Build the collection of model converters. No-op if `model.converters` empty
153-
model_converters = build_model_converters(job_config, parallel_dims)
154-
model_converters.convert(model)
156+
# model_converters = build_model_converters(job_config, parallel_dims)
157+
# model_converters.convert(model)
155158

156159
# metrics logging
157160
build_metrics_processor_fn = (
@@ -165,15 +168,15 @@ def __init__(self, job_config: JobConfig):
165168
color = self.metrics_processor.color
166169

167170
# calculate model size and flops per token
168-
(
169-
model_param_count,
170-
self.metrics_processor.num_flops_per_token,
171-
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
171+
# (
172+
# model_param_count,
173+
# self.metrics_processor.num_flops_per_token,
174+
# ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
172175

173-
logger.info(
174-
f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
175-
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
176-
)
176+
# logger.info(
177+
# f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
178+
# f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
179+
# )
177180

178181
# move sharded model to CPU/GPU and initialize weights via DTensor
179182
if job_config.checkpoint.create_seed_checkpoint:
@@ -251,7 +254,7 @@ def __init__(self, job_config: JobConfig):
251254
else:
252255
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
253256
model = self.train_spec.parallelize_fn(
254-
model, world_mesh, parallel_dims, job_config
257+
model_fn, world_mesh, parallel_dims, job_config
255258
)
256259

257260
model.to_empty(device=init_device)
@@ -288,11 +291,11 @@ def __init__(self, job_config: JobConfig):
288291
# Post optimizer step model converters hook.
289292
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
290293
# where it issues a single all-reduce for all parameters at once for better performance
291-
self.optimizers.register_step_post_hook(
292-
lambda *args, **kwargs: model_converters.post_optimizer_hook(
293-
self.model_parts
294-
)
295-
)
294+
# self.optimizers.register_step_post_hook(
295+
# lambda *args, **kwargs: model_converters.post_optimizer_hook(
296+
# self.model_parts
297+
# )
298+
# )
296299
self.metrics_processor.optimizers = self.optimizers
297300

298301
# Initialize trainer states that will be saved in checkpoint.

0 commit comments

Comments
 (0)