Skip to content

Commit 1ed4f48

Browse files
committed
add the forge folder
1 parent cbccb38 commit 1ed4f48

File tree

5 files changed

+728
-0
lines changed

5 files changed

+728
-0
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
## `ForgeEngine`
2+
3+
The `forge` folder contains a lightweight training engine that serves as a streamlined subset of the `Trainer` class from [torchtitan/train.py](/torchtitan/train.py). This engine provides only the essential constructor method, making it highly flexible for various downstream applications.
4+
5+
The [`ForgeEngine`](engine.py) helps to:
6+
- Initialize an SPMD distributed training environment
7+
- Construct and scale models via n-D parallelisms and meta-device initialization
8+
- Provide necessary training components and utilities
9+
10+
**Primary Use Case**: The engine is designed for building trainers in post-training workflows where multiple specialized components (trainer, generator, replay buffer, parameter server, etc.) work together.
11+
12+
Additionally, the folder provides a train spec registration method [`register_train_spec`](train_spec.py) that allows users to extend beyond the core set of models and training components available in torchtitan, enabling greater flexibility and customization for specific training requirements.
13+
14+
The [`example_train.py`](./example_train.py) demonstrates how to use `ForgeEngine` for pretraining, achieving the same functionality as [torchtitan/train.py](/torchtitan/train.py).
15+
16+
TODO: define a minimal version of `ForgeJobConfig`.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .engine import ForgeEngine
8+
from .train_spec import ForgeTrainSpec, register_train_spec
9+
10+
__all__ = ["ForgeEngine", "ForgeTrainSpec", "register_train_spec"]
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib
8+
import os
9+
from typing import Generator
10+
11+
import torch
12+
from torch.distributed.elastic.multiprocessing.errors import record
13+
14+
import torchtitan.components.ft as ft
15+
import torchtitan.protocols.train_spec as train_spec_module
16+
from torchtitan.components.checkpoint import CheckpointManager
17+
from torchtitan.components.loss import rescale_accumulated_loss
18+
from torchtitan.config_manager import JobConfig
19+
from torchtitan.distributed import ParallelDims, utils as dist_utils
20+
from torchtitan.experiments.forge.train_spec import ForgeTrainSpec, get_train_spec
21+
from torchtitan.protocols.model_converter import build_model_converters
22+
from torchtitan.protocols.train_spec import BaseModelArgs
23+
from torchtitan.tools import utils
24+
25+
26+
class ForgeEngine(torch.distributed.checkpoint.stateful.Stateful):
27+
# core configs
28+
job_config: JobConfig
29+
parallel_dims: ParallelDims
30+
train_spec: ForgeTrainSpec
31+
32+
# swappable training components in ForgeTrainSpec
33+
model_parts: list[torch.nn.Module]
34+
loss_fn: train_spec_module.LossFunction
35+
optimizers: train_spec_module.OptimizersContainer
36+
lr_schedulers: train_spec_module.LRSchedulersContainer
37+
38+
# non-swappable training components
39+
checkpointer: CheckpointManager
40+
ft_manager: ft.FTManager
41+
42+
# runtime utilities
43+
device: torch.device
44+
gc_handler: utils.GarbageCollection
45+
gradient_accumulation_steps: int
46+
train_context: Generator[None, None, None]
47+
pp_has_first_stage: bool
48+
pp_has_last_stage: bool
49+
50+
# Fields in ForgeEngine which are not in original Trainer
51+
# for dataloading
52+
tokenizer: train_spec_module.BaseTokenizer | None
53+
dp_degree: int
54+
dp_rank: int
55+
# for logging
56+
model_args: BaseModelArgs
57+
num_flops_per_token: float
58+
model_param_count: int
59+
global_batch_size: int
60+
61+
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
62+
@record
63+
def __init__(self, job_config: JobConfig):
64+
torch._C._log_api_usage_once("torchtitan.train")
65+
66+
self.job_config = job_config
67+
68+
if job_config.experimental.custom_import:
69+
importlib.import_module(job_config.experimental.custom_import)
70+
71+
device_module, device_type = utils.device_module, utils.device_type
72+
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
73+
# Device has to be set before creating TorchFT manager.
74+
device_module.set_device(self.device)
75+
76+
# init distributed and build meshes
77+
dist_utils.init_distributed(job_config)
78+
world_size = int(os.environ["WORLD_SIZE"])
79+
parallelism_config = job_config.parallelism
80+
self.parallel_dims = parallel_dims = ParallelDims(
81+
dp_shard=parallelism_config.data_parallel_shard_degree,
82+
dp_replicate=parallelism_config.data_parallel_replicate_degree,
83+
cp=parallelism_config.context_parallel_degree,
84+
tp=parallelism_config.tensor_parallel_degree,
85+
pp=parallelism_config.pipeline_parallel_degree,
86+
ep=parallelism_config.expert_parallel_degree,
87+
world_size=world_size,
88+
)
89+
90+
world_mesh = parallel_dims.world_mesh
91+
if parallel_dims.dp_enabled:
92+
dp_mesh = world_mesh["dp"]
93+
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
94+
else:
95+
dp_degree, dp_rank = 1, 0
96+
97+
self.ft_manager = ft.init_ft_manager(job_config)
98+
# If TorchFT is enabled, the dp_rank and dp_degree, which are used for
99+
# dataloader must be changed.
100+
if self.ft_manager.enabled:
101+
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
102+
103+
self.dp_degree, self.dp_rank = dp_degree, dp_rank
104+
105+
# take control of garbage collection to avoid stragglers
106+
self.gc_handler = utils.GarbageCollection(
107+
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
108+
)
109+
110+
# Set random seed, and maybe enable deterministic mode
111+
# (mainly for debugging, expect perf loss).
112+
dist_utils.set_determinism(
113+
world_mesh,
114+
self.device,
115+
job_config.training.seed,
116+
job_config.training.deterministic,
117+
)
118+
self.train_spec = get_train_spec(job_config.model.name)
119+
120+
# build tokenizer
121+
self.tokenizer = tokenizer = (
122+
self.train_spec.build_tokenizer_fn(job_config)
123+
if self.train_spec.build_tokenizer_fn is not None
124+
else None
125+
)
126+
127+
# build model (using meta init)
128+
self.model_args = model_args = self.train_spec.model_args[
129+
job_config.model.flavor
130+
]
131+
# set the model args from training job configs
132+
model_args.update_from_config(job_config, tokenizer)
133+
134+
with torch.device("meta"):
135+
model = self.train_spec.model_cls(model_args)
136+
137+
# Build the collection of model converters. No-op if `model.converters` empty
138+
model_converters = build_model_converters(job_config, parallel_dims)
139+
model_converters.convert(model)
140+
141+
# calculate model size and flops per token
142+
(
143+
self.model_param_count,
144+
self.num_flops_per_token,
145+
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
146+
147+
# move sharded model to CPU/GPU and initialize weights via DTensor
148+
if job_config.checkpoint.create_seed_checkpoint:
149+
init_device = "cpu"
150+
buffer_device = None
151+
elif job_config.training.enable_cpu_offload:
152+
init_device = "cpu"
153+
buffer_device = device_type
154+
else:
155+
init_device = device_type
156+
buffer_device = None
157+
158+
self.loss_fn = self.train_spec.build_loss_fn(job_config)
159+
160+
# verify batch sizes
161+
global_batch_size = job_config.training.global_batch_size
162+
if global_batch_size < 0:
163+
# This global batch size results in 1 gradient accumulation
164+
# step.
165+
global_batch_size = job_config.training.local_batch_size * dp_degree
166+
assert global_batch_size > 0
167+
assert (
168+
global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0
169+
), (
170+
f"global batch size must be multiple of local batch size times "
171+
f"data-parallel degree ({global_batch_size} "
172+
f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)"
173+
)
174+
self.global_batch_size = global_batch_size
175+
176+
# calculate gradient accumulation steps
177+
self.gradient_accumulation_steps = global_batch_size // (
178+
job_config.training.local_batch_size * dp_degree
179+
)
180+
assert self.gradient_accumulation_steps > 0
181+
self.loss_fn = rescale_accumulated_loss(
182+
self.loss_fn, self.gradient_accumulation_steps
183+
)
184+
185+
# apply parallelisms and initialization
186+
if parallel_dims.pp_enabled:
187+
if not self.train_spec.pipelining_fn:
188+
raise RuntimeError(
189+
f"Pipeline Parallel is enabled but {self.train_spec.name} "
190+
f"does not support pipelining"
191+
)
192+
193+
# apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
194+
(
195+
self.pp_schedule,
196+
self.model_parts,
197+
self.pp_has_first_stage,
198+
self.pp_has_last_stage,
199+
) = self.train_spec.pipelining_fn(
200+
model,
201+
parallel_dims,
202+
job_config,
203+
self.device,
204+
model_args,
205+
self.train_spec.parallelize_fn,
206+
self.loss_fn,
207+
)
208+
# when PP is enabled, `model` obj is no longer used after this point,
209+
# model_parts is used instead
210+
del model
211+
212+
for m in self.model_parts:
213+
m.to_empty(device=init_device)
214+
with torch.no_grad():
215+
m.init_weights(buffer_device=buffer_device)
216+
m.train()
217+
else:
218+
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
219+
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)
220+
221+
model.to_empty(device=init_device)
222+
with torch.no_grad():
223+
model.init_weights(buffer_device=buffer_device)
224+
model.train()
225+
226+
self.model_parts = [model]
227+
228+
if (
229+
self.ft_manager.enabled
230+
and job_config.fault_tolerance.semi_sync_method is None
231+
):
232+
self.ft_manager.set_all_reduce_hook(self.model_parts)
233+
234+
# build optimizer after applying parallelisms to the model
235+
self.optimizers = self.train_spec.build_optimizers_fn(
236+
self.model_parts, job_config, parallel_dims, self.ft_manager
237+
)
238+
self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
239+
self.optimizers, job_config
240+
)
241+
# Post optimizer step model converters hook.
242+
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
243+
# where it issues a single all-reduce for all parameters at once for better performance
244+
self.optimizers.register_step_post_hook(
245+
lambda *args, **kwargs: model_converters.post_optimizer_hook(
246+
self.model_parts
247+
)
248+
)
249+
250+
self.checkpointer = CheckpointManager(
251+
dataloader=None,
252+
model_parts=self.model_parts,
253+
optimizers=self.optimizers,
254+
lr_schedulers=self.lr_schedulers,
255+
states={"train_state": self},
256+
job_config=job_config,
257+
ft_manager=self.ft_manager,
258+
)
259+
260+
loss_parallel_enabled = (
261+
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
262+
)
263+
self.train_context = dist_utils.get_train_context(
264+
loss_parallel_enabled,
265+
parallelism_config.enable_compiled_autograd,
266+
)
267+
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
268+
parallel_dims,
269+
job_config.training.mixed_precision_param,
270+
device_type,
271+
)
272+
273+
def close(self) -> None:
274+
if self.checkpointer:
275+
self.checkpointer.close()

0 commit comments

Comments
 (0)