Skip to content

Commit 1c5be28

Browse files
authored
refactor to make model folder structure consistent (#1298)
This PR does the following refactoring: 1. remove `from_model_args` in `ModelProtocal`; add `__init__` and `init_weights` to `ModelProtocol` 2. make the structure consistent for each model folder: - model - args.py - model.py - infra - parallelize.py - pipeline.py (optional) Will publish guidelines on adding new models soon.
1 parent 820504e commit 1c5be28

File tree

24 files changed

+177
-182
lines changed

24 files changed

+177
-182
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ We report [performance](benchmarks/llama3_h100_202412_torchtitan.md) on up to 51
7777

7878
You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
7979
* [torchtitan/train.py](torchtitan/train.py) - the main training loop and high-level setup code
80-
* [torchtitan/models/llama3/model.py](torchtitan/models/llama3/model.py) - the Llama 3.1 model definition
81-
* [torchtitan/models/llama3/parallelize_llama.py](torchtitan/models/llama3/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model
82-
* [torchtitan/models/llama3/pipeline_llama.py](torchtitan/models/llama3/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
80+
* [torchtitan/models/llama3/model/model.py](torchtitan/models/llama3/model/model.py) - the Llama 3.1 model definition
81+
* [torchtitan/models/llama3/infra/parallelize.py](torchtitan/models/llama3/infra/parallelize.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model
82+
* [torchtitan/models/llama3/infra/pipeline.py](torchtitan/models/llama3/infra/pipeline.py) - helpers for applying Pipeline Parallel to the model
8383
* [torchtitan/components/checkpoint.py](torchtitan/components/checkpoint.py) - utils for saving/loading distributed checkpoints
8484
* [torchtitan/components/quantization/float8.py](torchtitan/components/quantization/float8.py) - utils for applying Float8 techniques
8585

scripts/estimate/estimation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def estimate_memory(job_config: JobConfig):
9999
f"Building {train_spec.name} {job_config.model.flavor} with {model_args}"
100100
)
101101
with torch.device("meta"):
102-
model = model_cls.from_model_args(model_args)
102+
model = model_cls(model_args)
103103

104104
# Build the collection of model converters. No-op if `model.converters` empty
105105
model_converters = build_model_converters(job_config, parallel_dims)

scripts/generate/test_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_generate(
113113
init_device = "meta" if world_size > 1 else device
114114
with torch.device(init_device):
115115
logger.info(f"Init model on init_device: {init_device}")
116-
model = model_cls.from_model_args(model_args)
116+
model = model_cls(model_args)
117117

118118
world_mesh = None
119119
# Init distributed env

tests/unit_tests/test_train_spec.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@
2626
)
2727

2828

29-
class FakeModel(ModelProtocol):
30-
@classmethod
31-
def from_model_args(cls, args: BaseModelArgs) -> nn.Module:
32-
return nn.Linear(8, 8)
29+
class FakeModel(nn.Module, ModelProtocol):
30+
def __init__(self, model_args: BaseModelArgs) -> None:
31+
super().__init__()
32+
self.linear = nn.Linear(8, 8)
33+
34+
def forward(self, x: torch.Tensor) -> torch.Tensor:
35+
return self.linear(x)
36+
37+
def init_weights(self, buffer_device: torch.device | None = None) -> None:
38+
nn.init.normal_(self.linear.weight, mean=0.0, std=0.02)
3339

3440

3541
def fake_build_optimizers(
@@ -117,7 +123,7 @@ def my_build_optimizer_fn(
117123

118124
apply_to_train_specs(register_optimizer_hook_to_spec)
119125

120-
model = new_spec.cls.from_model_args(BaseModelArgs())
126+
model = new_spec.cls(BaseModelArgs())
121127
model_parts = [model]
122128
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
123129
assert optimizers.optimizers[0].__class__.__name__ == "Adam"

torchtitan/experiments/flux/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99

1010
from torchtitan.components.lr_scheduler import build_lr_schedulers
1111
from torchtitan.components.optimizer import build_optimizers
12-
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
13-
from torchtitan.experiments.flux.loss import build_mse_loss
14-
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
15-
from torchtitan.experiments.flux.parallelize_flux import parallelize_flux
1612
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1713

18-
from .model.model import FluxModel, FluxModelArgs
14+
from .dataset.flux_dataset import build_flux_dataloader
15+
from .infra.parallelize import parallelize_flux
16+
from .loss import build_mse_loss
17+
from .model.args import FluxModelArgs
18+
from .model.autoencoder import AutoEncoderParams
19+
from .model.model import FluxModel
1920

2021
__all__ = [
2122
"FluxModelArgs",
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
9+
from torch import nn
10+
11+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
12+
13+
from torchtitan.protocols.train_spec import BaseModelArgs
14+
from torchtitan.tools.logging import logger
15+
16+
17+
@dataclass
18+
class FluxModelArgs(BaseModelArgs):
19+
in_channels: int = 64
20+
out_channels: int = 64
21+
vec_in_dim: int = 768
22+
context_in_dim: int = 512
23+
hidden_size: int = 3072
24+
mlp_ratio: float = 4.0
25+
num_heads: int = 24
26+
depth: int = 19
27+
depth_single_blocks: int = 38
28+
axes_dim: tuple = (16, 56, 56)
29+
theta: int = 10_000
30+
qkv_bias: bool = True
31+
autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
32+
33+
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
34+
# TODO(jianiw): Add the number of flops for the autoencoder
35+
nparams = sum(p.numel() for p in model.parameters())
36+
logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
37+
return nparams, 1

torchtitan/experiments/flux/model/model.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from dataclasses import dataclass, field
8-
97
import torch
10-
118
from torch import nn, Tensor
129

13-
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
1410
from torchtitan.experiments.flux.model.layers import (
1511
DoubleStreamBlock,
1612
EmbedND,
@@ -20,31 +16,9 @@
2016
timestep_embedding,
2117
)
2218

23-
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
24-
from torchtitan.tools.logging import logger
25-
26-
27-
@dataclass
28-
class FluxModelArgs(BaseModelArgs):
29-
in_channels: int = 64
30-
out_channels: int = 64
31-
vec_in_dim: int = 768
32-
context_in_dim: int = 512
33-
hidden_size: int = 3072
34-
mlp_ratio: float = 4.0
35-
num_heads: int = 24
36-
depth: int = 19
37-
depth_single_blocks: int = 38
38-
axes_dim: tuple = (16, 56, 56)
39-
theta: int = 10_000
40-
qkv_bias: bool = True
41-
autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
42-
43-
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
44-
# TODO(jianiw): Add the number of flops for the autoencoder
45-
nparams = sum(p.numel() for p in model.parameters())
46-
logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
47-
return nparams, 1
19+
from torchtitan.protocols.train_spec import ModelProtocol
20+
21+
from .args import FluxModelArgs
4822

4923

5024
class FluxModel(nn.Module, ModelProtocol):
@@ -159,17 +133,3 @@ def forward(
159133

160134
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
161135
return img
162-
163-
@classmethod
164-
def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
165-
"""
166-
Initialize a Flux model from a FluxModelArgs object.
167-
168-
Args:
169-
model_args (FluxModelArgs): Model configuration arguments.
170-
171-
Returns:
172-
FluxModel: FluxModel model.
173-
174-
"""
175-
return cls(model_args)

torchtitan/experiments/flux/sampling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@
1616

1717
from torchtitan.components.tokenizer import Tokenizer
1818
from torchtitan.config_manager import JobConfig
19-
from torchtitan.experiments.flux.model.autoencoder import AutoEncoder
19+
from torchtitan.tools.logging import logger
2020

21-
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
22-
from torchtitan.experiments.flux.model.model import FluxModel
23-
from torchtitan.experiments.flux.utils import (
21+
from .model.autoencoder import AutoEncoder
22+
from .model.hf_embedder import FluxEmbedder
23+
from .model.model import FluxModel
24+
from .utils import (
2425
create_position_encoding_for_latents,
2526
generate_noise_latent,
2627
pack_latents,
2728
preprocess_data,
2829
unpack_latents,
2930
)
30-
from torchtitan.tools.logging import logger
3131

3232

3333
# ----------------------------------------

torchtitan/experiments/flux/tests/test_generate_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_generate_image(self):
3737
classifier_free_guidance_scale = 5.0
3838

3939
# Contracting JobConfig
40-
path = "torchtitan.experiments.flux.flux_argparser"
40+
path = "torchtitan.experiments.flux.job_config"
4141
config_manager = ConfigManager()
4242
config = config_manager.parse_args(
4343
[

0 commit comments

Comments
 (0)