Skip to content

add the forge folder #1387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions torchtitan/experiments/forge/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## `ForgeEngine`

The `forge` folder contains a lightweight training engine that serves as a streamlined subset of the `Trainer` class from [torchtitan/train.py](/torchtitan/train.py). This engine provides only the essential constructor method, making it highly flexible for various downstream applications.

The [`ForgeEngine`](engine.py) takes a [`ForgeJobConfig`](job_config.py) to
- Initialize an SPMD distributed training environment
- Construct and scale models via n-D parallelisms and meta-device initialization
- Provide necessary training components and utilities

**Primary Use Case**: The engine is designed for building trainers in post-training workflows where multiple specialized components (trainer, generator, replay buffer, parameter server, etc.) work together.

Additionally, the folder provides a train spec registration method [`register_train_spec`](train_spec.py) that allows users to extend beyond the core set of models and training components available in torchtitan, enabling greater flexibility and customization for specific training requirements.

The [example_train.py](./example_train.py) demonstrates how to use `ForgeEngine` for pretraining, achieving the same functionality as [torchtitan/train.py](/torchtitan/train.py).
11 changes: 11 additions & 0 deletions torchtitan/experiments/forge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .engine import ForgeEngine
from .job_config import ForgeJobConfig
from .train_spec import ForgeTrainSpec, register_train_spec

__all__ = ["ForgeEngine", "ForgeJobConfig", "ForgeTrainSpec", "register_train_spec"]
276 changes: 276 additions & 0 deletions torchtitan/experiments/forge/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import importlib
import os
from typing import Generator

import torch
from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.protocols.train_spec as train_spec_module
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.ft import FTManager, init_ft_manager
from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.protocols.model_converter import build_model_converters
from torchtitan.protocols.train_spec import BaseModelArgs
from torchtitan.tools import utils

from .job_config import ForgeJobConfig
from .train_spec import ForgeTrainSpec, get_train_spec


class ForgeEngine(torch.distributed.checkpoint.stateful.Stateful):
# core configs
job_config: ForgeJobConfig
parallel_dims: ParallelDims
train_spec: ForgeTrainSpec

# swappable training components in ForgeTrainSpec
model_parts: list[torch.nn.Module]
loss_fn: train_spec_module.LossFunction
optimizers: train_spec_module.OptimizersContainer
lr_schedulers: train_spec_module.LRSchedulersContainer

# non-swappable training components
checkpointer: CheckpointManager
ft_manager: FTManager

# runtime utilities
device: torch.device
gc_handler: utils.GarbageCollection
gradient_accumulation_steps: int
train_context: Generator[None, None, None]
pp_has_first_stage: bool
pp_has_last_stage: bool

# Fields in ForgeEngine which are not in original Trainer
# for dataloading
tokenizer: train_spec_module.BaseTokenizer | None
dp_degree: int
dp_rank: int
# for logging
model_args: BaseModelArgs
num_flops_per_token: float
model_param_count: int
global_batch_size: int

# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def __init__(self, job_config: ForgeJobConfig):
torch._C._log_api_usage_once("torchtitan.train")

self.job_config = job_config

if job_config.experimental.custom_import:
importlib.import_module(job_config.experimental.custom_import)
Comment on lines +69 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need this, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will remove


device_module, device_type = utils.device_module, utils.device_type
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
# Device has to be set before creating TorchFT manager.
device_module.set_device(self.device)

# init distributed and build meshes
dist_utils.init_distributed(job_config)
world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism
self.parallel_dims = parallel_dims = ParallelDims(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
cp=parallelism_config.context_parallel_degree,
tp=parallelism_config.tensor_parallel_degree,
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
world_size=world_size,
)

world_mesh = parallel_dims.world_mesh
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0

self.ft_manager = init_ft_manager(job_config)
# If TorchFT is enabled, the dp_rank and dp_degree, which are used for
# dataloader must be changed.
if self.ft_manager.enabled:
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
Comment on lines +98 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here.. personally I would leave this out for the initial version (again, lmk if you disagree)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will remove


self.dp_degree, self.dp_rank = dp_degree, dp_rank

