|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import torch |
| 8 | + |
7 | 9 | from torchtitan.config_manager import ConfigManager
|
8 | 10 | from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
9 |
| -from torchtitan.tools.profiling import ( |
10 |
| - maybe_enable_memory_snapshot, |
11 |
| - maybe_enable_profiling, |
12 |
| -) |
13 | 11 |
|
14 | 12 |
|
15 | 13 | class TestFluxDataLoader:
|
16 | 14 | def test_load_dataset(self):
|
17 |
| - for dataset_name in ["cc12m-test"]: |
18 |
| - self._test_flux_dataloader(dataset_name) |
| 15 | + # The test checks for the correct tensor shapes during the first num_steps |
| 16 | + # The next num_steps ensure the loaded from checkpoint dataloader generates tokens and labels correctly |
| 17 | + for world_size in [2, 4]: |
| 18 | + for rank in range(world_size): |
| 19 | + dataset_name = "cc12m-test" |
| 20 | + batch_size = 4 |
| 21 | + |
| 22 | + num_steps = 10 |
| 23 | + |
| 24 | + path = "torchtitan.experiments.flux.job_config" |
| 25 | + config_manager = ConfigManager() |
| 26 | + config = config_manager.parse_args( |
| 27 | + [ |
| 28 | + f"--experimental.custom_args_module={path}", |
| 29 | + "--training.img_size", |
| 30 | + str(256), |
| 31 | + "--training.dataset", |
| 32 | + dataset_name, |
| 33 | + "--training.local_batch_size", |
| 34 | + str(batch_size), |
| 35 | + "--training.seed", |
| 36 | + "0", |
| 37 | + "--training.classifer_free_guidance_prob", |
| 38 | + "0.447", |
| 39 | + "--encoder.t5_encoder", |
| 40 | + "google/t5-v1_1-xxl", |
| 41 | + "--encoder.clip_encoder", |
| 42 | + "openai/clip-vit-large-patch14", |
| 43 | + # "--encoder.max_t5_encoding_len", |
| 44 | + # "512", |
| 45 | + ] |
| 46 | + ) |
19 | 47 |
|
20 |
| - def _test_flux_dataloader(self, dataset_name): |
21 |
| - batch_size = 4 |
22 |
| - world_size = 4 |
23 |
| - rank = 0 |
| 48 | + dl = build_flux_dataloader( |
| 49 | + dp_world_size=world_size, |
| 50 | + dp_rank=rank, |
| 51 | + job_config=config, |
| 52 | + tokenizer=None, |
| 53 | + infinite=True, |
| 54 | + ) |
24 | 55 |
|
25 |
| - num_steps = 10 |
| 56 | + it = iter(dl) |
26 | 57 |
|
27 |
| - path = "torchtitan.experiments.flux.job_config" |
28 |
| - config_manager = ConfigManager() |
29 |
| - config = config_manager.parse_args( |
30 |
| - [ |
31 |
| - f"--experimental.custom_args_module={path}", |
32 |
| - # Profiling options |
33 |
| - # "--profiling.enable_profiling", |
34 |
| - # "--profiling.profile_freq", |
35 |
| - # "5", |
36 |
| - # "--profiling.enable_memory_snapshot", |
37 |
| - # "--profiling.save_memory_snapshot_folder", |
38 |
| - # "memory_snapshot_flux", |
39 |
| - "--training.img_size", |
40 |
| - str(256), |
41 |
| - "--training.dataset", |
42 |
| - dataset_name, |
43 |
| - "--training.local_batch_size", |
44 |
| - str(batch_size), |
45 |
| - "--training.seed", |
46 |
| - "0", |
47 |
| - "--training.classifer_free_guidance_prob", |
48 |
| - "0.447", |
49 |
| - "--encoder.t5_encoder", |
50 |
| - "google/t5-v1_1-small", |
51 |
| - "--encoder.clip_encoder", |
52 |
| - "openai/clip-vit-large-patch14", |
53 |
| - "--encoder.max_t5_encoding_len", |
54 |
| - "512", |
55 |
| - ] |
56 |
| - ) |
| 58 | + for i in range(0, num_steps): |
| 59 | + input_data, labels = next(it) |
57 | 60 |
|
58 |
| - with maybe_enable_profiling( |
59 |
| - config, global_step=0 |
60 |
| - ) as torch_profiler, maybe_enable_memory_snapshot( |
61 |
| - config, global_step=0 |
62 |
| - ) as memory_profiler: |
63 |
| - dl = self._build_dataloader( |
64 |
| - config, |
65 |
| - world_size, |
66 |
| - rank, |
67 |
| - ) |
68 |
| - dl = iter(dl) |
| 61 | + assert len(input_data) == 2 # (clip_encodings, t5_encodings) |
| 62 | + assert labels.shape == (batch_size, 3, 256, 256) |
| 63 | + assert input_data["clip_tokens"].shape == ( |
| 64 | + batch_size, |
| 65 | + 1, |
| 66 | + 77, |
| 67 | + ) |
| 68 | + assert input_data["t5_tokens"].shape == ( |
| 69 | + batch_size, |
| 70 | + 1, |
| 71 | + 256, |
| 72 | + ) |
69 | 73 |
|
70 |
| - for i in range(0, num_steps): |
71 |
| - input_data, labels = next(dl) |
72 |
| - if torch_profiler: |
73 |
| - torch_profiler.step() |
74 |
| - if memory_profiler: |
75 |
| - memory_profiler.step() |
| 74 | + state = dl.state_dict() |
76 | 75 |
|
77 |
| - assert len(input_data) == 2 # (clip_encodings, t5_encodings) |
78 |
| - assert labels.shape == (batch_size, 3, 256, 256) |
79 |
| - # assert input_data["clip_tokens"].shape[0] == batch_size |
80 |
| - # assert input_data["t5_tokens"].shape == (batch_size, 512, 512) |
| 76 | + # Create new dataloader, restore checkpoint, and check if next data yielded is the same as above |
| 77 | + dl_resumed = build_flux_dataloader( |
| 78 | + dp_world_size=world_size, |
| 79 | + dp_rank=rank, |
| 80 | + job_config=config, |
| 81 | + tokenizer=None, |
| 82 | + infinite=True, |
| 83 | + ) |
| 84 | + dl_resumed.load_state_dict(state) |
| 85 | + it_resumed = iter(dl_resumed) |
81 | 86 |
|
82 |
| - if torch_profiler: |
83 |
| - torch_profiler.step() |
84 |
| - if memory_profiler: |
85 |
| - memory_profiler.step(exit_ctx=True) |
| 87 | + for i in range(num_steps): |
| 88 | + # Set torch manual seed before each dataloader iteration to ensure consistent randomness |
| 89 | + # across dataloaders for testing purposes. |
| 90 | + torch.manual_seed(i) |
| 91 | + expected_input_ids, expected_labels = next(it) |
| 92 | + torch.manual_seed(i) |
| 93 | + input_ids, labels = next(it_resumed) |
86 | 94 |
|
87 |
| - def _build_dataloader( |
88 |
| - self, |
89 |
| - job_config, |
90 |
| - world_size, |
91 |
| - rank, |
92 |
| - ): |
93 |
| - return build_flux_dataloader( |
94 |
| - dp_world_size=world_size, |
95 |
| - dp_rank=rank, |
96 |
| - job_config=job_config, |
97 |
| - tokenizer=None, |
98 |
| - infinite=True, |
99 |
| - ) |
| 95 | + assert torch.equal( |
| 96 | + input_ids["clip_tokens"], expected_input_ids["clip_tokens"] |
| 97 | + ) |
| 98 | + assert torch.equal( |
| 99 | + input_ids["t5_tokens"], expected_input_ids["t5_tokens"] |
| 100 | + ) |
| 101 | + assert torch.equal(labels, expected_labels) |
0 commit comments