diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index bb4543ac1..eda0d5690 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -576,6 +576,28 @@ class Experimental: needs to ensure that the path can be imported. """ + reorder_for_compute_comm_overlap: bool = False + """ + Whether to enable inductor comm reordering passes + """ + + reorder_for_compute_comm_overlap_passes: list[str] = field( + default_factory=lambda: [ + "sink_waits", + "reorder_communication_preserving_peak_memory", + ] + ) + """ + Sequence of reordering passes (names of functions inside _inductor.comms) to call, + if reorder_for_compute_comm_overlap is enabled. + """ + + reorder_prefetch_limit: int | None = None + """ + How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory' + pass is enabled. default of None means unlimited + """ + @dataclass class JobConfig: diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 4c54bdc13..b7ff983e9 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,5 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torchtitan.experiments.auto_parallel # noqa: F401 import torchtitan.experiments.llama4 # noqa: F401 import torchtitan.experiments.simple_fsdp # noqa: F401 diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md new file mode 100644 index 000000000..ef66a5916 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/README.md @@ -0,0 +1,7 @@ +## Auto Parallel + +requires installing git@github.com:pytorch-labs/autoparallel.git + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` + +(or llama3-8b.toml) diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py new file mode 100644 index 000000000..7a2682a9f --- /dev/null +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from .parallelize_llama import parallelize_llama + +register_train_spec( + TrainSpec( + name="llama3_auto_parallel", + cls=Transformer, + config=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py new file mode 100644 index 000000000..88463b74c --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed import DeviceMesh +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = world_mesh["dp"].size() + global_batch_size = job_config.training.local_batch_size * dp_degree + return torch.rand( + (global_batch_size, job_config.training.seq_len), device="cuda" + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert ( + len(world_mesh.shape) == 2 + ), "Only support 2D mesh (DP, TP) for now- OK if one has size=1" + assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet" + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + # TODO: there are multiple float8 recipes, this just hardcodes one + enable_float8_linear = "float8" in job_config.model.converters + if enable_float8_linear: + import copy + from torchao.float8.float8_linear_utils import convert_to_float8_training + from torchao.float8.config import Float8LinearConfig + model = convert_to_float8_training(copy.deepcopy(model), config=Float8LinearConfig()) + + # bail out + # model = model_fn() + # return model + + autop = AutoParallel(model, input_fn, world_mesh) + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0), Replicate()) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if job_config.training.compile: + torch._inductor.config.reorder_for_peak_memory = False + parallel_mod = torch.compile(parallel_mod, fullgraph=True) + + return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index 9340671d7..c82e534b3 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,6 +12,7 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.tensor import DTensor import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module @@ -113,6 +114,21 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # TODO(whc) + # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + + # allow configuring inductor comms optimizations from torchtitan commandline + torch._inductor.config.reorder_for_compute_comm_overlap = ( + job_config.experimental.reorder_for_compute_comm_overlap + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( + job_config.experimental.reorder_for_compute_comm_overlap_passes + ) + torch._inductor.config.reorder_prefetch_limit = ( + job_config.experimental.reorder_prefetch_limit + ) + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -138,20 +154,19 @@ def __init__(self, job_config: JobConfig): ) # build model (using meta init) - model_cls = self.train_spec.cls model_args = self.train_spec.config[job_config.model.flavor] + model_cls = self.train_spec.cls # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) - logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) + with torch.device("meta"): model = model_cls.from_model_args(model_args) - - # Build the collection of model converters. No-op if `model.converters` empty - model_converters = build_model_converters(job_config, parallel_dims) - model_converters.convert(model) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # metrics logging build_metrics_processor_fn = (