-
Notifications
You must be signed in to change notification settings - Fork 427
add the forge folder #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add the forge folder #1387
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
## `ForgeEngine` | ||
|
||
The `forge` folder contains a lightweight training engine that serves as a streamlined subset of the `Trainer` class from [torchtitan/train.py](/torchtitan/train.py). This engine provides only the essential constructor method, making it highly flexible for various downstream applications. | ||
|
||
The [`ForgeEngine`](engine.py) takes a [`ForgeJobConfig`](job_config.py) to | ||
- Initialize an SPMD distributed training environment | ||
- Construct and scale models via n-D parallelisms and meta-device initialization | ||
- Provide necessary training components and utilities | ||
|
||
**Primary Use Case**: The engine is designed for building trainers in post-training workflows where multiple specialized components (trainer, generator, replay buffer, parameter server, etc.) work together. | ||
|
||
Additionally, the folder provides a train spec registration method [`register_train_spec`](train_spec.py) that allows users to extend beyond the core set of models and training components available in torchtitan, enabling greater flexibility and customization for specific training requirements. | ||
|
||
The [example_train.py](./example_train.py) demonstrates how to use `ForgeEngine` for pretraining, achieving the same functionality as [torchtitan/train.py](/torchtitan/train.py). |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .engine import ForgeEngine | ||
from .job_config import ForgeJobConfig | ||
from .train_spec import ForgeTrainSpec, register_train_spec | ||
|
||
__all__ = ["ForgeEngine", "ForgeJobConfig", "ForgeTrainSpec", "register_train_spec"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import importlib | ||
import os | ||
from typing import Generator | ||
|
||
import torch | ||
from torch.distributed.elastic.multiprocessing.errors import record | ||
|
||
import torchtitan.protocols.train_spec as train_spec_module | ||
from torchtitan.components.checkpoint import CheckpointManager | ||
from torchtitan.components.ft import FTManager, init_ft_manager | ||
from torchtitan.components.loss import rescale_accumulated_loss | ||
from torchtitan.distributed import ParallelDims, utils as dist_utils | ||
from torchtitan.protocols.model_converter import build_model_converters | ||
from torchtitan.protocols.train_spec import BaseModelArgs | ||
from torchtitan.tools import utils | ||
|
||
from .job_config import ForgeJobConfig | ||
from .train_spec import ForgeTrainSpec, get_train_spec | ||
|
||
|
||
class ForgeEngine(torch.distributed.checkpoint.stateful.Stateful): | ||
# core configs | ||
job_config: ForgeJobConfig | ||
parallel_dims: ParallelDims | ||
train_spec: ForgeTrainSpec | ||
|
||
# swappable training components in ForgeTrainSpec | ||
model_parts: list[torch.nn.Module] | ||
loss_fn: train_spec_module.LossFunction | ||
optimizers: train_spec_module.OptimizersContainer | ||
lr_schedulers: train_spec_module.LRSchedulersContainer | ||
|
||
# non-swappable training components | ||
checkpointer: CheckpointManager | ||
ft_manager: FTManager | ||
|
||
# runtime utilities | ||
device: torch.device | ||
gc_handler: utils.GarbageCollection | ||
gradient_accumulation_steps: int | ||
train_context: Generator[None, None, None] | ||
pp_has_first_stage: bool | ||
pp_has_last_stage: bool | ||
|
||
# Fields in ForgeEngine which are not in original Trainer | ||
# for dataloading | ||
tokenizer: train_spec_module.BaseTokenizer | None | ||
dp_degree: int | ||
dp_rank: int | ||
# for logging | ||
model_args: BaseModelArgs | ||
num_flops_per_token: float | ||
model_param_count: int | ||
global_batch_size: int | ||
|
||
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html | ||
@record | ||
def __init__(self, job_config: ForgeJobConfig): | ||
torch._C._log_api_usage_once("torchtitan.train") | ||
|
||
self.job_config = job_config | ||
|
||
if job_config.experimental.custom_import: | ||
importlib.import_module(job_config.experimental.custom_import) | ||
|
||
device_module, device_type = utils.device_module, utils.device_type | ||
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") | ||
# Device has to be set before creating TorchFT manager. | ||
device_module.set_device(self.device) | ||
|
||
# init distributed and build meshes | ||
dist_utils.init_distributed(job_config) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
parallelism_config = job_config.parallelism | ||
self.parallel_dims = parallel_dims = ParallelDims( | ||
dp_shard=parallelism_config.data_parallel_shard_degree, | ||
dp_replicate=parallelism_config.data_parallel_replicate_degree, | ||
cp=parallelism_config.context_parallel_degree, | ||
tp=parallelism_config.tensor_parallel_degree, | ||
pp=parallelism_config.pipeline_parallel_degree, | ||
ep=parallelism_config.expert_parallel_degree, | ||
world_size=world_size, | ||
) | ||
|
||
world_mesh = parallel_dims.world_mesh | ||
if parallel_dims.dp_enabled: | ||
dp_mesh = world_mesh["dp"] | ||
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() | ||
else: | ||
dp_degree, dp_rank = 1, 0 | ||
|
||
self.ft_manager = init_ft_manager(job_config) | ||
# If TorchFT is enabled, the dp_rank and dp_degree, which are used for | ||
# dataloader must be changed. | ||
if self.ft_manager.enabled: | ||
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) | ||
Comment on lines
+98
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment here.. personally I would leave this out for the initial version (again, lmk if you disagree) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, will remove |
||
|
||
self.dp_degree, self.dp_rank = dp_degree, dp_rank | ||
|
||
# take control of garbage collection to avoid stragglers | ||
self.gc_handler = utils.GarbageCollection( | ||
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug | ||
) | ||
|
||
# Set random seed, and maybe enable deterministic mode | ||
# (mainly for debugging, expect perf loss). | ||
dist_utils.set_determinism( | ||
world_mesh, | ||
self.device, | ||
job_config.training.seed, | ||
job_config.training.deterministic, | ||
) | ||
self.train_spec = get_train_spec(job_config.model.name) | ||
|
||
# build tokenizer | ||
self.tokenizer = tokenizer = ( | ||
self.train_spec.build_tokenizer_fn(job_config) | ||
if self.train_spec.build_tokenizer_fn is not None | ||
else None | ||
) | ||
Comment on lines
+122
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be done downstream, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. discussed offline, probably makes more sense to keep for now |
||
|
||
# build model (using meta init) | ||
self.model_args = model_args = self.train_spec.model_args[ | ||
job_config.model.flavor | ||
] | ||
# set the model args from training job configs | ||
model_args.update_from_config(job_config, tokenizer) | ||
|
||
with torch.device("meta"): | ||
model = self.train_spec.model_cls(model_args) | ||
|
||
# Build the collection of model converters. No-op if `model.converters` empty | ||
model_converters = build_model_converters(job_config, parallel_dims) | ||
model_converters.convert(model) | ||
|
||
# calculate model size and flops per token | ||
( | ||
self.model_param_count, | ||
self.num_flops_per_token, | ||
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) | ||
|
||
# move sharded model to CPU/GPU and initialize weights via DTensor | ||
if job_config.checkpoint.create_seed_checkpoint: | ||
init_device = "cpu" | ||
buffer_device = None | ||
Comment on lines
+149
to
+151
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand the purposes of this. Is the idea to do this in lieu of a separate checkpoint conversion script? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is more for debugging now. I can remove |
||
elif job_config.training.enable_cpu_offload: | ||
init_device = "cpu" | ||
buffer_device = device_type | ||
else: | ||
init_device = device_type | ||
buffer_device = None | ||
|
||
self.loss_fn = self.train_spec.build_loss_fn(job_config) | ||
|
||
# verify batch sizes | ||
global_batch_size = job_config.training.global_batch_size | ||
if global_batch_size < 0: | ||
# This global batch size results in 1 gradient accumulation | ||
# step. | ||
global_batch_size = job_config.training.local_batch_size * dp_degree | ||
assert global_batch_size > 0 | ||
assert ( | ||
global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 | ||
), ( | ||
f"global batch size must be multiple of local batch size times " | ||
f"data-parallel degree ({global_batch_size} " | ||
f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" | ||
) | ||
self.global_batch_size = global_batch_size | ||
|
||
# calculate gradient accumulation steps | ||
self.gradient_accumulation_steps = global_batch_size // ( | ||
job_config.training.local_batch_size * dp_degree | ||
) | ||
assert self.gradient_accumulation_steps > 0 | ||
self.loss_fn = rescale_accumulated_loss( | ||
self.loss_fn, self.gradient_accumulation_steps | ||
) | ||
|
||
# apply parallelisms and initialization | ||
if parallel_dims.pp_enabled: | ||
if not self.train_spec.pipelining_fn: | ||
raise RuntimeError( | ||
f"Pipeline Parallel is enabled but {self.train_spec.name} " | ||
f"does not support pipelining" | ||
) | ||
|
||
# apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques | ||
( | ||
self.pp_schedule, | ||
self.model_parts, | ||
self.pp_has_first_stage, | ||
self.pp_has_last_stage, | ||
) = self.train_spec.pipelining_fn( | ||
model, | ||
parallel_dims, | ||
job_config, | ||
self.device, | ||
model_args, | ||
self.train_spec.parallelize_fn, | ||
self.loss_fn, | ||
) | ||
# when PP is enabled, `model` obj is no longer used after this point, | ||
# model_parts is used instead | ||
del model | ||
|
||
for m in self.model_parts: | ||
m.to_empty(device=init_device) | ||
with torch.no_grad(): | ||
m.init_weights(buffer_device=buffer_device) | ||
m.train() | ||
else: | ||
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel | ||
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) | ||
|
||
model.to_empty(device=init_device) | ||
with torch.no_grad(): | ||
model.init_weights(buffer_device=buffer_device) | ||
model.train() | ||
|
||
self.model_parts = [model] | ||
|
||
if ( | ||
self.ft_manager.enabled | ||
and job_config.fault_tolerance.semi_sync_method is None | ||
): | ||
self.ft_manager.set_all_reduce_hook(self.model_parts) | ||
|
||
# build optimizer after applying parallelisms to the model | ||
self.optimizers = self.train_spec.build_optimizers_fn( | ||
self.model_parts, job_config, parallel_dims, self.ft_manager | ||
) | ||
self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( | ||
self.optimizers, job_config | ||
) | ||
# Post optimizer step model converters hook. | ||
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 | ||
# where it issues a single all-reduce for all parameters at once for better performance | ||
self.optimizers.register_step_post_hook( | ||
lambda *args, **kwargs: model_converters.post_optimizer_hook( | ||
self.model_parts | ||
) | ||
) | ||
|
||
self.checkpointer = CheckpointManager( | ||
dataloader=None, | ||
model_parts=self.model_parts, | ||
optimizers=self.optimizers, | ||
lr_schedulers=self.lr_schedulers, | ||
states={"train_state": self}, | ||
job_config=job_config, | ||
ft_manager=self.ft_manager, | ||
) | ||
|
||
loss_parallel_enabled = ( | ||
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel | ||
) | ||
self.train_context = dist_utils.get_train_context( | ||
loss_parallel_enabled, | ||
parallelism_config.enable_compiled_autograd, | ||
) | ||
self.maybe_enable_amp = dist_utils.maybe_enable_amp( | ||
parallel_dims, | ||
job_config.training.mixed_precision_param, | ||
device_type, | ||
) | ||
|
||
def close(self) -> None: | ||
if self.checkpointer: | ||
self.checkpointer.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably don't need this, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, will remove