-
Notifications
You must be signed in to change notification settings - Fork 428
Autoparallel support for DP-only, DP+TP, or TP-only #1349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: autoparallel
Are you sure you want to change the base?
Changes from all commits
5e93263
42d5da6
c8fb6b5
0ec2b2f
6861adb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
## Auto Parallel | ||
|
||
requires installing git@github.com:pytorch-labs/autoparallel.git | ||
|
||
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` | ||
|
||
(or llama3-8b.toml) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
from torchtitan.components.loss import build_cross_entropy_loss | ||
from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
from torchtitan.components.optimizer import build_optimizers | ||
from torchtitan.datasets.hf_datasets import build_hf_dataloader | ||
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer | ||
from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer | ||
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec | ||
from .parallelize_llama import parallelize_llama | ||
|
||
register_train_spec( | ||
TrainSpec( | ||
name="llama3_auto_parallel", | ||
cls=Transformer, | ||
config=llama3_configs, | ||
parallelize_fn=parallelize_llama, | ||
pipelining_fn=pipeline_llama, | ||
build_optimizers_fn=build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
build_dataloader_fn=build_hf_dataloader, | ||
build_tokenizer_fn=build_tiktoken_tokenizer, | ||
build_loss_fn=build_cross_entropy_loss, | ||
) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import time | ||
|
||
import torch | ||
|
||
from autoparallel.api import AutoParallel | ||
|
||
from torch.distributed import DeviceMesh | ||
from torch.distributed.tensor.placement_types import Replicate, Shard | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.distributed import ParallelDims | ||
|
||
from torchtitan.tools.logging import logger | ||
|
||
|
||
def parallelize_llama( | ||
model, | ||
world_mesh: DeviceMesh, | ||
parallel_dims: ParallelDims, | ||
job_config: JobConfig, | ||
): | ||
""" | ||
Apply tensor parallelism, activation checkpointing, torch.compile, and data | ||
parallelism to the model. | ||
|
||
NOTE: The passed-in model preferably should be on meta device. Otherwise, | ||
the model must fit on GPU or CPU memory. | ||
""" | ||
def input_fn(): | ||
global_batch_size = job_config.training.global_batch_size | ||
if global_batch_size < 0: | ||
# This global batch size results in 1 gradient accumulation | ||
# step. | ||
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard | ||
global_batch_size = job_config.training.local_batch_size * dp_degree | ||
return torch.rand( | ||
(global_batch_size, job_config.training.seq_len), device="cuda" | ||
) | ||
|
||
# TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP | ||
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" | ||
assert parallel_dims.cp_enabled is False, "CP not supported yet" | ||
assert parallel_dims.pp_enabled is False, "PP not supported yet" | ||
|
||
|
||
# bail out | ||
# model = model_fn() | ||
# return model | ||
|
||
autop = AutoParallel(model, input_fn, world_mesh) | ||
autop.add_parameter_memory_constraint(low=None, high=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe you might also need to tweak the memory constraint if you want to get By default (i.e., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, i will try that next. |
||
|
||
possible_input_shardings = { | ||
# maps relative to mesh dim names used in torchtitan | ||
"dp_replicate": Shard(0), | ||
"dp_shard": Shard(0), | ||
"tp": Replicate(), | ||
} | ||
assert all(name in possible_input_shardings for name in world_mesh.mesh_dim_names), ( | ||
f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" | ||
) | ||
x_sharding = tuple(possible_input_shardings[name] for name in world_mesh.mesh_dim_names) | ||
autop.add_input_constraints([x_sharding]) | ||
autop.add_output_constraints([x_sharding]) | ||
t0 = time.time() | ||
sharding_placement = autop.optimize_placement() | ||
t1 = time.time() | ||
logger.info(f"AutoParallel took {t1 - t0} seconds") | ||
parallel_mod = autop.apply_placement(sharding_placement) | ||
|
||
if job_config.training.compile: | ||
torch._inductor.config.reorder_for_peak_memory = False | ||
parallel_mod = torch.compile(parallel_mod, fullgraph=True) | ||
|
||
return parallel_mod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making
dp_replicate
anddp_shard
configurable seem to be weird as 2 * 4 or 4 * 2 make no difference . Should we just stick to one and assert when another is assigned to be larger than 1?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean they make no difference? i want to use them to control the memory constraints so autop behaves more like ddp vs fsdp vs hsdp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking to have more explicit arguments to tune memory constraints. But I understand now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, i think i will add explicit knobs to get the most out of autoparallel. i was first trying to slot autoparallel into the existing torchtitan as 'seamlessly' as possible.
one thing is that if users ignore the --dp_replicate_degree and similar cmdline args, and use other args for influencing autoparallel, we have the problem of how to decide which mesh dims to create. I will have to think about what to do for that.