Skip to content

Commit 8518306

Browse files
authored
Migrating flux checkpointing to hf api (#1377)
Migrating the checkpointing of the flux dataset to use huggingface's faster api for IterableDatasets.
1 parent f0ce21b commit 8518306

File tree

8 files changed

+75
-44
lines changed

8 files changed

+75
-44
lines changed

torchtitan/experiments/flux/dataset/flux_dataset.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,13 @@ def __init__(
194194
self._all_samples: list[dict[str, Any]] = []
195195

196196
def _get_data_iter(self):
197-
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
198-
return iter([])
197+
if isinstance(self._data, Dataset):
198+
if self._sample_idx == len(self._data):
199+
return iter([])
200+
else:
201+
return iter(self._data.skip(self._sample_idx))
199202

200-
it = iter(self._data)
201-
for _ in range(self._sample_idx):
202-
next(it)
203-
return it
203+
return iter(self._data)
204204

205205
def __iter__(self):
206206
dataset_iterator = self._get_data_iter()
@@ -223,8 +223,13 @@ def __iter__(self):
223223
else:
224224
# Reset offset for the next iteration if infinite
225225
self._sample_idx = 0
226-
logger.info(f"Dataset {self.dataset_name} is being re-looped.")
226+
logger.warning(f"Dataset {self.dataset_name} is being re-looped.")
227227
dataset_iterator = self._get_data_iter()
228+
if not isinstance(self._data, Dataset):
229+
if hasattr(self._data, "set_epoch") and hasattr(
230+
self._data, "epoch"
231+
):
232+
self._data.set_epoch(self._data.epoch + 1)
228233
continue
229234

230235
# Use the dataset-specific preprocessor
@@ -244,7 +249,7 @@ def __iter__(self):
244249

245250
# Classifier-free guidance: Replace some of the strings with empty strings.
246251
# Distinct random seed is initialized at the beginning of training for each FSDP rank.
247-
dropout_prob = self.job_config.training.classifer_free_guidance_prob
252+
dropout_prob = self.job_config.training.classifier_free_guidance_prob
248253
if dropout_prob > 0.0:
249254
if torch.rand(1).item() < dropout_prob:
250255
sample_dict["t5_tokens"] = self._t5_empty_token
@@ -258,12 +263,17 @@ def __iter__(self):
258263
yield sample_dict, labels
259264

260265
def load_state_dict(self, state_dict):
261-
self._sample_idx = state_dict["sample_idx"]
266+
if isinstance(self._data, Dataset):
267+
self._sample_idx = state_dict["sample_idx"]
268+
else:
269+
assert "data" in state_dict
270+
self._data.load_state_dict(state_dict["data"])
262271

263272
def state_dict(self):
264-
return {
265-
"sample_idx": self._sample_idx,
266-
}
273+
if isinstance(self._data, Dataset):
274+
return {"sample_idx": self._sample_idx}
275+
else:
276+
return {"data": self._data.state_dict()}
267277

268278

269279
def build_flux_dataloader(

torchtitan/experiments/flux/job_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@dataclass
1111
class Training:
12-
classifer_free_guidance_prob: float = 0.0
12+
classifier_free_guidance_prob: float = 0.0
1313
"""Classifier-free guidance with probability `p` to dropout each text encoding independently.
1414
If `n` text encoders are used, the unconditional model is trained in `p ^ n` of all steps.
1515
For example, if `n = 2` and `p = 0.447`, the unconditional model is trained in 20% of all steps"""
@@ -37,7 +37,7 @@ class Encoder:
3737

3838
@dataclass
3939
class Eval:
40-
enable_classifer_free_guidance: bool = False
40+
enable_classifier_free_guidance: bool = False
4141
"""Whether to use classifier-free guidance during sampling"""
4242
classifier_free_guidance_scale: float = 5.0
4343
"""Classifier-free guidance scale when sampling"""

torchtitan/experiments/flux/sampling.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def generate_image(
9393
img_height = 16 * (job_config.training.img_size // 16)
9494
img_width = 16 * (job_config.training.img_size // 16)
9595

96-
enable_classifer_free_guidance = job_config.eval.enable_classifer_free_guidance
96+
enable_classifier_free_guidance = job_config.eval.enable_classifier_free_guidance
9797

9898
# Tokenize the prompt. Unsqueeze to add a batch dimension.
9999
clip_tokens = clip_tokenizer.encode(prompt).unsqueeze(0)
@@ -111,7 +111,7 @@ def generate_image(
111111
},
112112
)
113113

114-
if enable_classifer_free_guidance:
114+
if enable_classifier_free_guidance:
115115
empty_clip_tokens = clip_tokenizer.encode("").unsqueeze(0)
116116
empty_t5_tokens = t5_tokenizer.encode("").unsqueeze(0)
117117
empty_batch = preprocess_data(
@@ -135,12 +135,12 @@ def generate_image(
135135
denoising_steps=job_config.eval.denoising_steps,
136136
clip_encodings=batch["clip_encodings"],
137137
t5_encodings=batch["t5_encodings"],
138-
enable_classifer_free_guidance=enable_classifer_free_guidance,
138+
enable_classifier_free_guidance=enable_classifier_free_guidance,
139139
empty_t5_encodings=(
140-
empty_batch["t5_encodings"] if enable_classifer_free_guidance else None
140+
empty_batch["t5_encodings"] if enable_classifier_free_guidance else None
141141
),
142142
empty_clip_encodings=(
143-
empty_batch["clip_encodings"] if enable_classifer_free_guidance else None
143+
empty_batch["clip_encodings"] if enable_classifier_free_guidance else None
144144
),
145145
classifier_free_guidance_scale=job_config.eval.classifier_free_guidance_scale,
146146
)
@@ -158,7 +158,7 @@ def denoise(
158158
denoising_steps: int,
159159
clip_encodings: torch.Tensor,
160160
t5_encodings: torch.Tensor,
161-
enable_classifer_free_guidance: bool = False,
161+
enable_classifier_free_guidance: bool = False,
162162
empty_t5_encodings: torch.Tensor | None = None,
163163
empty_clip_encodings: torch.Tensor | None = None,
164164
classifier_free_guidance_scale: float | None = None,
@@ -181,7 +181,7 @@ def denoise(
181181
).to(latents)
182182
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
183183

184-
if enable_classifer_free_guidance:
184+
if enable_classifier_free_guidance:
185185
latents = torch.cat([latents, latents], dim=0)
186186
t5_encodings = torch.cat([empty_t5_encodings, t5_encodings], dim=0)
187187
clip_encodings = torch.cat([empty_clip_encodings, clip_encodings], dim=0)
@@ -200,7 +200,7 @@ def denoise(
200200
y=clip_encodings,
201201
timesteps=t_vec,
202202
)
203-
if enable_classifer_free_guidance:
203+
if enable_classifier_free_guidance:
204204
pred_u, pred_c = pred.chunk(2)
205205
pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u)
206206

torchtitan/experiments/flux/tests/test_generate_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_generate_image(self):
4646
"./outputs",
4747
"--training.seed",
4848
"0",
49-
"--training.classifer_free_guidance_prob",
49+
"--training.classifier_free_guidance_prob",
5050
"0.447",
5151
"--encoder.t5_encoder",
5252
"google/t5-v1_1-base",
@@ -59,7 +59,7 @@ def test_generate_image(self):
5959
# eval params
6060
"--eval.denoising_steps",
6161
str(num_steps),
62-
"--eval.enable_classifer_free_guidance",
62+
"--eval.enable_classifier_free_guidance",
6363
"--eval.classifier_free_guidance_scale",
6464
str(classifier_free_guidance_scale),
6565
"--eval.save_img_folder",

torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,47 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import unittest
8+
79
import torch
810

11+
from datasets import load_dataset
12+
913
from torchtitan.config_manager import ConfigManager
10-
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
14+
from torchtitan.experiments.flux.dataset.flux_dataset import (
15+
_cc12m_wds_data_processor,
16+
build_flux_dataloader,
17+
DATASETS,
18+
TextToImageDatasetConfig,
19+
)
20+
1121

22+
class TestFluxDataLoader(unittest.TestCase):
23+
def setUp(self):
24+
DATASETS["cc12m-test-iterable"] = TextToImageDatasetConfig(
25+
path="torchtitan/experiments/flux/tests/assets/cc12m_test",
26+
loader=lambda path: load_dataset(
27+
path, split="train", data_files={"train": "*tar"}
28+
).to_iterable_dataset(num_shards=4),
29+
data_processor=_cc12m_wds_data_processor,
30+
)
31+
32+
def tearDown(self):
33+
del DATASETS["cc12m-test-iterable"]
1234

13-
class TestFluxDataLoader:
1435
def test_load_dataset(self):
1536
# The test checks for the correct tensor shapes during the first num_steps
1637
# The next num_steps ensure the loaded from checkpoint dataloader generates tokens and labels correctly
17-
for world_size in [2, 4]:
38+
for world_size in [2]:
1839
for rank in range(world_size):
19-
dataset_name = "cc12m-test"
20-
batch_size = 4
40+
dataset_name = "cc12m-test-iterable"
41+
batch_size = 1
42+
43+
num_steps = 15
2144

22-
num_steps = 10
45+
# TODO: if num_steps * batch_size * world_size is larger than the number of samples
46+
# in the dataset, then the test will fail, due to huggingface's
47+
# non-resumption when checkpointing after the first epoch
2348

2449
path = "torchtitan.experiments.flux.job_config"
2550
config_manager = ConfigManager()
@@ -32,16 +57,12 @@ def test_load_dataset(self):
3257
dataset_name,
3358
"--training.local_batch_size",
3459
str(batch_size),
35-
"--training.seed",
36-
"0",
37-
"--training.classifer_free_guidance_prob",
60+
"--training.classifier_free_guidance_prob",
3861
"0.447",
3962
"--encoder.t5_encoder",
4063
"google/t5-v1_1-xxl",
4164
"--encoder.clip_encoder",
4265
"openai/clip-vit-large-patch14",
43-
# "--encoder.max_t5_encoding_len",
44-
# "512",
4566
]
4667
)
4768

torchtitan/experiments/flux/train_configs/debug_model.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ max_norm = 2.0 # grad norm clipping
3838
steps = 10
3939
compile = false
4040
dataset = "cc12m-test"
41-
classifer_free_guidance_prob = 0.447
41+
classifier_free_guidance_prob = 0.447
4242
img_size = 256
4343

4444
[encoder]
@@ -48,8 +48,8 @@ max_t5_encoding_len = 256
4848
autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
4949

5050
[eval]
51-
enable_classifer_free_guidance = true
52-
classifer_free_guidance_scale = 5.0
51+
enable_classifier_free_guidance = true
52+
classifier_free_guidance_scale = 5.0
5353
denoising_steps = 4
5454
save_img_folder = "img"
5555
eval_freq = 5

torchtitan/experiments/flux/train_configs/flux_dev_model.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
3737
steps = 30_000
3838
compile = false
3939
dataset = "cc12m-wds"
40-
classifer_free_guidance_prob = 0.447
40+
classifier_free_guidance_prob = 0.447
4141
img_size = 256
4242

4343
[encoder]
@@ -47,8 +47,8 @@ max_t5_encoding_len = 512
4747
autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
4848

4949
[eval]
50-
enable_classifer_free_guidance = true
51-
classifer_free_guidance_scale = 5.0
50+
enable_classifier_free_guidance = true
51+
classifier_free_guidance_scale = 5.0
5252
denoising_steps = 50
5353
save_img_folder = "img"
5454
eval_freq = 1000

torchtitan/experiments/flux/train_configs/flux_schnell_model.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
3737
steps = 30_000
3838
compile = false
3939
dataset = "cc12m-wds"
40-
classifer_free_guidance_prob = 0.447
40+
classifier_free_guidance_prob = 0.447
4141
img_size = 256
4242

4343
[encoder]
@@ -47,8 +47,8 @@ max_t5_encoding_len = 256
4747
autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
4848

4949
[eval]
50-
enable_classifer_free_guidance = true
51-
classifer_free_guidance_scale = 5.0
50+
enable_classifier_free_guidance = true
51+
classifier_free_guidance_scale = 5.0
5252
denoising_steps = 50
5353
save_img_folder = "img"
5454
eval_freq = 1000

0 commit comments

Comments
 (0)