# take control of garbage collection to avoid stragglers
self.gc_handler = utils.GarbageCollection(
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
)

# Set random seed, and maybe enable deterministic mode
# (mainly for debugging, expect perf loss).
dist_utils.set_determinism(
world_mesh,
self.device,
job_config.training.seed,
job_config.training.deterministic,
)
self.train_spec = get_train_spec(job_config.model.name)

# build tokenizer
self.tokenizer = tokenizer = (
self.train_spec.build_tokenizer_fn(job_config)
if self.train_spec.build_tokenizer_fn is not None
else None
)
Comment on lines +122 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done downstream, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline, probably makes more sense to keep for now


# build model (using meta init)
self.model_args = model_args = self.train_spec.model_args[
job_config.model.flavor
]
# set the model args from training job configs
model_args.update_from_config(job_config, tokenizer)

with torch.device("meta"):
model = self.train_spec.model_cls(model_args)

# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)

# calculate model size and flops per token
(
self.model_param_count,
self.num_flops_per_token,
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)

# move sharded model to CPU/GPU and initialize weights via DTensor
if job_config.checkpoint.create_seed_checkpoint:
init_device = "cpu"
buffer_device = None
Comment on lines +149 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand the purposes of this. Is the idea to do this in lieu of a separate checkpoint conversion script?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more for debugging now. I can remove

elif job_config.training.enable_cpu_offload:
init_device = "cpu"
buffer_device = device_type
else:
init_device = device_type
buffer_device = None

self.loss_fn = self.train_spec.build_loss_fn(job_config)

# verify batch sizes
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
# This global batch size results in 1 gradient accumulation
# step.
global_batch_size = job_config.training.local_batch_size * dp_degree
assert global_batch_size > 0
assert (
global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0
), (
f"global batch size must be multiple of local batch size times "
f"data-parallel degree ({global_batch_size} "
f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)"
)
self.global_batch_size = global_batch_size

# calculate gradient accumulation steps
self.gradient_accumulation_steps = global_batch_size // (
job_config.training.local_batch_size * dp_degree
)
assert self.gradient_accumulation_steps > 0
self.loss_fn = rescale_accumulated_loss(
self.loss_fn, self.gradient_accumulation_steps
)

# apply parallelisms and initialization
if parallel_dims.pp_enabled:
if not self.train_spec.pipelining_fn:
raise RuntimeError(
f"Pipeline Parallel is enabled but {self.train_spec.name} "
f"does not support pipelining"
)

# apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
(
self.pp_schedule,
self.model_parts,
self.pp_has_first_stage,
self.pp_has_last_stage,
) = self.train_spec.pipelining_fn(
model,
parallel_dims,
job_config,
self.device,
model_args,
self.train_spec.parallelize_fn,
self.loss_fn,
)
# when PP is enabled, `model` obj is no longer used after this point,
# model_parts is used instead
del model

for m in self.model_parts:
m.to_empty(device=init_device)
with torch.no_grad():
m.init_weights(buffer_device=buffer_device)
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)

model.to_empty(device=init_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
model.train()

self.model_parts = [model]

if (
self.ft_manager.enabled
and job_config.fault_tolerance.semi_sync_method is None
):
self.ft_manager.set_all_reduce_hook(self.model_parts)

# build optimizer after applying parallelisms to the model
self.optimizers = self.train_spec.build_optimizers_fn(
self.model_parts, job_config, parallel_dims, self.ft_manager
)
self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
self.optimizers, job_config
)
# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
# where it issues a single all-reduce for all parameters at once for better performance
self.optimizers.register_step_post_hook(
lambda *args, **kwargs: model_converters.post_optimizer_hook(
self.model_parts
)
)

self.checkpointer = CheckpointManager(
dataloader=None,
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states={"train_state": self},
job_config=job_config,
ft_manager=self.ft_manager,
)

loss_parallel_enabled = (
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
)
self.train_context = dist_utils.get_train_context(
loss_parallel_enabled,
parallelism_config.enable_compiled_autograd,
)
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
parallel_dims,
job_config.training.mixed_precision_param,
device_type,
)

def close(self) -> None:
if self.checkpointer:
self.checkpointer.close()
Loading