Skip to content

Commit 5ad8ec0

Browse files
committed
add the forge folder
1 parent cbccb38 commit 5ad8ec0

File tree

6 files changed

+773
-0
lines changed

6 files changed

+773
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
## `ForgeEngine`
2+
3+
The `forge` folder contains a lightweight training engine that serves as a streamlined subset of the `Trainer` class from [torchtitan/train.py](/torchtitan/train.py). This engine provides only the essential constructor method, making it highly flexible for various downstream applications.
4+
5+
The [`ForgeEngine`](engine.py) takes a [`ForgeJobConfig`](job_config.py) to
6+
- Initialize an SPMD distributed training environment
7+
- Construct and scale models via n-D parallelisms and meta-device initialization
8+
- Provide necessary training components and utilities
9+
10+
**Primary Use Case**: The engine is designed for building trainers in post-training workflows where multiple specialized components (trainer, generator, replay buffer, parameter server, etc.) work together.
11+
12+
Additionally, the folder provides a train spec registration method [`register_train_spec`](train_spec.py) that allows users to extend beyond the core set of models and training components available in torchtitan, enabling greater flexibility and customization for specific training requirements.
13+
14+
The [example_train.py](./example_train.py) demonstrates how to use `ForgeEngine` for pretraining, achieving the same functionality as [torchtitan/train.py](/torchtitan/train.py).
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .engine import ForgeEngine
8+
from .job_config import ForgeJobConfig
9+
from .train_spec import ForgeTrainSpec, register_train_spec
10+
11+
__all__ = ["ForgeEngine", "ForgeJobConfig", "ForgeTrainSpec", "register_train_spec"]
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib
8+
import os
9+
from typing import Generator
10+
11+
import torch
12+
from torch.distributed.elastic.multiprocessing.errors import record
13+
14+
import torchtitan.protocols.train_spec as train_spec_module
15+
from torchtitan.components.checkpoint import CheckpointManager
16+
from torchtitan.components.ft import FTManager, init_ft_manager
17+
from torchtitan.components.loss import rescale_accumulated_loss
18+
from torchtitan.distributed import ParallelDims, utils as dist_utils
19+
from torchtitan.protocols.model_converter import build_model_converters
20+
from torchtitan.protocols.train_spec import BaseModelArgs
21+
from torchtitan.tools import utils
22+
23+
from .job_config import ForgeJobConfig
24+
from .train_spec import ForgeTrainSpec, get_train_spec
25+
26+
27+
class ForgeEngine(torch.distributed.checkpoint.stateful.Stateful):
28+
# core configs
29+
job_config: ForgeJobConfig
30+
parallel_dims: ParallelDims
31+
train_spec: ForgeTrainSpec
32+
33+
# swappable training components in ForgeTrainSpec
34+
model_parts: list[torch.nn.Module]
35+
loss_fn: train_spec_module.LossFunction
36+
optimizers: train_spec_module.OptimizersContainer
37+
lr_schedulers: train_spec_module.LRSchedulersContainer
38+
39+
# non-swappable training components
40+
checkpointer: CheckpointManager
41+
ft_manager: FTManager
42+
43+
# runtime utilities
44+
device: torch.device
45+
gc_handler: utils.GarbageCollection
46+
gradient_accumulation_steps: int
47+
train_context: Generator[None, None, None]
48+
pp_has_first_stage: bool
49+
pp_has_last_stage: bool
50+
51+
# Fields in ForgeEngine which are not in original Trainer
52+
# for dataloading
53+
tokenizer: train_spec_module.BaseTokenizer | None
54+
dp_degree: int
55+
dp_rank: int
56+
# for logging
57+
model_args: BaseModelArgs
58+
num_flops_per_token: float
59+
model_param_count: int
60+
global_batch_size: int
61+
62+
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
63+
@record
64+
def __init__(self, job_config: ForgeJobConfig):
65+
torch._C._log_api_usage_once("torchtitan.train")
66+
67+
self.job_config = job_config
68+
69+
if job_config.experimental.custom_import:
70+
importlib.import_module(job_config.experimental.custom_import)
71+
72+
device_module, device_type = utils.device_module, utils.device_type
73+
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
74+
# Device has to be set before creating TorchFT manager.
75+
device_module.set_device(self.device)
76+
77+
# init distributed and build meshes
78+
dist_utils.init_distributed(job_config)
79+
world_size = int(os.environ["WORLD_SIZE"])
80+
parallelism_config = job_config.parallelism
81+
self.parallel_dims = parallel_dims = ParallelDims(
82+
dp_shard=parallelism_config.data_parallel_shard_degree,
83+
dp_replicate=parallelism_config.data_parallel_replicate_degree,
84+
cp=parallelism_config.context_parallel_degree,
85+
tp=parallelism_config.tensor_parallel_degree,
86+
pp=parallelism_config.pipeline_parallel_degree,
87+
ep=parallelism_config.expert_parallel_degree,
88+
world_size=world_size,
89+
)
90+
91+
world_mesh = parallel_dims.world_mesh
92+
if parallel_dims.dp_enabled:
93+
dp_mesh = world_mesh["dp"]
94+
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
95+
else:
96+
dp_degree, dp_rank = 1, 0
97+
98+
self.ft_manager = init_ft_manager(job_config)
99+
# If TorchFT is enabled, the dp_rank and dp_degree, which are used for
100+
# dataloader must be changed.
101+
if self.ft_manager.enabled:
102+
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
103+
104+
self.dp_degree, self.dp_rank = dp_degree, dp_rank
105+
106+
# take control of garbage collection to avoid stragglers
107+
self.gc_handler = utils.GarbageCollection(
108+
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
109+
)
110+
111+
# Set random seed, and maybe enable deterministic mode
112+
# (mainly for debugging, expect perf loss).
113+
dist_utils.set_determinism(
114+
world_mesh,
115+
self.device,
116+
job_config.training.seed,
117+
job_config.training.deterministic,
118+
)
119+
self.train_spec = get_train_spec(job_config.model.name)
120+
121+
# build tokenizer
122+
self.tokenizer = tokenizer = (
123+
self.train_spec.build_tokenizer_fn(job_config)
124+
if self.train_spec.build_tokenizer_fn is not None
125+
else None
126+
)
127+
128+
# build model (using meta init)
129+
self.model_args = model_args = self.train_spec.model_args[
130+
job_config.model.flavor
131+
]
132+
# set the model args from training job configs
133+
model_args.update_from_config(job_config, tokenizer)
134+
135+
with torch.device("meta"):
136+
model = self.train_spec.model_cls(model_args)
137+
138+
# Build the collection of model converters. No-op if `model.converters` empty
139+
model_converters = build_model_converters(job_config, parallel_dims)
140+
model_converters.convert(model)
141+
142+
# calculate model size and flops per token
143+
(
144+
self.model_param_count,
145+
self.num_flops_per_token,
146+
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
147+
148+
# move sharded model to CPU/GPU and initialize weights via DTensor
149+
if job_config.checkpoint.create_seed_checkpoint:
150+
init_device = "cpu"
151+
buffer_device = None
152+
elif job_config.training.enable_cpu_offload:
153+
init_device = "cpu"
154+
buffer_device = device_type
155+
else:
156+
init_device = device_type
157+
buffer_device = None
158+
159+
self.loss_fn = self.train_spec.build_loss_fn(job_config)
160+
161+
# verify batch sizes
162+
global_batch_size = job_config.training.global_batch_size
163+
if global_batch_size < 0:
164+
# This global batch size results in 1 gradient accumulation
165+
# step.
166+
global_batch_size = job_config.training.local_batch_size * dp_degree
167+
assert global_batch_size > 0
168+
assert (
169+
global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0
170+
), (
171+
f"global batch size must be multiple of local batch size times "
172+
f"data-parallel degree ({global_batch_size} "
173+
f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)"
174+
)
175+
self.global_batch_size = global_batch_size
176+
177+
# calculate gradient accumulation steps
178+
self.gradient_accumulation_steps = global_batch_size // (
179+
job_config.training.local_batch_size * dp_degree
180+
)
181+
assert self.gradient_accumulation_steps > 0
182+
self.loss_fn = rescale_accumulated_loss(
183+
self.loss_fn, self.gradient_accumulation_steps
184+
)
185+
186+
# apply parallelisms and initialization
187+
if parallel_dims.pp_enabled:
188+
if not self.train_spec.pipelining_fn:
189+
raise RuntimeError(
190+
f"Pipeline Parallel is enabled but {self.train_spec.name} "
191+
f"does not support pipelining"
192+
)
193+
194+
# apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
195+
(
196+
self.pp_schedule,
197+
self.model_parts,
198+
self.pp_has_first_stage,
199+
self.pp_has_last_stage,
200+
) = self.train_spec.pipelining_fn(
201+
model,
202+
parallel_dims,
203+
job_config,
204+
self.device,
205+
model_args,
206+
self.train_spec.parallelize_fn,
207+
self.loss_fn,
208+
)
209+
# when PP is enabled, `model` obj is no longer used after this point,
210+
# model_parts is used instead
211+
del model
212+
213+
for m in self.model_parts:
214+
m.to_empty(device=init_device)
215+
with torch.no_grad():
216+
m.init_weights(buffer_device=buffer_device)
217+
m.train()
218+
else:
219+
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
220+
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)
221+
222+
model.to_empty(device=init_device)
223+
with torch.no_grad():
224+
model.init_weights(buffer_device=buffer_device)
225+
model.train()
226+
227+
self.model_parts = [model]
228+
229+
if (
230+
self.ft_manager.enabled
231+
and job_config.fault_tolerance.semi_sync_method is None
232+
):
233+
self.ft_manager.set_all_reduce_hook(self.model_parts)
234+
235+
# build optimizer after applying parallelisms to the model
236+
self.optimizers = self.train_spec.build_optimizers_fn(
237+
self.model_parts, job_config, parallel_dims, self.ft_manager
238+
)
239+
self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
240+
self.optimizers, job_config
241+
)
242+
# Post optimizer step model converters hook.
243+
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
244+
# where it issues a single all-reduce for all parameters at once for better performance
245+
self.optimizers.register_step_post_hook(
246+
lambda *args, **kwargs: model_converters.post_optimizer_hook(
247+
self.model_parts
248+
)
249+
)
250+
251+
self.checkpointer = CheckpointManager(
252+
dataloader=None,
253+
model_parts=self.model_parts,
254+
optimizers=self.optimizers,
255+
lr_schedulers=self.lr_schedulers,
256+
states={"train_state": self},
257+
job_config=job_config,
258+
ft_manager=self.ft_manager,
259+
)
260+
261+
loss_parallel_enabled = (
262+
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
263+
)
264+
self.train_context = dist_utils.get_train_context(
265+
loss_parallel_enabled,
266+
parallelism_config.enable_compiled_autograd,
267+
)
268+
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
269+
parallel_dims,
270+
job_config.training.mixed_precision_param,
271+
device_type,
272+
)
273+
274+
def close(self) -> None:
275+
if self.checkpointer:
276+
self.checkpointer.close()

0 commit comments

Comments
 (0)