Skip to content

Autoparallel support for DP-only, DP+TP, or TP-only #1349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: autoparallel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torchtitan.experiments.auto_parallel # noqa: F401
import torchtitan.experiments.llama4 # noqa: F401
import torchtitan.experiments.simple_fsdp # noqa: F401
7 changes: 7 additions & 0 deletions torchtitan/experiments/auto_parallel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## Auto Parallel

requires installing git@github.com:pytorch-labs/autoparallel.git

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4`

(or llama3-8b.toml)
31 changes: 31 additions & 0 deletions torchtitan/experiments/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
from .parallelize_llama import parallelize_llama

register_train_spec(
TrainSpec(
name="llama3_auto_parallel",
cls=Transformer,
config=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_tiktoken_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)
81 changes: 81 additions & 0 deletions torchtitan/experiments/auto_parallel/parallelize_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import time

import torch

from autoparallel.api import AutoParallel

from torch.distributed import DeviceMesh
from torch.distributed.tensor.placement_types import Replicate, Shard

from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims

from torchtitan.tools.logging import logger


def parallelize_llama(
model,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
parallelism to the model.

NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
def input_fn():
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
# This global batch size results in 1 gradient accumulation
# step.
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making dp_replicate and dp_shard configurable seem to be weird as 2 * 4 or 4 * 2 make no difference . Should we just stick to one and assert when another is assigned to be larger than 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean they make no difference? i want to use them to control the memory constraints so autop behaves more like ddp vs fsdp vs hsdp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to have more explicit arguments to tune memory constraints. But I understand now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i think i will add explicit knobs to get the most out of autoparallel. i was first trying to slot autoparallel into the existing torchtitan as 'seamlessly' as possible.

one thing is that if users ignore the --dp_replicate_degree and similar cmdline args, and use other args for influencing autoparallel, we have the problem of how to decide which mesh dims to create. I will have to think about what to do for that.

global_batch_size = job_config.training.local_batch_size * dp_degree
return torch.rand(
(global_batch_size, job_config.training.seq_len), device="cuda"
)

# TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet"
assert parallel_dims.cp_enabled is False, "CP not supported yet"
assert parallel_dims.pp_enabled is False, "PP not supported yet"


# bail out
# model = model_fn()
# return model

autop = AutoParallel(model, input_fn, world_mesh)
autop.add_parameter_memory_constraint(low=None, high=None)
Copy link
Member

@fmassa fmassa Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you might also need to tweak the memory constraint if you want to get DP behavior. For example, you can get DDP with low=1.0, high=1.0.

By default (i.e., low=None, high=None), we get low = 0, high=1 / mesh.size(), which says "let's shard the sum of all parameters so that each GPU has 1 / mesh.size().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i will try that next.


possible_input_shardings = {
# maps relative to mesh dim names used in torchtitan
"dp_replicate": Shard(0),
"dp_shard": Shard(0),
"tp": Replicate(),
}
assert all(name in possible_input_shardings for name in world_mesh.mesh_dim_names), (
f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
)
x_sharding = tuple(possible_input_shardings[name] for name in world_mesh.mesh_dim_names)
autop.add_input_constraints([x_sharding])
autop.add_output_constraints([x_sharding])
t0 = time.time()
sharding_placement = autop.optimize_placement()
t1 = time.time()
logger.info(f"AutoParallel took {t1 - t0} seconds")
parallel_mod = autop.apply_placement(sharding_placement)

if job_config.training.compile:
torch._inductor.config.reorder_for_peak_memory = False
parallel_mod = torch.compile(parallel_mod, fullgraph=True)

return parallel_mod
90 changes: 82 additions & 8 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import torch
from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.components.ft as ft
import torchtitan.protocols.train_spec as train_spec_module
from torchtitan.components.checkpoint import CheckpointManager
Expand All @@ -24,6 +23,7 @@
)
from torchtitan.config_manager import ConfigManager, JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torch.distributed.tensor import DTensor
from torchtitan.protocols.model_converter import build_model_converters
from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
Expand Down Expand Up @@ -138,20 +138,92 @@ def __init__(self, job_config: JobConfig):
)

# build model (using meta init)
model_cls = self.train_spec.cls
model_args = self.train_spec.config[job_config.model.flavor]
model_cls = self.train_spec.cls
# set the model args from training job configs
model_args.update_from_config(job_config, tokenizer)

logger.info(
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
)


def llama3_autoparallel_init_fn(model):
# WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying
# code from the llama3 init_weights functions throughout the model components, and adjusting them to use
# the new FQN structures in autoparallel.
# TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module
def param(name):
return model.get_parameter(f"params.{name}")

from torchtitan.models.llama3.model import precompute_freqs_cis

model.buffers_.get_buffer("freqs_cis").copy_(
DTensor.from_local(
precompute_freqs_cis(
model_args.dim // model_args.n_heads,
model_args.max_seq_len,
model_args.rope_theta,
),
device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh,
)
)

torch.nn.init.normal_(param("tok_embeddings/weight"))

def init_layer(i):
for norm in ("attention_norm", "ffn_norm"):
torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight"))

if model_args.depth_init:
weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5
else:
weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5

for linear in ("wq", "wk", "wv"):
torch.nn.init.trunc_normal_(
param(f"layers/{i}/attention/{linear}/weight"),
mean=0.0,
std=0.02,
)
torch.nn.init.trunc_normal_(
param(f"layers/{i}/attention/wo/weight"),
mean=0.0,
std=weight_init_std,
)

torch.nn.init.trunc_normal_(
param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02
)
for linear in ("w2", "w3"):
torch.nn.init.trunc_normal_(
param(f"layers/{i}/feed_forward/{linear}/weight"),
mean=0.0,
std=weight_init_std,
)

for i in range(model_args.n_layers):
init_layer(i)

if param("norm/weight") is not None:
torch.nn.init.ones_(param("norm/weight"))

final_out_std = model_args.dim**-0.5
cutoff_factor = 3

if param("output/weight") is not None:
torch.nn.init.trunc_normal_(
param("output/weight"),
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)

with torch.device("meta"):
model = model_cls.from_model_args(model_args)

# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)
# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)

# metrics logging
build_metrics_processor_fn = (
Expand Down Expand Up @@ -256,7 +328,9 @@ def __init__(self, job_config: JobConfig):

model.to_empty(device=init_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
# TODO(whc) make model.init_weights work with autoparallel
llama3_autoparallel_init_fn(model)
# model.init_weights(buffer_device=buffer_device)
model.train()

self.model_parts = [model]
Expand Down
Loading