Skip to content

refactor ParallelDims and CheckpointManager #1384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def estimate_memory(job_config: JobConfig):
job_config.training.compile = False
job_config.parallelism.enable_compiled_autograd = False

# init fake pg
store = FakeStore()
torch.distributed.init_process_group(
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
)

parallelism_config = job_config.parallelism
parallel_dims = ParallelDims(
dp_shard=parallelism_config.data_parallel_shard_degree,
Expand All @@ -48,8 +54,9 @@ def estimate_memory(job_config: JobConfig):
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
world_size=world_size,
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
)
# ParallelDims.build_mesh has to happen outside of the FakeTensorMode
_ = parallel_dims.world_mesh

# only FSDP and HSDP are supported
if (
Expand All @@ -68,28 +75,21 @@ def estimate_memory(job_config: JobConfig):
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)

# init fake pg
store = FakeStore()
torch.distributed.init_process_group(
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
)

train_spec = get_train_spec(job_config.model.name)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")

# build tokenizer
tokenizer = train_spec.build_tokenizer_fn(job_config)

loss_parallel_enabled = (
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
)
train_context = dist_utils.get_train_context(
parallel_dims.loss_parallel_enabled,
loss_parallel_enabled,
job_config.parallelism.enable_compiled_autograd,
)

# build model (using meta init)
model_cls = train_spec.cls
model_args = train_spec.config[job_config.model.flavor]
model_args = train_spec.model_args[job_config.model.flavor]
model_args.update_from_config(job_config, tokenizer)

with (
Expand All @@ -101,14 +101,14 @@ def estimate_memory(job_config: JobConfig):
f"Building {train_spec.name} {job_config.model.flavor} with {model_args}"
)
with torch.device("meta"):
model = model_cls(model_args)
model = train_spec.model_cls(model_args)

# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
train_spec.parallelize_fn(model, parallel_dims, job_config)

model.to_empty(device="cuda")
if not active_fake_mode():
Expand All @@ -117,9 +117,7 @@ def estimate_memory(job_config: JobConfig):

# build optimizer after applying parallelisms to the model
ft_manager = init_ft_manager(job_config)
optimizers = build_optimizers(
[model], job_config, parallel_dims, world_mesh, ft_manager
)
optimizers = build_optimizers([model], job_config, parallel_dims, ft_manager)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
Expand Down
11 changes: 4 additions & 7 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,13 @@ def test_generate(
# Tokenizer setup
tokenizer = train_spec.build_tokenizer_fn(config)

model_cls = train_spec.cls
model_args = train_spec.config[config.model.flavor]
model_args = train_spec.model_args[config.model.flavor]
model_args.update_from_config(config, tokenizer)

init_device = "meta" if world_size > 1 else device
with torch.device(init_device):
logger.info(f"Init model on init_device: {init_device}")
model = model_cls(model_args)
model = train_spec.model_cls(model_args)

world_mesh = None
# Init distributed env
Expand All @@ -127,14 +126,12 @@ def test_generate(
pp=1,
ep=1,
world_size=world_size,
enable_loss_parallel=False,
)
# Build world mesh for parallelism
world_mesh = parallel_dims.build_mesh(device_type=device_type)
world_mesh = parallel_dims.world_mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We don't need this line here to "build world mesh" explicitly, right? In line 134, parallel_dims.world_mesh["tp"] will call build_mesh internally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be used in line 136 to set determinism.


# apply_tp (with Sequence Parallel) on unevenly sharded
# sequences would require https://github.com/pytorch/torchtitan/pull/686
apply_tp_minus_sp(model, world_mesh["tp"])
apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"])

dist_utils.set_determinism(world_mesh, device, seed, deterministic)

Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def build_parallel_dims(job_config, world_size):
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
world_size=world_size,
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
)
return parallel_dims

