Skip to content

Commit 5d4cc9a

Browse files
authored
unit test for flux_dataset dataloader checkpointing (#1346)
Adds a unit test for loading flux dataset from a checkpoint. Creates the new dataloader from a checkpoint then ensures that the next generated labels and tokens are the same in both dataloaders, starting from the checkpoint.
1 parent 05d3f7c commit 5d4cc9a

File tree

1 file changed

+81
-79
lines changed

1 file changed

+81
-79
lines changed

torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py

Lines changed: 81 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,96 +4,98 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import torch
8+
79
from torchtitan.config_manager import ConfigManager
810
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
9-
from torchtitan.tools.profiling import (
10-
maybe_enable_memory_snapshot,
11-
maybe_enable_profiling,
12-
)
1311

1412

1513
class TestFluxDataLoader:
1614
def test_load_dataset(self):
17-
for dataset_name in ["cc12m-test"]:
18-
self._test_flux_dataloader(dataset_name)
15+
# The test checks for the correct tensor shapes during the first num_steps
16+
# The next num_steps ensure the loaded from checkpoint dataloader generates tokens and labels correctly
17+
for world_size in [2, 4]:
18+
for rank in range(world_size):
19+
dataset_name = "cc12m-test"
20+
batch_size = 4
21+
22+
num_steps = 10
23+
24+
path = "torchtitan.experiments.flux.job_config"
25+
config_manager = ConfigManager()
26+
config = config_manager.parse_args(
27+
[
28+
f"--experimental.custom_args_module={path}",
29+
"--training.img_size",
30+
str(256),
31+
"--training.dataset",
32+
dataset_name,
33+
"--training.local_batch_size",
34+
str(batch_size),
35+
"--training.seed",
36+
"0",
37+
"--training.classifer_free_guidance_prob",
38+
"0.447",
39+
"--encoder.t5_encoder",
40+
"google/t5-v1_1-xxl",
41+
"--encoder.clip_encoder",
42+
"openai/clip-vit-large-patch14",
43+
# "--encoder.max_t5_encoding_len",
44+
# "512",
45+
]
46+
)
1947

20-
def _test_flux_dataloader(self, dataset_name):
21-
batch_size = 4
22-
world_size = 4
23-
rank = 0
48+
dl = build_flux_dataloader(
49+
dp_world_size=world_size,
50+
dp_rank=rank,
51+
job_config=config,
52+
tokenizer=None,
53+
infinite=True,
54+
)
2455

25-
num_steps = 10
56+
it = iter(dl)
2657

27-
path = "torchtitan.experiments.flux.job_config"
28-
config_manager = ConfigManager()
29-
config = config_manager.parse_args(
30-
[
31-
f"--experimental.custom_args_module={path}",
32-
# Profiling options
33-
# "--profiling.enable_profiling",
34-
# "--profiling.profile_freq",
35-
# "5",
36-
# "--profiling.enable_memory_snapshot",
37-
# "--profiling.save_memory_snapshot_folder",
38-
# "memory_snapshot_flux",
39-
"--training.img_size",
40-
str(256),
41-
"--training.dataset",
42-
dataset_name,
43-
"--training.local_batch_size",
44-
str(batch_size),
45-
"--training.seed",
46-
"0",
47-
"--training.classifer_free_guidance_prob",
48-
"0.447",
49-
"--encoder.t5_encoder",
50-
"google/t5-v1_1-small",
51-
"--encoder.clip_encoder",
52-
"openai/clip-vit-large-patch14",
53-
"--encoder.max_t5_encoding_len",
54-
"512",
55-
]
56-
)
58+
for i in range(0, num_steps):
59+
input_data, labels = next(it)
5760

58-
with maybe_enable_profiling(
59-
config, global_step=0
60-
) as torch_profiler, maybe_enable_memory_snapshot(
61-
config, global_step=0
62-
) as memory_profiler:
63-
dl = self._build_dataloader(
64-
config,
65-
world_size,
66-
rank,
67-
)
68-
dl = iter(dl)
61+
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
62+
assert labels.shape == (batch_size, 3, 256, 256)
63+
assert input_data["clip_tokens"].shape == (
64+
batch_size,
65+
1,
66+
77,
67+
)
68+
assert input_data["t5_tokens"].shape == (
69+
batch_size,
70+
1,
71+
256,
72+
)
6973

70-
for i in range(0, num_steps):
71-
input_data, labels = next(dl)
72-
if torch_profiler:
73-
torch_profiler.step()
74-
if memory_profiler:
75-
memory_profiler.step()
74+
state = dl.state_dict()
7675

77-
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
78-
assert labels.shape == (batch_size, 3, 256, 256)
79-
# assert input_data["clip_tokens"].shape[0] == batch_size
80-
# assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
76+
# Create new dataloader, restore checkpoint, and check if next data yielded is the same as above
77+
dl_resumed = build_flux_dataloader(
78+
dp_world_size=world_size,
79+
dp_rank=rank,
80+
job_config=config,
81+
tokenizer=None,
82+
infinite=True,
83+
)
84+
dl_resumed.load_state_dict(state)
85+
it_resumed = iter(dl_resumed)
8186

82-
if torch_profiler:
83-
torch_profiler.step()
84-
if memory_profiler:
85-
memory_profiler.step(exit_ctx=True)
87+
for i in range(num_steps):
88+
# Set torch manual seed before each dataloader iteration to ensure consistent randomness
89+
# across dataloaders for testing purposes.
90+
torch.manual_seed(i)
91+
expected_input_ids, expected_labels = next(it)
92+
torch.manual_seed(i)
93+
input_ids, labels = next(it_resumed)
8694

87-
def _build_dataloader(
88-
self,
89-
job_config,
90-
world_size,
91-
rank,
92-
):
93-
return build_flux_dataloader(
94-
dp_world_size=world_size,
95-
dp_rank=rank,
96-
job_config=job_config,
97-
tokenizer=None,
98-
infinite=True,
99-
)
95+
assert torch.equal(
96+
input_ids["clip_tokens"], expected_input_ids["clip_tokens"]
97+
)
98+
assert torch.equal(
99+
input_ids["t5_tokens"], expected_input_ids["t5_tokens"]
100+
)
101+
assert torch.equal(labels, expected_labels)

0 commit comments

Comments
 (0)