Skip to content

Commit f8f2e88

Browse files
committed
add the forge folder
1 parent db52d57 commit f8f2e88

File tree

6 files changed

+725
-0
lines changed

6 files changed

+725
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
## `ForgeEngine`
2+
3+
The `forge` folder contains a lightweight training engine that serves as a streamlined subset of the `Trainer` class from [torchtitan/train.py](/torchtitan/train.py). This engine provides only the essential constructor method, making it highly flexible for various downstream applications.
4+
5+
The [`ForgeEngine`](engine.py) takes a [`ForgeJobConfig`](job_config.py) to
6+
- Initialize an SPMD distributed training environment
7+
- Construct and scale models via n-D parallelisms and meta-device initialization
8+
- Provide necessary training components and utilities
9+
10+
**Primary Use Case**: The engine is designed for building trainers in post-training workflows where multiple specialized components (trainer, generator, replay buffer, parameter server, etc.) work together.
11+
12+
Additionally, the folder provides a train spec registration method [`register_train_spec`](train_spec.py) that allows users to extend beyond the core set of models and training components available in torchtitan, enabling greater flexibility and customization for specific training requirements.
13+
14+
The [example_train.py](./example_train.py) demonstrates how to use `ForgeEngine` for pretraining, achieving the same functionality as [torchtitan/train.py](/torchtitan/train.py) (except for quantization or fault tolerance).
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .engine import ForgeEngine
8+
from .job_config import ForgeJobConfig
9+
from .train_spec import ForgeTrainSpec, register_train_spec
10+
11+
__all__ = ["ForgeEngine", "ForgeJobConfig", "ForgeTrainSpec", "register_train_spec"]
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from typing import Generator
9+
10+
import torch
11+
from torch.distributed.elastic.multiprocessing.errors import record
12+
13+
import torchtitan.protocols.train_spec as train_spec_module
14+
from torchtitan.components.checkpoint import CheckpointManager
15+
from torchtitan.components.loss import rescale_accumulated_loss
16+
from torchtitan.distributed import ParallelDims, utils as dist_utils
17+
from torchtitan.protocols.train_spec import BaseModelArgs
18+
from torchtitan.tools import utils
19+
20+
from .job_config import ForgeJobConfig
21+
from .train_spec import ForgeTrainSpec, get_train_spec
22+
23+
24+
class ForgeEngine(torch.distributed.checkpoint.stateful.Stateful):
25+
# core configs
26+
job_config: ForgeJobConfig
27+
parallel_dims: ParallelDims
28+
train_spec: ForgeTrainSpec
29+
30+
# swappable training components in ForgeTrainSpec
31+
model_parts: list[torch.nn.Module]
32+
loss_fn: train_spec_module.LossFunction
33+
optimizers: train_spec_module.OptimizersContainer
34+
lr_schedulers: train_spec_module.LRSchedulersContainer
35+
36+
# non-swappable training components
37+
checkpointer: CheckpointManager
38+
39+
# runtime utilities
40+
device: torch.device
41+
gc_handler: utils.GarbageCollection
42+
gradient_accumulation_steps: int
43+
train_context: Generator[None, None, None]
44+
pp_has_first_stage: bool
45+
pp_has_last_stage: bool
46+
47+
# Fields in ForgeEngine which are not in original Trainer
48+
# for dataloading
49+
tokenizer: train_spec_module.BaseTokenizer | None
50+
dp_degree: int
51+
dp_rank: int
52+
# for logging
53+
model_args: BaseModelArgs
54+
num_flops_per_token: float
55+
model_param_count: int
56+
global_batch_size: int
57+
58+
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
59+
@record
60+
def __init__(self, job_config: ForgeJobConfig):
61+
torch._C._log_api_usage_once("torchtitan.train")
62+
63+
self.job_config = job_config
64+
65+
device_module, device_type = utils.device_module, utils.device_type
66+
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
67+
# Device has to be set before creating TorchFT manager.
68+
device_module.set_device(self.device)
69+
70+
# init distributed and build meshes
71+
dist_utils.init_distributed(job_config)
72+
world_size = int(os.environ["WORLD_SIZE"])
73+
parallelism_config = job_config.parallelism
74+
self.parallel_dims = parallel_dims = ParallelDims(
75+
dp_shard=parallelism_config.data_parallel_shard_degree,
76+
dp_replicate=parallelism_config.data_parallel_replicate_degree,
77+
cp=parallelism_config.context_parallel_degree,
78+
tp=parallelism_config.tensor_parallel_degree,
79+
pp=parallelism_config.pipeline_parallel_degree,
80+
ep=parallelism_config.expert_parallel_degree,
81+
world_size=world_size,
82+
)
83+
84+
world_mesh = parallel_dims.world_mesh
85+
if parallel_dims.dp_enabled:
86+
dp_mesh = world_mesh["dp"]
87+
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
88+
else:
89+
dp_degree, dp_rank = 1, 0
90+
self.dp_degree, self.dp_rank = dp_degree, dp_rank
91+
92+
# take control of garbage collection to avoid stragglers
93+
self.gc_handler = utils.GarbageCollection(
94+
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
95+
)
96+
97+
# Set random seed, and maybe enable deterministic mode
98+
# (mainly for debugging, expect perf loss).
99+
dist_utils.set_determinism(
100+
world_mesh,
101+
self.device,
102+
job_config.training.seed,
103+
job_config.training.deterministic,
104+
)
105+
self.train_spec = get_train_spec(job_config.model.name)
106+
107+
# build tokenizer
108+
self.tokenizer = tokenizer = (
109+
self.train_spec.build_tokenizer_fn(job_config)
110+
if self.train_spec.build_tokenizer_fn is not None
111+
else None
112+
)
113+
114+
# build model (using meta init)
115+
self.model_args = model_args = self.train_spec.model_args[
116+
job_config.model.flavor
117+
]
118+
# set the model args from training job configs
119+
model_args.update_from_config(job_config, tokenizer)
120+
121+
with torch.device("meta"):
122+
model = self.train_spec.model_cls(model_args)
123+
124+
# calculate model size and flops per token
125+
(
126+
self.model_param_count,
127+
self.num_flops_per_token,
128+
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
129+
130+
# move sharded model to CPU/GPU and initialize weights via DTensor
131+
if job_config.training.enable_cpu_offload:
132+
init_device = "cpu"
133+
buffer_device = device_type
134+
else:
135+
init_device = device_type
136+
buffer_device = None
137+
138+
self.loss_fn = self.train_spec.build_loss_fn(job_config)
139+
140+
# verify batch sizes
141+
global_batch_size = job_config.training.global_batch_size
142+
if global_batch_size < 0:
143+
# This global batch size results in 1 gradient accumulation
144+
# step.
145+
global_batch_size = job_config.training.local_batch_size * dp_degree
146+
assert global_batch_size > 0
147+
assert (
148+
global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0
149+
), (
150+
f"global batch size must be multiple of local batch size times "
151+
f"data-parallel degree ({global_batch_size} "
152+
f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)"
153+
)
154+
self.global_batch_size = global_batch_size
155+
156+
# calculate gradient accumulation steps
157+
self.gradient_accumulation_steps = global_batch_size // (
158+
job_config.training.local_batch_size * dp_degree
159+
)
160+
assert self.gradient_accumulation_steps > 0
161+
self.loss_fn = rescale_accumulated_loss(
162+
self.loss_fn, self.gradient_accumulation_steps
163+
)
164+
165+
# apply parallelisms and initialization
166+
if parallel_dims.pp_enabled:
167+
if not self.train_spec.pipelining_fn:
168+
raise RuntimeError(
169+
f"Pipeline Parallel is enabled but {self.train_spec.name} "
170+
f"does not support pipelining"
171+
)
172+
173+
# apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
174+
(
175+
self.pp_schedule,
176+
self.model_parts,
177+
self.pp_has_first_stage,
178+
self.pp_has_last_stage,
179+
) = self.train_spec.pipelining_fn(
180+
model,
181+
parallel_dims,
182+
job_config,
183+
self.device,
184+
model_args,
185+
self.train_spec.parallelize_fn,
186+
self.loss_fn,
187+
)
188+
# when PP is enabled, `model` obj is no longer used after this point,
189+
# model_parts is used instead
190+
del model
191+
192+
for m in self.model_parts:
193+
m.to_empty(device=init_device)
194+
with torch.no_grad():
195+
m.init_weights(buffer_device=buffer_device)
196+
m.train()
197+
else:
198+
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
199+
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)
200+
201+
model.to_empty(device=init_device)
202+
with torch.no_grad():
203+
model.init_weights(buffer_device=buffer_device)
204+
model.train()
205+
206+
self.model_parts = [model]
207+
208+
# build optimizer after applying parallelisms to the model
209+
self.optimizers = self.train_spec.build_optimizers_fn(
210+
self.model_parts, job_config, parallel_dims
211+
)
212+
self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
213+
self.optimizers, job_config
214+
)
215+
216+
self.checkpointer = CheckpointManager(
217+
dataloader=None,
218+
model_parts=self.model_parts,
219+
optimizers=self.optimizers,
220+
lr_schedulers=self.lr_schedulers,
221+
states={"train_state": self},
222+
job_config=job_config,
223+
)
224+
225+
loss_parallel_enabled = (
226+
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
227+
)
228+
self.train_context = dist_utils.get_train_context(
229+
loss_parallel_enabled,
230+
parallelism_config.enable_compiled_autograd,
231+
)
232+
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
233+
parallel_dims,
234+
job_config.training.mixed_precision_param,
235+
device_type,
236+
)
237+
238+
def close(self) -> None:
239+
if self.checkpointer:
240+
self.checkpointer.close()

0 commit comments

Comments
 (0)