From 5e93263e207ac1e0ebe94ec06c88151c67d50c9d Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 12 Jun 2025 21:08:11 -0700 Subject: [PATCH 1/5] [WIP] Integrate autoparallel into torchtitan TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop --- torchtitan/components/metrics.py | 14 ++-- torchtitan/experiments/__init__.py | 1 + .../experiments/auto_parallel/README.md | 7 ++ .../experiments/auto_parallel/__init__.py | 31 ++++++++ .../auto_parallel/parallelize_llama.py | 76 +++++++++++++++++++ torchtitan/train.py | 45 ++++++----- 6 files changed, 146 insertions(+), 28 deletions(-) create mode 100644 torchtitan/experiments/auto_parallel/README.md create mode 100644 torchtitan/experiments/auto_parallel/__init__.py create mode 100644 torchtitan/experiments/auto_parallel/parallelize_llama.py diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 084c2c4ff..e93a6561e 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -354,7 +354,7 @@ def log( global_max_loss: float, extra_metrics: dict[str, Any] | None = None, ): - assert self.num_flops_per_token > 0, "num_flops_per_token must be set" + # assert self.num_flops_per_token > 0, "num_flops_per_token must be set" time_delta = time.perf_counter() - self.time_last_log @@ -365,8 +365,8 @@ def log( # model FLOPS utilization # For its definition and calculation, please refer to the PaLM paper: # https://arxiv.org/abs/2204.02311 - mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops - tflops = self.num_flops_per_token * tps / 1e12 + # mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops + # tflops = self.num_flops_per_token * tps / 1e12 time_end_to_end = time_delta / self.job_config.metrics.log_freq time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times) @@ -378,8 +378,8 @@ def log( "loss_metrics/global_avg_loss": global_avg_loss, "loss_metrics/global_max_loss": global_max_loss, "throughput(tps)": tps, - "tflops": tflops, - "mfu(%)": mfu, + # "tflops": tflops, + # "mfu(%)": mfu, "time_metrics/end_to_end(s)": time_end_to_end, "time_metrics/data_loading(s)": time_data_loading, "time_metrics/data_loading(%)": time_data_loading_pct, @@ -403,8 +403,8 @@ def log( f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " - f"{color.cyan}tflops: {tflops:,.2f} " - f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + # f"{color.cyan}tflops: {tflops:,.2f} " + # f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) self.ntokens_since_last_log = 0 diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 4c54bdc13..b7ff983e9 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,5 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torchtitan.experiments.auto_parallel # noqa: F401 import torchtitan.experiments.llama4 # noqa: F401 import torchtitan.experiments.simple_fsdp # noqa: F401 diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md new file mode 100644 index 000000000..ef66a5916 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/README.md @@ -0,0 +1,7 @@ +## Auto Parallel + +requires installing git@github.com:pytorch-labs/autoparallel.git + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` + +(or llama3-8b.toml) diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py new file mode 100644 index 000000000..7a2682a9f --- /dev/null +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from .parallelize_llama import parallelize_llama + +register_train_spec( + TrainSpec( + name="llama3_auto_parallel", + cls=Transformer, + config=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py new file mode 100644 index 000000000..227935108 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed import DeviceMesh +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model_fn, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + # model = model.to_empty(device="cuda") + + # TODO: make auto-p work with already created model object? + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = world_mesh["dp"].size() + global_batch_size = job_config.training.local_batch_size * dp_degree + return torch.rand( + (global_batch_size, job_config.training.seq_len), device="cuda" + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert ( + len(world_mesh.shape) == 2 + ), "Only support 2D mesh (DP, TP) for now- OK if one has size=1" + assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet" + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + autop = AutoParallel(model_fn, input_fn, world_mesh) + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0), Replicate()) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if job_config.training.compile: + torch._inductor.config.reorder_for_peak_memory = False + parallel_mod = torch.compile(parallel_mod, fullgraph=True) + + return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index 9340671d7..051cdd5d5 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -24,7 +24,8 @@ ) from torchtitan.config_manager import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.protocols.model_converter import build_model_converters + +# from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import ( @@ -138,20 +139,22 @@ def __init__(self, job_config: JobConfig): ) # build model (using meta init) - model_cls = self.train_spec.cls model_args = self.train_spec.config[job_config.model.flavor] + model_cls = self.train_spec.cls # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) - logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - with torch.device("meta"): - model = model_cls.from_model_args(model_args) + def model_fn(): + return model_cls.from_model_args(model_args).cuda() + + # with torch.device("meta"): + # model = model_fn() # Build the collection of model converters. No-op if `model.converters` empty - model_converters = build_model_converters(job_config, parallel_dims) - model_converters.convert(model) + # model_converters = build_model_converters(job_config, parallel_dims) + # model_converters.convert(model) # metrics logging build_metrics_processor_fn = ( @@ -165,15 +168,15 @@ def __init__(self, job_config: JobConfig): color = self.metrics_processor.color # calculate model size and flops per token - ( - model_param_count, - self.metrics_processor.num_flops_per_token, - ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + # ( + # model_param_count, + # self.metrics_processor.num_flops_per_token, + # ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) - logger.info( - f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " - f"{color.red}size: {model_param_count:,} total parameters{color.reset}" - ) + # logger.info( + # f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + # f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + # ) # move sharded model to CPU/GPU and initialize weights via DTensor if job_config.checkpoint.create_seed_checkpoint: @@ -251,7 +254,7 @@ def __init__(self, job_config: JobConfig): else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel model = self.train_spec.parallelize_fn( - model, world_mesh, parallel_dims, job_config + model_fn, world_mesh, parallel_dims, job_config ) model.to_empty(device=init_device) @@ -288,11 +291,11 @@ def __init__(self, job_config: JobConfig): # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # where it issues a single all-reduce for all parameters at once for better performance - self.optimizers.register_step_post_hook( - lambda *args, **kwargs: model_converters.post_optimizer_hook( - self.model_parts - ) - ) + # self.optimizers.register_step_post_hook( + # lambda *args, **kwargs: model_converters.post_optimizer_hook( + # self.model_parts + # ) + # ) self.metrics_processor.optimizers = self.optimizers # Initialize trainer states that will be saved in checkpoint. From 42d5da6db85bd8b0cbab3854c5462fd041282ed8 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 16 Jun 2025 16:04:11 -0700 Subject: [PATCH 2/5] Hack an init_fn for llama3 and observe loss decreasing with autoparallel """ [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """ --- torchtitan/components/metrics.py | 1 + .../auto_parallel/parallelize_llama.py | 15 +++- torchtitan/train.py | 79 ++++++++++++++++++- 3 files changed, 89 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index e93a6561e..d87747cda 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -405,6 +405,7 @@ def log( f"{color.blue}tps: {round(tps):,} " # f"{color.cyan}tflops: {tflops:,.2f} " # f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + f"{color.reset}" ) self.ntokens_since_last_log = 0 diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 227935108..acc5f5764 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -21,6 +21,7 @@ def parallelize_llama( model_fn, + init_fn, # TODO(whc) hack to pass stuff to autoparallel world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, @@ -32,9 +33,14 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - # model = model.to_empty(device="cuda") - # TODO: make auto-p work with already created model object? + # wherever the auto-parallel code that creates a FakeTensorMode is... + # fake_mode = ... + # for k, v in m.named_parameters(): + # # swap each param in your model with a fake tensor version + # # warning - we probably need to do this before initializing the optimizer? + # setattr(m, k, fake_mode.from_tensor(v)) + # # also do the same for named_buffers def input_fn(): global_batch_size = job_config.training.global_batch_size @@ -56,6 +62,10 @@ def input_fn(): assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" + # bail out + # model = model_fn() + # return model + autop = AutoParallel(model_fn, input_fn, world_mesh) autop.add_parameter_memory_constraint(low=None, high=None) @@ -73,4 +83,5 @@ def input_fn(): torch._inductor.config.reorder_for_peak_memory = False parallel_mod = torch.compile(parallel_mod, fullgraph=True) + init_fn(parallel_mod) return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index 051cdd5d5..f7af6a869 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -148,8 +148,79 @@ def __init__(self, job_config: JobConfig): ) def model_fn(): + # WHC - allow auto_p to construct the model object under its own fake_mode. + # TODO: let us pass in meta model, and internally hook it up to the auto_p fake mode return model_cls.from_model_args(model_args).cuda() + def init_fn(model): + # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying + # code from the llama3 init_weights functions throughout the model components, and adjusting them to use + # the new FQN structures in autoparallel. + # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module + def param(name): + return model.get_parameter(f"params.{name}") + + from torchtitan.models.llama3.model import precompute_freqs_cis + + model.buffers_.get_buffer("freqs_cis").copy_( + precompute_freqs_cis( + model_args.dim // model_args.n_heads, + model_args.max_seq_len, + model_args.rope_theta, + ) + ) + + torch.nn.init.normal_(param("tok_embeddings/weight")) + + def init_layer(i): + for norm in ("attention_norm", "ffn_norm"): + torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight")) + + if model_args.depth_init: + weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5 + else: + weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + for linear in ("wq", "wk", "wv"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/{linear}/weight"), + mean=0.0, + std=0.02, + ) + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/wo/weight"), + mean=0.0, + std=weight_init_std, + ) + + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02 + ) + for linear in ("w2", "w3"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/{linear}/weight"), + mean=0.0, + std=weight_init_std, + ) + + for i in range(model_args.n_layers): + init_layer(i) + + if param("norm/weight") is not None: + torch.nn.init.ones_(param("norm/weight")) + + final_out_std = model_args.dim**-0.5 + cutoff_factor = 3 + + if param("output/weight") is not None: + torch.nn.init.trunc_normal_( + param("output/weight"), + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + # with torch.device("meta"): # model = model_fn() # Build the collection of model converters. No-op if `model.converters` empty @@ -254,12 +325,12 @@ def model_fn(): else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel model = self.train_spec.parallelize_fn( - model_fn, world_mesh, parallel_dims, job_config + model_fn, init_fn, world_mesh, parallel_dims, job_config ) - model.to_empty(device=init_device) - with torch.no_grad(): - model.init_weights(buffer_device=buffer_device) + # model.to_empty(device=init_device) + # with torch.no_grad(): + # model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model] From c8fb6b54f39e93a5f790eb09c240ff7c2ac8a776 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 18 Jun 2025 16:19:20 -0700 Subject: [PATCH 3/5] Adopt new autoparallel API with meta-init model Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop. --- torchtitan/components/metrics.py | 15 +++-- .../auto_parallel/parallelize_llama.py | 15 +---- torchtitan/train.py | 55 +++++++++---------- 3 files changed, 35 insertions(+), 50 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index d87747cda..084c2c4ff 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -354,7 +354,7 @@ def log( global_max_loss: float, extra_metrics: dict[str, Any] | None = None, ): - # assert self.num_flops_per_token > 0, "num_flops_per_token must be set" + assert self.num_flops_per_token > 0, "num_flops_per_token must be set" time_delta = time.perf_counter() - self.time_last_log @@ -365,8 +365,8 @@ def log( # model FLOPS utilization # For its definition and calculation, please refer to the PaLM paper: # https://arxiv.org/abs/2204.02311 - # mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops - # tflops = self.num_flops_per_token * tps / 1e12 + mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops + tflops = self.num_flops_per_token * tps / 1e12 time_end_to_end = time_delta / self.job_config.metrics.log_freq time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times) @@ -378,8 +378,8 @@ def log( "loss_metrics/global_avg_loss": global_avg_loss, "loss_metrics/global_max_loss": global_max_loss, "throughput(tps)": tps, - # "tflops": tflops, - # "mfu(%)": mfu, + "tflops": tflops, + "mfu(%)": mfu, "time_metrics/end_to_end(s)": time_end_to_end, "time_metrics/data_loading(s)": time_data_loading, "time_metrics/data_loading(%)": time_data_loading_pct, @@ -403,9 +403,8 @@ def log( f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " - # f"{color.cyan}tflops: {tflops:,.2f} " - # f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" - f"{color.reset}" + f"{color.cyan}tflops: {tflops:,.2f} " + f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) self.ntokens_since_last_log = 0 diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index acc5f5764..31830613a 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -20,8 +20,7 @@ def parallelize_llama( - model_fn, - init_fn, # TODO(whc) hack to pass stuff to autoparallel + model, world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, @@ -33,15 +32,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - # TODO: make auto-p work with already created model object? - # wherever the auto-parallel code that creates a FakeTensorMode is... - # fake_mode = ... - # for k, v in m.named_parameters(): - # # swap each param in your model with a fake tensor version - # # warning - we probably need to do this before initializing the optimizer? - # setattr(m, k, fake_mode.from_tensor(v)) - # # also do the same for named_buffers - def input_fn(): global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: @@ -66,7 +56,7 @@ def input_fn(): # model = model_fn() # return model - autop = AutoParallel(model_fn, input_fn, world_mesh) + autop = AutoParallel(model, input_fn, world_mesh, device=world_mesh.device_type) autop.add_parameter_memory_constraint(low=None, high=None) x_sharding = (Shard(0), Replicate()) @@ -83,5 +73,4 @@ def input_fn(): torch._inductor.config.reorder_for_peak_memory = False parallel_mod = torch.compile(parallel_mod, fullgraph=True) - init_fn(parallel_mod) return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index f7af6a869..ac44a73fd 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,7 +12,6 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record - import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager @@ -25,7 +24,7 @@ from torchtitan.config_manager import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils -# from torchtitan.protocols.model_converter import build_model_converters +from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import ( @@ -147,12 +146,8 @@ def __init__(self, job_config: JobConfig): f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - def model_fn(): - # WHC - allow auto_p to construct the model object under its own fake_mode. - # TODO: let us pass in meta model, and internally hook it up to the auto_p fake mode - return model_cls.from_model_args(model_args).cuda() - def init_fn(model): + def llama3_autoparallel_init_fn(model): # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying # code from the llama3 init_weights functions throughout the model components, and adjusting them to use # the new FQN structures in autoparallel. @@ -221,11 +216,11 @@ def init_layer(i): b=cutoff_factor * final_out_std, ) - # with torch.device("meta"): - # model = model_fn() - # Build the collection of model converters. No-op if `model.converters` empty - # model_converters = build_model_converters(job_config, parallel_dims) - # model_converters.convert(model) + with torch.device("meta"): + model = model_cls.from_model_args(model_args) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # metrics logging build_metrics_processor_fn = ( @@ -239,15 +234,15 @@ def init_layer(i): color = self.metrics_processor.color # calculate model size and flops per token - # ( - # model_param_count, - # self.metrics_processor.num_flops_per_token, - # ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + ( + model_param_count, + self.metrics_processor.num_flops_per_token, + ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) - # logger.info( - # f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " - # f"{color.red}size: {model_param_count:,} total parameters{color.reset}" - # ) + logger.info( + f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + ) # move sharded model to CPU/GPU and initialize weights via DTensor if job_config.checkpoint.create_seed_checkpoint: @@ -325,12 +320,14 @@ def init_layer(i): else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel model = self.train_spec.parallelize_fn( - model_fn, init_fn, world_mesh, parallel_dims, job_config + model, world_mesh, parallel_dims, job_config ) - # model.to_empty(device=init_device) - # with torch.no_grad(): - # model.init_weights(buffer_device=buffer_device) + model.to_empty(device=init_device) + with torch.no_grad(): + # TODO(whc) make model.init_weights work with autoparallel + llama3_autoparallel_init_fn(model) + # model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model] @@ -362,11 +359,11 @@ def init_layer(i): # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # where it issues a single all-reduce for all parameters at once for better performance - # self.optimizers.register_step_post_hook( - # lambda *args, **kwargs: model_converters.post_optimizer_hook( - # self.model_parts - # ) - # ) + self.optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.model_parts + ) + ) self.metrics_processor.optimizers = self.optimizers # Initialize trainer states that will be saved in checkpoint. From 0ec2b2fcc452bf7d56f7ad7fa8c212898bd2b331 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 25 Jun 2025 17:10:09 -0700 Subject: [PATCH 4/5] Fixes to align with latest autoparallel --- .../experiments/auto_parallel/parallelize_llama.py | 2 +- torchtitan/train.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 31830613a..af279de59 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -56,7 +56,7 @@ def input_fn(): # model = model_fn() # return model - autop = AutoParallel(model, input_fn, world_mesh, device=world_mesh.device_type) + autop = AutoParallel(model, input_fn, world_mesh) autop.add_parameter_memory_constraint(low=None, high=None) x_sharding = (Shard(0), Replicate()) diff --git a/torchtitan/train.py b/torchtitan/train.py index ac44a73fd..2734fc8e1 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -23,7 +23,7 @@ ) from torchtitan.config_manager import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils - +from torch.distributed.tensor import DTensor from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -158,10 +158,13 @@ def param(name): from torchtitan.models.llama3.model import precompute_freqs_cis model.buffers_.get_buffer("freqs_cis").copy_( - precompute_freqs_cis( - model_args.dim // model_args.n_heads, - model_args.max_seq_len, - model_args.rope_theta, + DTensor.from_local( + precompute_freqs_cis( + model_args.dim // model_args.n_heads, + model_args.max_seq_len, + model_args.rope_theta, + ), + device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh, ) ) From 6861adbe9d86e76057cf8200fe0803f39dff30c6 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 26 Jun 2025 18:34:34 -0700 Subject: [PATCH 5/5] Autoparallel support for DP-only, DP+TP, or TP-only lets existing torchtitan knobs which govern DP/TP mesh creation and mesh size influence the sharding constraints of autoparallel, allowing it to support these different sharding configurations. Examples: `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 1 --training.dataset c4` https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpUf57BL/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ``` [rank0]:[titan] 2025-06-26 18:12:46,592 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-06-26 18:12:46,593 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:[titan] 2025-06-26 18:23:14,389 - root - INFO - step: 1 loss: 8.1996 memory: 2.65GiB(2.79%) tps: 1,772 tflops: 0.13 mfu: 0.01% [rank0]:[titan] 2025-06-26 18:23:14,486 - root - INFO - step: 2 loss: 8.1709 memory: 2.66GiB(2.80%) tps: 168,877 tflops: 12.14 mfu: 1.23% [rank0]:[titan] 2025-06-26 18:23:14,580 - root - INFO - step: 3 loss: 8.1121 memory: 2.66GiB(2.80%) tps: 175,100 tflops: 12.59 mfu: 1.27% [rank0]:[titan] 2025-06-26 18:23:14,677 - root - INFO - step: 4 loss: 8.0119 memory: 2.66GiB(2.80%) tps: 170,227 tflops: 12.24 mfu: 1.24% [rank0]:[titan] 2025-06-26 18:23:14,771 - root - INFO - step: 5 loss: 7.8920 memory: 2.66GiB(2.80%) tps: 174,614 tflops: 12.56 mfu: 1.27% [rank0]:[titan] 2025-06-26 18:23:14,867 - root - INFO - step: 6 loss: 7.7511 memory: 2.66GiB(2.80%) tps: 170,863 tflops: 12.29 mfu: 1.24% [rank0]:[titan] 2025-06-26 18:23:14,963 - root - INFO - step: 7 loss: 7.6531 memory: 2.66GiB(2.80%) tps: 172,868 tflops: 12.43 mfu: 1.26% [rank0]:[titan] 2025-06-26 18:23:15,060 - root - INFO - step: 8 loss: 7.5231 memory: 2.66GiB(2.80%) tps: 168,378 tflops: 12.11 mfu: 1.22% [rank0]:[titan] 2025-06-26 18:23:15,157 - root - INFO - step: 9 loss: 7.3795 memory: 2.66GiB(2.80%) tps: 170,250 tflops: 12.24 mfu: 1.24% [rank0]:[titan] 2025-06-26 18:23:15,251 - root - INFO - step: 10 loss: 7.3036 memory: 2.66GiB(2.80%) tps: 175,755 tflops: 12.64 mfu: 1.28% ``` `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp981ifR/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ``` [rank0]:[titan] 2025-06-26 18:24:05,617 - root - INFO - Building 2-D device mesh with ['dp_shard', 'tp'], [2, 4] [rank0]:[titan] 2025-06-26 18:27:44,952 - root - INFO - step: 1 loss: 8.2345 memory: 1.08GiB(1.14%) tps: 74 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-06-26 18:27:45,003 - root - INFO - step: 2 loss: 8.2156 memory: 1.15GiB(1.21%) tps: 80,543 tflops: 5.79 mfu: 0.59% [rank0]:[titan] 2025-06-26 18:27:45,054 - root - INFO - step: 3 loss: 8.1867 memory: 1.15GiB(1.21%) tps: 81,472 tflops: 5.86 mfu: 0.59% [rank0]:[titan] 2025-06-26 18:27:45,099 - root - INFO - step: 4 loss: 8.1072 memory: 1.15GiB(1.21%) tps: 90,961 tflops: 6.54 mfu: 0.66% [rank0]:[titan] 2025-06-26 18:27:45,145 - root - INFO - step: 5 loss: 8.0360 memory: 1.15GiB(1.21%) tps: 90,280 tflops: 6.49 mfu: 0.66% [rank0]:[titan] 2025-06-26 18:27:45,193 - root - INFO - step: 6 loss: 7.9681 memory: 1.15GiB(1.21%) tps: 84,915 tflops: 6.11 mfu: 0.62% [rank0]:[titan] 2025-06-26 18:27:45,241 - root - INFO - step: 7 loss: 7.8870 memory: 1.15GiB(1.21%) tps: 86,096 tflops: 6.19 mfu: 0.63% [rank0]:[titan] 2025-06-26 18:27:45,292 - root - INFO - step: 8 loss: 7.8493 memory: 1.15GiB(1.21%) tps: 81,182 tflops: 5.84 mfu: 0.59% [rank0]:[titan] 2025-06-26 18:27:45,341 - root - INFO - step: 9 loss: 7.7431 memory: 1.15GiB(1.21%) tps: 84,341 tflops: 6.06 mfu: 0.61% [rank0]:[titan] 2025-06-26 18:27:45,396 - root - INFO - step: 10 loss: 7.7052 memory: 1.15GiB(1.21%) tps: 74,973 tflops: 5.39 mfu: 0.55% ``` `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 8 --training.dataset c4` https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpgPuMRF/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ``` [rank0]:[titan] 2025-06-26 18:32:37,789 - root - INFO - Building 1-D device mesh with ['tp'], [8] [rank0]:[titan] 2025-06-26 18:33:00,183 - root - INFO - step: 1 loss: 8.2190 memory: 0.81GiB(0.85%) tps: 205 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-06-26 18:33:00,251 - root - INFO - step: 2 loss: 8.1733 memory: 0.87GiB(0.92%) tps: 30,431 tflops: 2.19 mfu: 0.22% [rank0]:[titan] 2025-06-26 18:33:00,297 - root - INFO - step: 3 loss: 8.1438 memory: 0.87GiB(0.92%) tps: 44,284 tflops: 3.18 mfu: 0.32% [rank0]:[titan] 2025-06-26 18:33:00,342 - root - INFO - step: 4 loss: 8.0361 memory: 0.87GiB(0.92%) tps: 45,921 tflops: 3.30 mfu: 0.33% [rank0]:[titan] 2025-06-26 18:33:00,384 - root - INFO - step: 5 loss: 7.9559 memory: 0.87GiB(0.92%) tps: 49,178 tflops: 3.54 mfu: 0.36% [rank0]:[titan] 2025-06-26 18:33:00,426 - root - INFO - step: 6 loss: 7.8346 memory: 0.87GiB(0.92%) tps: 49,172 tflops: 3.54 mfu: 0.36% [rank0]:[titan] 2025-06-26 18:33:00,462 - root - INFO - step: 7 loss: 7.7266 memory: 0.87GiB(0.92%) tps: 58,273 tflops: 4.19 mfu: 0.42% [rank0]:[titan] 2025-06-26 18:33:00,499 - root - INFO - step: 8 loss: 7.6807 memory: 0.87GiB(0.92%) tps: 54,435 tflops: 3.91 mfu: 0.40% [rank0]:[titan] 2025-06-26 18:33:00,537 - root - INFO - step: 9 loss: 7.5616 memory: 0.87GiB(0.92%) tps: 55,232 tflops: 3.97 mfu: 0.40% [rank0]:[titan] 2025-06-26 18:33:00,575 - root - INFO - step: 10 loss: 7.5090 memory: 0.87GiB(0.92%) tps: 54,284 tflops: 3.90 mfu: 0.39% ``` --- .../auto_parallel/parallelize_llama.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index af279de59..e06ae4fbf 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -37,21 +37,18 @@ def input_fn(): if global_batch_size < 0: # This global batch size results in 1 gradient accumulation # step. - dp_degree = world_mesh["dp"].size() + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard global_batch_size = job_config.training.local_batch_size * dp_degree return torch.rand( (global_batch_size, job_config.training.seq_len), device="cuda" ) # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP - assert ( - len(world_mesh.shape) == 2 - ), "Only support 2D mesh (DP, TP) for now- OK if one has size=1" - assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet" assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" + # bail out # model = model_fn() # return model @@ -59,8 +56,16 @@ def input_fn(): autop = AutoParallel(model, input_fn, world_mesh) autop.add_parameter_memory_constraint(low=None, high=None) - x_sharding = (Shard(0), Replicate()) - + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + assert all(name in possible_input_shardings for name in world_mesh.mesh_dim_names), ( + f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + ) + x_sharding = tuple(possible_input_shardings[name] for name in world_mesh.mesh_dim_names) autop.add_input_constraints([x_sharding]) autop.add_output_constraints([x_sharding]) t0 = time.time()