From 5e93263e207ac1e0ebe94ec06c88151c67d50c9d Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 12 Jun 2025 21:08:11 -0700 Subject: [PATCH 1/9] [WIP] Integrate autoparallel into torchtitan TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop --- torchtitan/components/metrics.py | 14 ++-- torchtitan/experiments/__init__.py | 1 + .../experiments/auto_parallel/README.md | 7 ++ .../experiments/auto_parallel/__init__.py | 31 ++++++++ .../auto_parallel/parallelize_llama.py | 76 +++++++++++++++++++ torchtitan/train.py | 45 ++++++----- 6 files changed, 146 insertions(+), 28 deletions(-) create mode 100644 torchtitan/experiments/auto_parallel/README.md create mode 100644 torchtitan/experiments/auto_parallel/__init__.py create mode 100644 torchtitan/experiments/auto_parallel/parallelize_llama.py diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 084c2c4ff..e93a6561e 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -354,7 +354,7 @@ def log( global_max_loss: float, extra_metrics: dict[str, Any] | None = None, ): - assert self.num_flops_per_token > 0, "num_flops_per_token must be set" + # assert self.num_flops_per_token > 0, "num_flops_per_token must be set" time_delta = time.perf_counter() - self.time_last_log @@ -365,8 +365,8 @@ def log( # model FLOPS utilization # For its definition and calculation, please refer to the PaLM paper: # https://arxiv.org/abs/2204.02311 - mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops - tflops = self.num_flops_per_token * tps / 1e12 + # mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops + # tflops = self.num_flops_per_token * tps / 1e12 time_end_to_end = time_delta / self.job_config.metrics.log_freq time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times) @@ -378,8 +378,8 @@ def log( "loss_metrics/global_avg_loss": global_avg_loss, "loss_metrics/global_max_loss": global_max_loss, "throughput(tps)": tps, - "tflops": tflops, - "mfu(%)": mfu, + # "tflops": tflops, + # "mfu(%)": mfu, "time_metrics/end_to_end(s)": time_end_to_end, "time_metrics/data_loading(s)": time_data_loading, "time_metrics/data_loading(%)": time_data_loading_pct, @@ -403,8 +403,8 @@ def log( f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " - f"{color.cyan}tflops: {tflops:,.2f} " - f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + # f"{color.cyan}tflops: {tflops:,.2f} " + # f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) self.ntokens_since_last_log = 0 diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 4c54bdc13..b7ff983e9 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,5 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torchtitan.experiments.auto_parallel # noqa: F401 import torchtitan.experiments.llama4 # noqa: F401 import torchtitan.experiments.simple_fsdp # noqa: F401 diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md new file mode 100644 index 000000000..ef66a5916 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/README.md @@ -0,0 +1,7 @@ +## Auto Parallel + +requires installing git@github.com:pytorch-labs/autoparallel.git + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` + +(or llama3-8b.toml) diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py new file mode 100644 index 000000000..7a2682a9f --- /dev/null +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from .parallelize_llama import parallelize_llama + +register_train_spec( + TrainSpec( + name="llama3_auto_parallel", + cls=Transformer, + config=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py new file mode 100644 index 000000000..227935108 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed import DeviceMesh +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model_fn, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + # model = model.to_empty(device="cuda") + + # TODO: make auto-p work with already created model object? + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = world_mesh["dp"].size() + global_batch_size = job_config.training.local_batch_size * dp_degree + return torch.rand( + (global_batch_size, job_config.training.seq_len), device="cuda" + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert ( + len(world_mesh.shape) == 2 + ), "Only support 2D mesh (DP, TP) for now- OK if one has size=1" + assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet" + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + autop = AutoParallel(model_fn, input_fn, world_mesh) + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0), Replicate()) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if job_config.training.compile: + torch._inductor.config.reorder_for_peak_memory = False + parallel_mod = torch.compile(parallel_mod, fullgraph=True) + + return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index 9340671d7..051cdd5d5 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -24,7 +24,8 @@ ) from torchtitan.config_manager import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.protocols.model_converter import build_model_converters + +# from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import ( @@ -138,20 +139,22 @@ def __init__(self, job_config: JobConfig): ) # build model (using meta init) - model_cls = self.train_spec.cls model_args = self.train_spec.config[job_config.model.flavor] + model_cls = self.train_spec.cls # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) - logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - with torch.device("meta"): - model = model_cls.from_model_args(model_args) + def model_fn(): + return model_cls.from_model_args(model_args).cuda() + + # with torch.device("meta"): + # model = model_fn() # Build the collection of model converters. No-op if `model.converters` empty - model_converters = build_model_converters(job_config, parallel_dims) - model_converters.convert(model) + # model_converters = build_model_converters(job_config, parallel_dims) + # model_converters.convert(model) # metrics logging build_metrics_processor_fn = ( @@ -165,15 +168,15 @@ def __init__(self, job_config: JobConfig): color = self.metrics_processor.color # calculate model size and flops per token - ( - model_param_count, - self.metrics_processor.num_flops_per_token, - ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + # ( + # model_param_count, + # self.metrics_processor.num_flops_per_token, + # ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) - logger.info( - f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " - f"{color.red}size: {model_param_count:,} total parameters{color.reset}" - ) + # logger.info( + # f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + # f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + # ) # move sharded model to CPU/GPU and initialize weights via DTensor if job_config.checkpoint.create_seed_checkpoint: @@ -251,7 +254,7 @@ def __init__(self, job_config: JobConfig): else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel model = self.train_spec.parallelize_fn( - model, world_mesh, parallel_dims, job_config + model_fn, world_mesh, parallel_dims, job_config ) model.to_empty(device=init_device) @@ -288,11 +291,11 @@ def __init__(self, job_config: JobConfig): # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # where it issues a single all-reduce for all parameters at once for better performance - self.optimizers.register_step_post_hook( - lambda *args, **kwargs: model_converters.post_optimizer_hook( - self.model_parts - ) - ) + # self.optimizers.register_step_post_hook( + # lambda *args, **kwargs: model_converters.post_optimizer_hook( + # self.model_parts + # ) + # ) self.metrics_processor.optimizers = self.optimizers # Initialize trainer states that will be saved in checkpoint. From 42d5da6db85bd8b0cbab3854c5462fd041282ed8 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 16 Jun 2025 16:04:11 -0700 Subject: [PATCH 2/9] Hack an init_fn for llama3 and observe loss decreasing with autoparallel """ [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """ --- torchtitan/components/metrics.py | 1 + .../auto_parallel/parallelize_llama.py | 15 +++- torchtitan/train.py | 79 ++++++++++++++++++- 3 files changed, 89 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index e93a6561e..d87747cda 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -405,6 +405,7 @@ def log( f"{color.blue}tps: {round(tps):,} " # f"{color.cyan}tflops: {tflops:,.2f} " # f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + f"{color.reset}" ) self.ntokens_since_last_log = 0 diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 227935108..acc5f5764 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -21,6 +21,7 @@ def parallelize_llama( model_fn, + init_fn, # TODO(whc) hack to pass stuff to autoparallel world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, @@ -32,9 +33,14 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - # model = model.to_empty(device="cuda") - # TODO: make auto-p work with already created model object? + # wherever the auto-parallel code that creates a FakeTensorMode is... + # fake_mode = ... + # for k, v in m.named_parameters(): + # # swap each param in your model with a fake tensor version + # # warning - we probably need to do this before initializing the optimizer? + # setattr(m, k, fake_mode.from_tensor(v)) + # # also do the same for named_buffers def input_fn(): global_batch_size = job_config.training.global_batch_size @@ -56,6 +62,10 @@ def input_fn(): assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" + # bail out + # model = model_fn() + # return model + autop = AutoParallel(model_fn, input_fn, world_mesh) autop.add_parameter_memory_constraint(low=None, high=None) @@ -73,4 +83,5 @@ def input_fn(): torch._inductor.config.reorder_for_peak_memory = False parallel_mod = torch.compile(parallel_mod, fullgraph=True) + init_fn(parallel_mod) return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index 051cdd5d5..f7af6a869 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -148,8 +148,79 @@ def __init__(self, job_config: JobConfig): ) def model_fn(): + # WHC - allow auto_p to construct the model object under its own fake_mode. + # TODO: let us pass in meta model, and internally hook it up to the auto_p fake mode return model_cls.from_model_args(model_args).cuda() + def init_fn(model): + # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying + # code from the llama3 init_weights functions throughout the model components, and adjusting them to use + # the new FQN structures in autoparallel. + # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module + def param(name): + return model.get_parameter(f"params.{name}") + + from torchtitan.models.llama3.model import precompute_freqs_cis + + model.buffers_.get_buffer("freqs_cis").copy_( + precompute_freqs_cis( + model_args.dim // model_args.n_heads, + model_args.max_seq_len, + model_args.rope_theta, + ) + ) + + torch.nn.init.normal_(param("tok_embeddings/weight")) + + def init_layer(i): + for norm in ("attention_norm", "ffn_norm"): + torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight")) + + if model_args.depth_init: + weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5 + else: + weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + for linear in ("wq", "wk", "wv"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/{linear}/weight"), + mean=0.0, + std=0.02, + ) + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/wo/weight"), + mean=0.0, + std=weight_init_std, + ) + + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02 + ) + for linear in ("w2", "w3"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/{linear}/weight"), + mean=0.0, + std=weight_init_std, + ) + + for i in range(model_args.n_layers): + init_layer(i) + + if param("norm/weight") is not None: + torch.nn.init.ones_(param("norm/weight")) + + final_out_std = model_args.dim**-0.5 + cutoff_factor = 3 + + if param("output/weight") is not None: + torch.nn.init.trunc_normal_( + param("output/weight"), + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + # with torch.device("meta"): # model = model_fn() # Build the collection of model converters. No-op if `model.converters` empty @@ -254,12 +325,12 @@ def model_fn(): else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel model = self.train_spec.parallelize_fn( - model_fn, world_mesh, parallel_dims, job_config + model_fn, init_fn, world_mesh, parallel_dims, job_config ) - model.to_empty(device=init_device) - with torch.no_grad(): - model.init_weights(buffer_device=buffer_device) + # model.to_empty(device=init_device) + # with torch.no_grad(): + # model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model] From c8fb6b54f39e93a5f790eb09c240ff7c2ac8a776 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 18 Jun 2025 16:19:20 -0700 Subject: [PATCH 3/9] Adopt new autoparallel API with meta-init model Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop. --- torchtitan/components/metrics.py | 15 +++-- .../auto_parallel/parallelize_llama.py | 15 +---- torchtitan/train.py | 55 +++++++++---------- 3 files changed, 35 insertions(+), 50 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index d87747cda..084c2c4ff 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -354,7 +354,7 @@ def log( global_max_loss: float, extra_metrics: dict[str, Any] | None = None, ): - # assert self.num_flops_per_token > 0, "num_flops_per_token must be set" + assert self.num_flops_per_token > 0, "num_flops_per_token must be set" time_delta = time.perf_counter() - self.time_last_log @@ -365,8 +365,8 @@ def log( # model FLOPS utilization # For its definition and calculation, please refer to the PaLM paper: # https://arxiv.org/abs/2204.02311 - # mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops - # tflops = self.num_flops_per_token * tps / 1e12 + mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops + tflops = self.num_flops_per_token * tps / 1e12 time_end_to_end = time_delta / self.job_config.metrics.log_freq time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times) @@ -378,8 +378,8 @@ def log( "loss_metrics/global_avg_loss": global_avg_loss, "loss_metrics/global_max_loss": global_max_loss, "throughput(tps)": tps, - # "tflops": tflops, - # "mfu(%)": mfu, + "tflops": tflops, + "mfu(%)": mfu, "time_metrics/end_to_end(s)": time_end_to_end, "time_metrics/data_loading(s)": time_data_loading, "time_metrics/data_loading(%)": time_data_loading_pct, @@ -403,9 +403,8 @@ def log( f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " - # f"{color.cyan}tflops: {tflops:,.2f} " - # f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" - f"{color.reset}" + f"{color.cyan}tflops: {tflops:,.2f} " + f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) self.ntokens_since_last_log = 0 diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index acc5f5764..31830613a 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -20,8 +20,7 @@ def parallelize_llama( - model_fn, - init_fn, # TODO(whc) hack to pass stuff to autoparallel + model, world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, @@ -33,15 +32,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - # TODO: make auto-p work with already created model object? - # wherever the auto-parallel code that creates a FakeTensorMode is... - # fake_mode = ... - # for k, v in m.named_parameters(): - # # swap each param in your model with a fake tensor version - # # warning - we probably need to do this before initializing the optimizer? - # setattr(m, k, fake_mode.from_tensor(v)) - # # also do the same for named_buffers - def input_fn(): global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: @@ -66,7 +56,7 @@ def input_fn(): # model = model_fn() # return model - autop = AutoParallel(model_fn, input_fn, world_mesh) + autop = AutoParallel(model, input_fn, world_mesh, device=world_mesh.device_type) autop.add_parameter_memory_constraint(low=None, high=None) x_sharding = (Shard(0), Replicate()) @@ -83,5 +73,4 @@ def input_fn(): torch._inductor.config.reorder_for_peak_memory = False parallel_mod = torch.compile(parallel_mod, fullgraph=True) - init_fn(parallel_mod) return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index f7af6a869..ac44a73fd 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,7 +12,6 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record - import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager @@ -25,7 +24,7 @@ from torchtitan.config_manager import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils -# from torchtitan.protocols.model_converter import build_model_converters +from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import ( @@ -147,12 +146,8 @@ def __init__(self, job_config: JobConfig): f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - def model_fn(): - # WHC - allow auto_p to construct the model object under its own fake_mode. - # TODO: let us pass in meta model, and internally hook it up to the auto_p fake mode - return model_cls.from_model_args(model_args).cuda() - def init_fn(model): + def llama3_autoparallel_init_fn(model): # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying # code from the llama3 init_weights functions throughout the model components, and adjusting them to use # the new FQN structures in autoparallel. @@ -221,11 +216,11 @@ def init_layer(i): b=cutoff_factor * final_out_std, ) - # with torch.device("meta"): - # model = model_fn() - # Build the collection of model converters. No-op if `model.converters` empty - # model_converters = build_model_converters(job_config, parallel_dims) - # model_converters.convert(model) + with torch.device("meta"): + model = model_cls.from_model_args(model_args) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # metrics logging build_metrics_processor_fn = ( @@ -239,15 +234,15 @@ def init_layer(i): color = self.metrics_processor.color # calculate model size and flops per token - # ( - # model_param_count, - # self.metrics_processor.num_flops_per_token, - # ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + ( + model_param_count, + self.metrics_processor.num_flops_per_token, + ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) - # logger.info( - # f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " - # f"{color.red}size: {model_param_count:,} total parameters{color.reset}" - # ) + logger.info( + f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + ) # move sharded model to CPU/GPU and initialize weights via DTensor if job_config.checkpoint.create_seed_checkpoint: @@ -325,12 +320,14 @@ def init_layer(i): else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel model = self.train_spec.parallelize_fn( - model_fn, init_fn, world_mesh, parallel_dims, job_config + model, world_mesh, parallel_dims, job_config ) - # model.to_empty(device=init_device) - # with torch.no_grad(): - # model.init_weights(buffer_device=buffer_device) + model.to_empty(device=init_device) + with torch.no_grad(): + # TODO(whc) make model.init_weights work with autoparallel + llama3_autoparallel_init_fn(model) + # model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model] @@ -362,11 +359,11 @@ def init_layer(i): # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # where it issues a single all-reduce for all parameters at once for better performance - # self.optimizers.register_step_post_hook( - # lambda *args, **kwargs: model_converters.post_optimizer_hook( - # self.model_parts - # ) - # ) + self.optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.model_parts + ) + ) self.metrics_processor.optimizers = self.optimizers # Initialize trainer states that will be saved in checkpoint. From 0ec2b2fcc452bf7d56f7ad7fa8c212898bd2b331 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 25 Jun 2025 17:10:09 -0700 Subject: [PATCH 4/9] Fixes to align with latest autoparallel --- .../experiments/auto_parallel/parallelize_llama.py | 2 +- torchtitan/train.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 31830613a..af279de59 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -56,7 +56,7 @@ def input_fn(): # model = model_fn() # return model - autop = AutoParallel(model, input_fn, world_mesh, device=world_mesh.device_type) + autop = AutoParallel(model, input_fn, world_mesh) autop.add_parameter_memory_constraint(low=None, high=None) x_sharding = (Shard(0), Replicate()) diff --git a/torchtitan/train.py b/torchtitan/train.py index ac44a73fd..2734fc8e1 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -23,7 +23,7 @@ ) from torchtitan.config_manager import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils - +from torch.distributed.tensor import DTensor from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -158,10 +158,13 @@ def param(name): from torchtitan.models.llama3.model import precompute_freqs_cis model.buffers_.get_buffer("freqs_cis").copy_( - precompute_freqs_cis( - model_args.dim // model_args.n_heads, - model_args.max_seq_len, - model_args.rope_theta, + DTensor.from_local( + precompute_freqs_cis( + model_args.dim // model_args.n_heads, + model_args.max_seq_len, + model_args.rope_theta, + ), + device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh, ) ) From 40340a86f4d8eed603cf0be50c8494df6d29c068 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 1 Jul 2025 14:09:43 -0700 Subject: [PATCH 5/9] Add inductor config knobs for comms optimizations to torchtitan --- torchtitan/config_manager.py | 22 ++++++++++++++++++++++ torchtitan/train.py | 11 +++++++++++ 2 files changed, 33 insertions(+) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index bb4543ac1..eda0d5690 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -576,6 +576,28 @@ class Experimental: needs to ensure that the path can be imported. """ + reorder_for_compute_comm_overlap: bool = False + """ + Whether to enable inductor comm reordering passes + """ + + reorder_for_compute_comm_overlap_passes: list[str] = field( + default_factory=lambda: [ + "sink_waits", + "reorder_communication_preserving_peak_memory", + ] + ) + """ + Sequence of reordering passes (names of functions inside _inductor.comms) to call, + if reorder_for_compute_comm_overlap is enabled. + """ + + reorder_prefetch_limit: int | None = None + """ + How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory' + pass is enabled. default of None means unlimited + """ + @dataclass class JobConfig: diff --git a/torchtitan/train.py b/torchtitan/train.py index 2734fc8e1..50c5424cb 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -113,6 +113,17 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # allow configuring inductor comms optimizations from torchtitan commandline + torch._inductor.config.reorder_for_compute_comm_overlap = ( + job_config.experimental.reorder_for_compute_comm_overlap + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( + job_config.experimental.reorder_for_compute_comm_overlap_passes + ) + torch._inductor.config.reorder_prefetch_limit = ( + job_config.experimental.reorder_prefetch_limit + ) + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( From d8b9802962c700f0a13e7e27480ca4b764ae2745 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 2 Jul 2025 14:56:15 -0700 Subject: [PATCH 6/9] Make inductor always run compile passes basically, this is an annoying workaround for debugging iteratively. 1- you run the model, it compiles, but something weird happens 2- you enable some logging or tlparse, rerun. but inductor decides not to run your pass anymore, its results are cached. since (2) has confused me horribly on more than one occasion, i just disable caching for now --- torchtitan/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchtitan/train.py b/torchtitan/train.py index 50c5424cb..2ad6a30ac 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -113,6 +113,10 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # TODO(whc) + # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches=True + # allow configuring inductor comms optimizations from torchtitan commandline torch._inductor.config.reorder_for_compute_comm_overlap = ( job_config.experimental.reorder_for_compute_comm_overlap From 8739c23e013ede1d387367309c328cd404fc6cd9 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 2 Jul 2025 16:19:57 -0700 Subject: [PATCH 7/9] Drop hacky llama3_init_fn and use autop init_weights feature Relying on https://github.com/pytorch-labs/autoparallel/pull/20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ``` --- torchtitan/train.py | 77 +-------------------------------------------- 1 file changed, 1 insertion(+), 76 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 2ad6a30ac..6d1ead6c9 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -161,79 +161,6 @@ def __init__(self, job_config: JobConfig): f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - - def llama3_autoparallel_init_fn(model): - # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying - # code from the llama3 init_weights functions throughout the model components, and adjusting them to use - # the new FQN structures in autoparallel. - # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module - def param(name): - return model.get_parameter(f"params.{name}") - - from torchtitan.models.llama3.model import precompute_freqs_cis - - model.buffers_.get_buffer("freqs_cis").copy_( - DTensor.from_local( - precompute_freqs_cis( - model_args.dim // model_args.n_heads, - model_args.max_seq_len, - model_args.rope_theta, - ), - device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh, - ) - ) - - torch.nn.init.normal_(param("tok_embeddings/weight")) - - def init_layer(i): - for norm in ("attention_norm", "ffn_norm"): - torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight")) - - if model_args.depth_init: - weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5 - else: - weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 - - for linear in ("wq", "wk", "wv"): - torch.nn.init.trunc_normal_( - param(f"layers/{i}/attention/{linear}/weight"), - mean=0.0, - std=0.02, - ) - torch.nn.init.trunc_normal_( - param(f"layers/{i}/attention/wo/weight"), - mean=0.0, - std=weight_init_std, - ) - - torch.nn.init.trunc_normal_( - param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02 - ) - for linear in ("w2", "w3"): - torch.nn.init.trunc_normal_( - param(f"layers/{i}/feed_forward/{linear}/weight"), - mean=0.0, - std=weight_init_std, - ) - - for i in range(model_args.n_layers): - init_layer(i) - - if param("norm/weight") is not None: - torch.nn.init.ones_(param("norm/weight")) - - final_out_std = model_args.dim**-0.5 - cutoff_factor = 3 - - if param("output/weight") is not None: - torch.nn.init.trunc_normal_( - param("output/weight"), - mean=0.0, - std=final_out_std, - a=-cutoff_factor * final_out_std, - b=cutoff_factor * final_out_std, - ) - with torch.device("meta"): model = model_cls.from_model_args(model_args) # Build the collection of model converters. No-op if `model.converters` empty @@ -343,9 +270,7 @@ def init_layer(i): model.to_empty(device=init_device) with torch.no_grad(): - # TODO(whc) make model.init_weights work with autoparallel - llama3_autoparallel_init_fn(model) - # model.init_weights(buffer_device=buffer_device) + model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model] From 38ee98c0c93857e7841af1148441ed7701974073 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 2 Jul 2025 16:22:32 -0700 Subject: [PATCH 8/9] fix lint --- torchtitan/experiments/auto_parallel/parallelize_llama.py | 1 + torchtitan/train.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index af279de59..bb7f1204d 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -32,6 +32,7 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ + def input_fn(): global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: diff --git a/torchtitan/train.py b/torchtitan/train.py index 6d1ead6c9..c82e534b3 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,6 +12,8 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.tensor import DTensor + import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager @@ -23,7 +25,6 @@ ) from torchtitan.config_manager import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils -from torch.distributed.tensor import DTensor from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -115,7 +116,7 @@ def __init__(self, job_config: JobConfig): # TODO(whc) # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering - torch._inductor.config.force_disable_caches=True + torch._inductor.config.force_disable_caches = True # allow configuring inductor comms optimizations from torchtitan commandline torch._inductor.config.reorder_for_compute_comm_overlap = ( From 56400c2386e9b03cd1b059fdc33fbbad356d2aeb Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 9 Jul 2025 17:08:25 -0700 Subject: [PATCH 9/9] add float8 support --- torchtitan/experiments/auto_parallel/parallelize_llama.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index bb7f1204d..88463b74c 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -53,6 +53,14 @@ def input_fn(): assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" + # TODO: there are multiple float8 recipes, this just hardcodes one + enable_float8_linear = "float8" in job_config.model.converters + if enable_float8_linear: + import copy + from torchao.float8.float8_linear_utils import convert_to_float8_training + from torchao.float8.config import Float8LinearConfig + model = convert_to_float8_training(copy.deepcopy(model), config=Float8LinearConfig()) + # bail out # model = model_fn() # return model