diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 4c54bdc13..b7ff983e9 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,5 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torchtitan.experiments.auto_parallel # noqa: F401 import torchtitan.experiments.llama4 # noqa: F401 import torchtitan.experiments.simple_fsdp # noqa: F401 diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md new file mode 100644 index 000000000..ef66a5916 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/README.md @@ -0,0 +1,7 @@ +## Auto Parallel + +requires installing git@github.com:pytorch-labs/autoparallel.git + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` + +(or llama3-8b.toml) diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py new file mode 100644 index 000000000..7a2682a9f --- /dev/null +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from .parallelize_llama import parallelize_llama + +register_train_spec( + TrainSpec( + name="llama3_auto_parallel", + cls=Transformer, + config=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py new file mode 100644 index 000000000..e06ae4fbf --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed import DeviceMesh +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return torch.rand( + (global_batch_size, job_config.training.seq_len), device="cuda" + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + + # bail out + # model = model_fn() + # return model + + autop = AutoParallel(model, input_fn, world_mesh) + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + assert all(name in possible_input_shardings for name in world_mesh.mesh_dim_names), ( + f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + ) + x_sharding = tuple(possible_input_shardings[name] for name in world_mesh.mesh_dim_names) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if job_config.training.compile: + torch._inductor.config.reorder_for_peak_memory = False + parallel_mod = torch.compile(parallel_mod, fullgraph=True) + + return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index 9340671d7..2734fc8e1 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,7 +12,6 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record - import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager @@ -24,6 +23,7 @@ ) from torchtitan.config_manager import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils +from torch.distributed.tensor import DTensor from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -138,20 +138,92 @@ def __init__(self, job_config: JobConfig): ) # build model (using meta init) - model_cls = self.train_spec.cls model_args = self.train_spec.config[job_config.model.flavor] + model_cls = self.train_spec.cls # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) - logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) + + + def llama3_autoparallel_init_fn(model): + # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying + # code from the llama3 init_weights functions throughout the model components, and adjusting them to use + # the new FQN structures in autoparallel. + # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module + def param(name): + return model.get_parameter(f"params.{name}") + + from torchtitan.models.llama3.model import precompute_freqs_cis + + model.buffers_.get_buffer("freqs_cis").copy_( + DTensor.from_local( + precompute_freqs_cis( + model_args.dim // model_args.n_heads, + model_args.max_seq_len, + model_args.rope_theta, + ), + device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh, + ) + ) + + torch.nn.init.normal_(param("tok_embeddings/weight")) + + def init_layer(i): + for norm in ("attention_norm", "ffn_norm"): + torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight")) + + if model_args.depth_init: + weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5 + else: + weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + for linear in ("wq", "wk", "wv"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/{linear}/weight"), + mean=0.0, + std=0.02, + ) + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/wo/weight"), + mean=0.0, + std=weight_init_std, + ) + + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02 + ) + for linear in ("w2", "w3"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/{linear}/weight"), + mean=0.0, + std=weight_init_std, + ) + + for i in range(model_args.n_layers): + init_layer(i) + + if param("norm/weight") is not None: + torch.nn.init.ones_(param("norm/weight")) + + final_out_std = model_args.dim**-0.5 + cutoff_factor = 3 + + if param("output/weight") is not None: + torch.nn.init.trunc_normal_( + param("output/weight"), + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + with torch.device("meta"): model = model_cls.from_model_args(model_args) - - # Build the collection of model converters. No-op if `model.converters` empty - model_converters = build_model_converters(job_config, parallel_dims) - model_converters.convert(model) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # metrics logging build_metrics_processor_fn = ( @@ -256,7 +328,9 @@ def __init__(self, job_config: JobConfig): model.to_empty(device=init_device) with torch.no_grad(): - model.init_weights(buffer_device=buffer_device) + # TODO(whc) make model.init_weights work with autoparallel + llama3_autoparallel_init_fn(model) + # model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model]