Skip to content

Commit 6204cdf

Browse files
authored
refactor ParallelDims and CheckpointManager (#1384)
This PR does the following: 1. move `world_mesh` into `ParallelDims`, as they have a close relationship 2. move `enable_loss_parallel` out of `ParallelDims` constructor 3. add a convenient property `seq_len_divisor` to `ParallelDims` 4. set `dataloader` and `ft_manager` as optional in `CheckpointManager` 5. some minor improvements on typing and code organization
1 parent 8908970 commit 6204cdf

File tree

30 files changed

+220
-179
lines changed

30 files changed

+220
-179
lines changed

scripts/estimate/estimation.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ def estimate_memory(job_config: JobConfig):
3939
job_config.training.compile = False
4040
job_config.parallelism.enable_compiled_autograd = False
4141

42+
# init fake pg
43+
store = FakeStore()
44+
torch.distributed.init_process_group(
45+
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
46+
)
47+
4248
parallelism_config = job_config.parallelism
4349
parallel_dims = ParallelDims(
4450
dp_shard=parallelism_config.data_parallel_shard_degree,
@@ -48,8 +54,9 @@ def estimate_memory(job_config: JobConfig):
4854
pp=parallelism_config.pipeline_parallel_degree,
4955
ep=parallelism_config.expert_parallel_degree,
5056
world_size=world_size,
51-
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
5257
)
58+
# ParallelDims.build_mesh has to happen outside of the FakeTensorMode
59+
_ = parallel_dims.world_mesh
5360

5461
# only FSDP and HSDP are supported
5562
if (
@@ -68,28 +75,21 @@ def estimate_memory(job_config: JobConfig):
6875
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
6976
torch.cuda.set_device(device)
7077

71-
# init fake pg
72-
store = FakeStore()
73-
torch.distributed.init_process_group(
74-
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
75-
)
76-
7778
train_spec = get_train_spec(job_config.model.name)
7879

79-
# build meshes
80-
world_mesh = parallel_dims.build_mesh(device_type="cuda")
81-
8280
# build tokenizer
8381
tokenizer = train_spec.build_tokenizer_fn(job_config)
8482

83+
loss_parallel_enabled = (
84+
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
85+
)
8586
train_context = dist_utils.get_train_context(
86-
parallel_dims.loss_parallel_enabled,
87+
loss_parallel_enabled,
8788
job_config.parallelism.enable_compiled_autograd,
8889
)
8990

9091
# build model (using meta init)
91-
model_cls = train_spec.cls
92-
model_args = train_spec.config[job_config.model.flavor]
92+
model_args = train_spec.model_args[job_config.model.flavor]
9393
model_args.update_from_config(job_config, tokenizer)
9494

9595
with (
@@ -101,14 +101,14 @@ def estimate_memory(job_config: JobConfig):
101101
f"Building {train_spec.name} {job_config.model.flavor} with {model_args}"
102102
)
103103
with torch.device("meta"):
104-
model = model_cls(model_args)
104+
model = train_spec.model_cls(model_args)
105105

106106
# Build the collection of model converters. No-op if `model.converters` empty
107107
model_converters = build_model_converters(job_config, parallel_dims)
108108
model_converters.convert(model)
109109

110110
# apply PT-D DP/TP parallelisms and activation checkpointing
111-
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
111+
train_spec.parallelize_fn(model, parallel_dims, job_config)
112112

113113
model.to_empty(device="cuda")
114114
if not active_fake_mode():
@@ -117,9 +117,7 @@ def estimate_memory(job_config: JobConfig):
117117

118118
# build optimizer after applying parallelisms to the model
119119
ft_manager = init_ft_manager(job_config)
120-
optimizers = build_optimizers(
121-
[model], job_config, parallel_dims, world_mesh, ft_manager
122-
)
120+
optimizers = build_optimizers([model], job_config, parallel_dims, ft_manager)
123121
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
124122
# Post optimizer step model converters hook.
125123
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2

scripts/generate/test_generate.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,13 @@ def test_generate(
106106
# Tokenizer setup
107107
tokenizer = train_spec.build_tokenizer_fn(config)
108108

109-
model_cls = train_spec.cls
110-
model_args = train_spec.config[config.model.flavor]
109+
model_args = train_spec.model_args[config.model.flavor]
111110
model_args.update_from_config(config, tokenizer)
112111

113112
init_device = "meta" if world_size > 1 else device
114113
with torch.device(init_device):
115114
logger.info(f"Init model on init_device: {init_device}")
116-
model = model_cls(model_args)
115+
model = train_spec.model_cls(model_args)
117116

118117
world_mesh = None
119118
# Init distributed env
@@ -127,14 +126,12 @@ def test_generate(
127126
pp=1,
128127
ep=1,
129128
world_size=world_size,
130-
enable_loss_parallel=False,
131129
)
132-
# Build world mesh for parallelism
133-
world_mesh = parallel_dims.build_mesh(device_type=device_type)
130+
world_mesh = parallel_dims.world_mesh
134131

135132
# apply_tp (with Sequence Parallel) on unevenly sharded
136133
# sequences would require https://github.com/pytorch/torchtitan/pull/686
137-
apply_tp_minus_sp(model, world_mesh["tp"])
134+
apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"])
138135

139136
dist_utils.set_determinism(world_mesh, device, seed, deterministic)
140137

tests/unit_tests/test_model_converter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def build_parallel_dims(job_config, world_size):
2323
pp=parallelism_config.pipeline_parallel_degree,
2424
ep=parallelism_config.expert_parallel_degree,
2525
world_size=world_size,
26-
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
2726
)
2827
return parallel_dims
2928

tests/unit_tests/test_train_spec.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
from torchtitan.components.ft import FTManager
1213
from torchtitan.components.loss import build_cross_entropy_loss
1314
from torchtitan.components.lr_scheduler import build_lr_schedulers
1415
from torchtitan.components.optimizer import build_optimizers, OptimizersContainer
1516
from torchtitan.components.tokenizer import build_hf_tokenizer
1617
from torchtitan.config_manager import JobConfig
1718
from torchtitan.datasets.hf_datasets import build_hf_dataloader
19+
from torchtitan.distributed.parallel_dims import ParallelDims
1820
from torchtitan.models.llama3 import parallelize_llama, pipeline_llama
1921
from torchtitan.protocols.train_spec import (
2022
apply_to_train_specs,
@@ -39,7 +41,10 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
3941

4042

4143
def fake_build_optimizers(
42-
model_parts: list[nn.Module], job_config: JobConfig
44+
model_parts: list[nn.Module],
45+
job_config: JobConfig,
46+
parallel_dims: ParallelDims,
47+
ft_manager: FTManager,
4348
) -> OptimizersContainer:
4449
optimizer_kwargs = {
4550
"lr": 0.1,
@@ -57,11 +62,11 @@ def fake_build_optimizers(
5762

5863
class TestTrainSpec:
5964
def test_register_train_spec(self):
60-
fake_config = {"fake": None}
65+
fake_config = {"fake": BaseModelArgs()}
6166
spec = TrainSpec(
6267
name="fake",
63-
cls=FakeModel,
64-
config=fake_config,
68+
model_cls=FakeModel,
69+
model_args=fake_config,
6570
parallelize_fn=parallelize_llama,
6671
pipelining_fn=pipeline_llama,
6772
build_optimizers_fn=build_optimizers,
@@ -78,11 +83,11 @@ def test_register_train_spec(self):
7883
new_spec = get_train_spec("fake2")
7984

8085
def test_optim_hook(self):
81-
fake_config = {"fake": None}
86+
fake_config = {"fake": BaseModelArgs()}
8287
spec = TrainSpec(
8388
name="fake2",
84-
cls=FakeModel,
85-
config=fake_config,
89+
model_cls=FakeModel,
90+
model_args=fake_config,
8691
parallelize_fn=parallelize_llama,
8792
pipelining_fn=pipeline_llama,
8893
build_optimizers_fn=fake_build_optimizers,
@@ -111,21 +116,27 @@ def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec:
111116
original_build_optimizers_fn = spec.build_optimizers_fn
112117

113118
def my_build_optimizer_fn(
114-
model_parts: list[nn.Module], job_config: JobConfig
119+
model_parts: list[nn.Module],
120+
job_config: JobConfig,
121+
parallel_dims: ParallelDims,
122+
ft_manager: FTManager,
115123
) -> OptimizersContainer:
116-
optimizers = original_build_optimizers_fn(model_parts, job_config)
124+
optimizers = original_build_optimizers_fn(
125+
model_parts, job_config, parallel_dims, ft_manager
126+
)
117127
optimizers.register_step_post_hook(
118128
partial(my_hook, model_parts=model_parts)
119129
)
120130
return optimizers
121131

122132
spec.build_optimizers_fn = my_build_optimizer_fn
133+
return spec
123134

124135
apply_to_train_specs(register_optimizer_hook_to_spec)
125136

126-
model = new_spec.cls(BaseModelArgs())
137+
model = new_spec.model_cls(BaseModelArgs())
127138
model_parts = [model]
128-
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
139+
optimizers = new_spec.build_optimizers_fn(model_parts, None, None, None)
129140
assert optimizers.optimizers[0].__class__.__name__ == "Adam"
130141
batch = torch.randn(8, 8)
131142
model(batch).sum().backward()

torchtitan/components/checkpoint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
)
2727
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
2828
from torch.distributed.checkpoint.stateful import Stateful
29-
from torch.utils.data import DataLoader
3029

30+
from torchtitan.components.dataloader import BaseDataLoader
3131
from torchtitan.components.ft import FTManager
3232
from torchtitan.components.lr_scheduler import LRSchedulersContainer
3333
from torchtitan.components.optimizer import OptimizersContainer
@@ -180,17 +180,19 @@ class CheckpointManager:
180180

181181
def __init__(
182182
self,
183-
dataloader: DataLoader,
183+
dataloader: BaseDataLoader | None,
184184
model_parts: list[nn.Module],
185185
optimizers: OptimizersContainer,
186186
lr_schedulers: LRSchedulersContainer,
187187
states: dict[str, Any],
188188
job_config: JobConfig,
189-
ft_manager: FTManager,
189+
ft_manager: FTManager | None = None,
190190
) -> None:
191191
ckpt_config = job_config.checkpoint
192192
self.enable_checkpoint = ckpt_config.enable_checkpoint
193-
self.ft_manager = ft_manager.manager if ft_manager.enabled else None
193+
self.ft_manager = (
194+
ft_manager.manager if ft_manager and ft_manager.enabled else None
195+
)
194196

195197
if self.ft_manager:
196198
optimizers.init_cache_state_dict()

torchtitan/components/optimizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
StateDictOptions,
1616
)
1717
from torch.distributed.checkpoint.stateful import Stateful
18-
from torch.distributed.device_mesh import DeviceMesh
1918
from torch.optim import Optimizer
2019

2120
from torchtitan.components.ft import FTManager, has_torchft
@@ -244,7 +243,6 @@ def build_optimizers(
244243
model_parts: list[nn.Module],
245244
job_config: JobConfig,
246245
parallel_dims: ParallelDims,
247-
world_mesh: DeviceMesh,
248246
ft_manager: FTManager,
249247
) -> OptimizersContainer:
250248
"""Create a OptimizersContainer for the given model parts and job config.

torchtitan/components/tokenizer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,15 @@
77

88
import json
99

10-
import logging
1110
import os
1211
from abc import ABC, abstractmethod
1312
from typing import Any, Optional, Union
1413

1514
from tokenizers import AddedToken, Tokenizer
1615
from torchtitan.config_manager import JobConfig
16+
from torchtitan.tools.logging import logger
1717
from typing_extensions import override
1818

19-
logger = logging.getLogger(__name__)
20-
2119

2220
class BaseTokenizer(ABC):
2321
# base tokenizer interface, for typing purpose mainly

torchtitan/components/validate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,12 @@ def __init__(
5050
dp_rank: int,
5151
tokenizer: BaseTokenizer,
5252
parallel_dims: ParallelDims,
53-
world_mesh: torch.distributed.DeviceMesh,
5453
loss_fn: LossFunction,
5554
validation_context: Generator[None, None, None],
5655
maybe_enable_amp: Generator[None, None, None],
5756
):
5857
self.job_config = job_config
5958
self.parallel_dims = parallel_dims
60-
self.world_mesh = world_mesh
6159
self.loss_fn = loss_fn
6260
self.validation_dataloader = build_hf_validation_dataloader(
6361
job_config=job_config,
@@ -78,6 +76,8 @@ def validate(
7876
model = model_parts[0]
7977
model.eval()
8078

79+
parallel_dims = self.parallel_dims
80+
8181
accumulated_losses = []
8282
device_type = utils.device_type
8383
num_steps = 0
@@ -96,13 +96,13 @@ def validate(
9696

9797
optional_context_parallel_ctx = (
9898
dist_utils.create_context_parallel_ctx(
99-
cp_mesh=self.world_mesh["cp"],
99+
cp_mesh=parallel_dims.world_mesh["cp"],
100100
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
101101
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
102102
cp_no_restore_buffers={inputs, labels},
103103
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
104104
)
105-
if self.parallel_dims.cp_enabled
105+
if parallel_dims.cp_enabled
106106
else None
107107
)
108108

@@ -119,8 +119,10 @@ def validate(
119119
# Compute average loss
120120
loss = torch.sum(torch.stack(accumulated_losses))
121121
loss /= num_steps
122-
if self.parallel_dims.dp_cp_enabled:
123-
global_avg_loss = dist_utils.dist_mean(loss, self.world_mesh["dp_cp"])
122+
if parallel_dims.dp_cp_enabled:
123+
global_avg_loss = dist_utils.dist_mean(
124+
loss, parallel_dims.world_mesh["dp_cp"]
125+
)
124126
else:
125127
global_avg_loss = loss
126128

@@ -144,7 +146,6 @@ def build_validator(
144146
dp_rank: int,
145147
tokenizer: BaseTokenizer,
146148
parallel_dims: ParallelDims,
147-
world_mesh: torch.distributed.DeviceMesh,
148149
loss_fn: LossFunction,
149150
validation_context: Generator[None, None, None],
150151
maybe_enable_amp: Generator[None, None, None],
@@ -156,7 +157,6 @@ def build_validator(
156157
dp_rank=dp_rank,
157158
tokenizer=tokenizer,
158159
parallel_dims=parallel_dims,
159-
world_mesh=world_mesh,
160160
loss_fn=loss_fn,
161161
validation_context=validation_context,
162162
maybe_enable_amp=maybe_enable_amp,

0 commit comments

Comments
 (0)