|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import importlib |
| 8 | +import os |
| 9 | +from typing import Generator |
| 10 | + |
| 11 | +import torch |
| 12 | +from torch.distributed.elastic.multiprocessing.errors import record |
| 13 | + |
| 14 | +import torchtitan.protocols.train_spec as train_spec_module |
| 15 | +from torchtitan.components.checkpoint import CheckpointManager |
| 16 | +from torchtitan.components.ft import FTManager, init_ft_manager |
| 17 | +from torchtitan.components.loss import rescale_accumulated_loss |
| 18 | +from torchtitan.distributed import ParallelDims, utils as dist_utils |
| 19 | +from torchtitan.protocols.model_converter import build_model_converters |
| 20 | +from torchtitan.protocols.train_spec import BaseModelArgs |
| 21 | +from torchtitan.tools import utils |
| 22 | + |
| 23 | +from .job_config import ForgeJobConfig |
| 24 | +from .train_spec import ForgeTrainSpec, get_train_spec |
| 25 | + |
| 26 | + |
| 27 | +class ForgeEngine(torch.distributed.checkpoint.stateful.Stateful): |
| 28 | + # core configs |
| 29 | + job_config: ForgeJobConfig |
| 30 | + parallel_dims: ParallelDims |
| 31 | + train_spec: ForgeTrainSpec |
| 32 | + |
| 33 | + # swappable training components in ForgeTrainSpec |
| 34 | + model_parts: list[torch.nn.Module] |
| 35 | + loss_fn: train_spec_module.LossFunction |
| 36 | + optimizers: train_spec_module.OptimizersContainer |
| 37 | + lr_schedulers: train_spec_module.LRSchedulersContainer |
| 38 | + |
| 39 | + # non-swappable training components |
| 40 | + checkpointer: CheckpointManager |
| 41 | + ft_manager: FTManager |
| 42 | + |
| 43 | + # runtime utilities |
| 44 | + device: torch.device |
| 45 | + gc_handler: utils.GarbageCollection |
| 46 | + gradient_accumulation_steps: int |
| 47 | + train_context: Generator[None, None, None] |
| 48 | + pp_has_first_stage: bool |
| 49 | + pp_has_last_stage: bool |
| 50 | + |
| 51 | + # Fields in ForgeEngine which are not in original Trainer |
| 52 | + # for dataloading |
| 53 | + tokenizer: train_spec_module.BaseTokenizer | None |
| 54 | + dp_degree: int |
| 55 | + dp_rank: int |
| 56 | + # for logging |
| 57 | + model_args: BaseModelArgs |
| 58 | + num_flops_per_token: float |
| 59 | + model_param_count: int |
| 60 | + global_batch_size: int |
| 61 | + |
| 62 | + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html |
| 63 | + @record |
| 64 | + def __init__(self, job_config: ForgeJobConfig): |
| 65 | + torch._C._log_api_usage_once("torchtitan.train") |
| 66 | + |
| 67 | + self.job_config = job_config |
| 68 | + |
| 69 | + if job_config.experimental.custom_import: |
| 70 | + importlib.import_module(job_config.experimental.custom_import) |
| 71 | + |
| 72 | + device_module, device_type = utils.device_module, utils.device_type |
| 73 | + self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") |
| 74 | + # Device has to be set before creating TorchFT manager. |
| 75 | + device_module.set_device(self.device) |
| 76 | + |
| 77 | + # init distributed and build meshes |
| 78 | + dist_utils.init_distributed(job_config) |
| 79 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 80 | + parallelism_config = job_config.parallelism |
| 81 | + self.parallel_dims = parallel_dims = ParallelDims( |
| 82 | + dp_shard=parallelism_config.data_parallel_shard_degree, |
| 83 | + dp_replicate=parallelism_config.data_parallel_replicate_degree, |
| 84 | + cp=parallelism_config.context_parallel_degree, |
| 85 | + tp=parallelism_config.tensor_parallel_degree, |
| 86 | + pp=parallelism_config.pipeline_parallel_degree, |
| 87 | + ep=parallelism_config.expert_parallel_degree, |
| 88 | + world_size=world_size, |
| 89 | + ) |
| 90 | + |
| 91 | + world_mesh = parallel_dims.world_mesh |
| 92 | + if parallel_dims.dp_enabled: |
| 93 | + dp_mesh = world_mesh["dp"] |
| 94 | + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() |
| 95 | + else: |
| 96 | + dp_degree, dp_rank = 1, 0 |
| 97 | + |
| 98 | + self.ft_manager = init_ft_manager(job_config) |
| 99 | + # If TorchFT is enabled, the dp_rank and dp_degree, which are used for |
| 100 | + # dataloader must be changed. |
| 101 | + if self.ft_manager.enabled: |
| 102 | + dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) |
| 103 | + |
| 104 | + self.dp_degree, self.dp_rank = dp_degree, dp_rank |
| 105 | + |
| 106 | + # take control of garbage collection to avoid stragglers |
| 107 | + self.gc_handler = utils.GarbageCollection( |
| 108 | + gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug |
| 109 | + ) |
| 110 | + |
| 111 | + # Set random seed, and maybe enable deterministic mode |
| 112 | + # (mainly for debugging, expect perf loss). |
| 113 | + dist_utils.set_determinism( |
| 114 | + world_mesh, |
| 115 | + self.device, |
| 116 | + job_config.training.seed, |
| 117 | + job_config.training.deterministic, |
| 118 | + ) |
| 119 | + self.train_spec = get_train_spec(job_config.model.name) |
| 120 | + |
| 121 | + # build tokenizer |
| 122 | + self.tokenizer = tokenizer = ( |
| 123 | + self.train_spec.build_tokenizer_fn(job_config) |
| 124 | + if self.train_spec.build_tokenizer_fn is not None |
| 125 | + else None |
| 126 | + ) |
| 127 | + |
| 128 | + # build model (using meta init) |
| 129 | + self.model_args = model_args = self.train_spec.model_args[ |
| 130 | + job_config.model.flavor |
| 131 | + ] |
| 132 | + # set the model args from training job configs |
| 133 | + model_args.update_from_config(job_config, tokenizer) |
| 134 | + |
| 135 | + with torch.device("meta"): |
| 136 | + model = self.train_spec.model_cls(model_args) |
| 137 | + |
| 138 | + # Build the collection of model converters. No-op if `model.converters` empty |
| 139 | + model_converters = build_model_converters(job_config, parallel_dims) |
| 140 | + model_converters.convert(model) |
| 141 | + |
| 142 | + # calculate model size and flops per token |
| 143 | + ( |
| 144 | + self.model_param_count, |
| 145 | + self.num_flops_per_token, |
| 146 | + ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) |
| 147 | + |
| 148 | + # move sharded model to CPU/GPU and initialize weights via DTensor |
| 149 | + if job_config.checkpoint.create_seed_checkpoint: |
| 150 | + init_device = "cpu" |
| 151 | + buffer_device = None |
| 152 | + elif job_config.training.enable_cpu_offload: |
| 153 | + init_device = "cpu" |
| 154 | + buffer_device = device_type |
| 155 | + else: |
| 156 | + init_device = device_type |
| 157 | + buffer_device = None |
| 158 | + |
| 159 | + self.loss_fn = self.train_spec.build_loss_fn(job_config) |
| 160 | + |
| 161 | + # verify batch sizes |
| 162 | + global_batch_size = job_config.training.global_batch_size |
| 163 | + if global_batch_size < 0: |
| 164 | + # This global batch size results in 1 gradient accumulation |
| 165 | + # step. |
| 166 | + global_batch_size = job_config.training.local_batch_size * dp_degree |
| 167 | + assert global_batch_size > 0 |
| 168 | + assert ( |
| 169 | + global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 |
| 170 | + ), ( |
| 171 | + f"global batch size must be multiple of local batch size times " |
| 172 | + f"data-parallel degree ({global_batch_size} " |
| 173 | + f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" |
| 174 | + ) |
| 175 | + self.global_batch_size = global_batch_size |
| 176 | + |
| 177 | + # calculate gradient accumulation steps |
| 178 | + self.gradient_accumulation_steps = global_batch_size // ( |
| 179 | + job_config.training.local_batch_size * dp_degree |
| 180 | + ) |
| 181 | + assert self.gradient_accumulation_steps > 0 |
| 182 | + self.loss_fn = rescale_accumulated_loss( |
| 183 | + self.loss_fn, self.gradient_accumulation_steps |
| 184 | + ) |
| 185 | + |
| 186 | + # apply parallelisms and initialization |
| 187 | + if parallel_dims.pp_enabled: |
| 188 | + if not self.train_spec.pipelining_fn: |
| 189 | + raise RuntimeError( |
| 190 | + f"Pipeline Parallel is enabled but {self.train_spec.name} " |
| 191 | + f"does not support pipelining" |
| 192 | + ) |
| 193 | + |
| 194 | + # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques |
| 195 | + ( |
| 196 | + self.pp_schedule, |
| 197 | + self.model_parts, |
| 198 | + self.pp_has_first_stage, |
| 199 | + self.pp_has_last_stage, |
| 200 | + ) = self.train_spec.pipelining_fn( |
| 201 | + model, |
| 202 | + parallel_dims, |
| 203 | + job_config, |
| 204 | + self.device, |
| 205 | + model_args, |
| 206 | + self.train_spec.parallelize_fn, |
| 207 | + self.loss_fn, |
| 208 | + ) |
| 209 | + # when PP is enabled, `model` obj is no longer used after this point, |
| 210 | + # model_parts is used instead |
| 211 | + del model |
| 212 | + |
| 213 | + for m in self.model_parts: |
| 214 | + m.to_empty(device=init_device) |
| 215 | + with torch.no_grad(): |
| 216 | + m.init_weights(buffer_device=buffer_device) |
| 217 | + m.train() |
| 218 | + else: |
| 219 | + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel |
| 220 | + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) |
| 221 | + |
| 222 | + model.to_empty(device=init_device) |
| 223 | + with torch.no_grad(): |
| 224 | + model.init_weights(buffer_device=buffer_device) |
| 225 | + model.train() |
| 226 | + |
| 227 | + self.model_parts = [model] |
| 228 | + |
| 229 | + if ( |
| 230 | + self.ft_manager.enabled |
| 231 | + and job_config.fault_tolerance.semi_sync_method is None |
| 232 | + ): |
| 233 | + self.ft_manager.set_all_reduce_hook(self.model_parts) |
| 234 | + |
| 235 | + # build optimizer after applying parallelisms to the model |
| 236 | + self.optimizers = self.train_spec.build_optimizers_fn( |
| 237 | + self.model_parts, job_config, parallel_dims, self.ft_manager |
| 238 | + ) |
| 239 | + self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( |
| 240 | + self.optimizers, job_config |
| 241 | + ) |
| 242 | + # Post optimizer step model converters hook. |
| 243 | + # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 |
| 244 | + # where it issues a single all-reduce for all parameters at once for better performance |
| 245 | + self.optimizers.register_step_post_hook( |
| 246 | + lambda *args, **kwargs: model_converters.post_optimizer_hook( |
| 247 | + self.model_parts |
| 248 | + ) |
| 249 | + ) |
| 250 | + |
| 251 | + self.checkpointer = CheckpointManager( |
| 252 | + dataloader=None, |
| 253 | + model_parts=self.model_parts, |
| 254 | + optimizers=self.optimizers, |
| 255 | + lr_schedulers=self.lr_schedulers, |
| 256 | + states={"train_state": self}, |
| 257 | + job_config=job_config, |
| 258 | + ft_manager=self.ft_manager, |
| 259 | + ) |
| 260 | + |
| 261 | + loss_parallel_enabled = ( |
| 262 | + parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel |
| 263 | + ) |
| 264 | + self.train_context = dist_utils.get_train_context( |
| 265 | + loss_parallel_enabled, |
| 266 | + parallelism_config.enable_compiled_autograd, |
| 267 | + ) |
| 268 | + self.maybe_enable_amp = dist_utils.maybe_enable_amp( |
| 269 | + parallel_dims, |
| 270 | + job_config.training.mixed_precision_param, |
| 271 | + device_type, |
| 272 | + ) |
| 273 | + |
| 274 | + def close(self) -> None: |
| 275 | + if self.checkpointer: |
| 276 | + self.checkpointer.close() |
0 commit comments