Expand Down
33 changes: 22 additions & 11 deletions tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import pytest
import torch
import torch.nn as nn
from torchtitan.components.ft import FTManager
from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers, OptimizersContainer
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.models.llama3 import parallelize_llama, pipeline_llama
from torchtitan.protocols.train_spec import (
apply_to_train_specs,
Expand All @@ -39,7 +41,10 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:


def fake_build_optimizers(
model_parts: list[nn.Module], job_config: JobConfig
model_parts: list[nn.Module],
job_config: JobConfig,
parallel_dims: ParallelDims,
ft_manager: FTManager,
) -> OptimizersContainer:
optimizer_kwargs = {
"lr": 0.1,
Expand All @@ -57,11 +62,11 @@ def fake_build_optimizers(

class TestTrainSpec:
def test_register_train_spec(self):
fake_config = {"fake": None}
fake_config = {"fake": BaseModelArgs()}
spec = TrainSpec(
name="fake",
cls=FakeModel,
config=fake_config,
model_cls=FakeModel,
model_args=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
Expand All @@ -78,11 +83,11 @@ def test_register_train_spec(self):
new_spec = get_train_spec("fake2")

def test_optim_hook(self):
fake_config = {"fake": None}
fake_config = {"fake": BaseModelArgs()}
spec = TrainSpec(
name="fake2",
cls=FakeModel,
config=fake_config,
model_cls=FakeModel,
model_args=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
Expand Down Expand Up @@ -111,21 +116,27 @@ def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec:
original_build_optimizers_fn = spec.build_optimizers_fn

def my_build_optimizer_fn(
model_parts: list[nn.Module], job_config: JobConfig
model_parts: list[nn.Module],
job_config: JobConfig,
parallel_dims: ParallelDims,
ft_manager: FTManager,
) -> OptimizersContainer:
optimizers = original_build_optimizers_fn(model_parts, job_config)
optimizers = original_build_optimizers_fn(
model_parts, job_config, parallel_dims, ft_manager
)
optimizers.register_step_post_hook(
partial(my_hook, model_parts=model_parts)
)
return optimizers

spec.build_optimizers_fn = my_build_optimizer_fn
return spec

apply_to_train_specs(register_optimizer_hook_to_spec)

model = new_spec.cls(BaseModelArgs())
model = new_spec.model_cls(BaseModelArgs())
model_parts = [model]
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
optimizers = new_spec.build_optimizers_fn(model_parts, None, None, None)
assert optimizers.optimizers[0].__class__.__name__ == "Adam"
batch = torch.randn(8, 8)
model(batch).sum().backward()
Expand Down
10 changes: 6 additions & 4 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
)
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader

from torchtitan.components.dataloader import BaseDataLoader
from torchtitan.components.ft import FTManager
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
Expand Down Expand Up @@ -180,17 +180,19 @@ class CheckpointManager:

def __init__(
self,
dataloader: DataLoader,
dataloader: BaseDataLoader | None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers
I kept this field to be required and value to be optional -- the code still works. I didn't make it completely optional with a None default because that would require more if-else in this file.

I think it won't look too bad when I specify dataloader=None in forge_engine.py. Let me know if it's ok to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up, I think this is a reasonable compromise

model_parts: list[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: LRSchedulersContainer,
states: dict[str, Any],
job_config: JobConfig,
ft_manager: FTManager,
ft_manager: FTManager | None = None,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.ft_manager = ft_manager.manager if ft_manager.enabled else None
self.ft_manager = (
ft_manager.manager if ft_manager and ft_manager.enabled else None
)

if self.ft_manager:
optimizers.init_cache_state_dict()
Expand Down
2 changes: 0 additions & 2 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import DeviceMesh
from torch.optim import Optimizer

from torchtitan.components.ft import FTManager, has_torchft
Expand Down Expand Up @@ -244,7 +243,6 @@ def build_optimizers(
model_parts: list[nn.Module],
job_config: JobConfig,
parallel_dims: ParallelDims,
world_mesh: DeviceMesh,
ft_manager: FTManager,
) -> OptimizersContainer:
"""Create a OptimizersContainer for the given model parts and job config.
Expand Down
4 changes: 1 addition & 3 deletions torchtitan/components/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@

import json

import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Optional, Union

from tokenizers import AddedToken, Tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.tools.logging import logger
from typing_extensions import override

logger = logging.getLogger(__name__)


class BaseTokenizer(ABC):
# base tokenizer interface, for typing purpose mainly
Expand Down
16 changes: 8 additions & 8 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,12 @@ def __init__(
dp_rank: int,
tokenizer: BaseTokenizer,
parallel_dims: ParallelDims,
world_mesh: torch.distributed.DeviceMesh,
loss_fn: LossFunction,
validation_context: Generator[None, None, None],
maybe_enable_amp: Generator[None, None, None],
):
self.job_config = job_config
self.parallel_dims = parallel_dims
self.world_mesh = world_mesh
self.loss_fn = loss_fn
self.validation_dataloader = build_hf_validation_dataloader(
job_config=job_config,
Expand All @@ -78,6 +76,8 @@ def validate(
model = model_parts[0]
model.eval()

parallel_dims = self.parallel_dims

accumulated_losses = []
device_type = utils.device_type
num_steps = 0
Expand All @@ -96,13 +96,13 @@ def validate(

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=self.world_mesh["cp"],
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if self.parallel_dims.cp_enabled
if parallel_dims.cp_enabled
else None
)

Expand All @@ -119,8 +119,10 @@ def validate(
# Compute average loss
loss = torch.sum(torch.stack(accumulated_losses))
loss /= num_steps
if self.parallel_dims.dp_cp_enabled:
global_avg_loss = dist_utils.dist_mean(loss, self.world_mesh["dp_cp"])
if parallel_dims.dp_cp_enabled:
global_avg_loss = dist_utils.dist_mean(
loss, parallel_dims.world_mesh["dp_cp"]
)
else:
global_avg_loss = loss

Expand All @@ -144,7 +146,6 @@ def build_validator(
dp_rank: int,
tokenizer: BaseTokenizer,
parallel_dims: ParallelDims,
world_mesh: torch.distributed.DeviceMesh,
loss_fn: LossFunction,
validation_context: Generator[None, None, None],
maybe_enable_amp: Generator[None, None, None],
Expand All @@ -156,7 +157,6 @@ def build_validator(
dp_rank=dp_rank,
tokenizer=tokenizer,
parallel_dims=parallel_dims,
world_mesh=world_mesh,
loss_fn=loss_fn,
validation_context=validation_context,
maybe_enable_amp=maybe_enable_amp,
Expand Down
Loading
Loading