Skip to content

Commit 2eb15d1

Browse files
committed
add_example_diffusers_test
1 parent b17f5c5 commit 2eb15d1

25 files changed

+2602
-37
lines changed
102 KB
Loading
13.7 KB
Loading

examples/diffusers/cogview/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cd mindone
2929
pip install -e .
3030
# NOTE: transformers requires >=4.46.0
3131
32-
cd examples/cogview
32+
cd examples/diffusers/cogview
3333
```
3434

3535

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
sys.path.append("..")
22+
from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
23+
24+
ExamplesTests._launch_args = ["python"]
25+
26+
logging.basicConfig(level=logging.DEBUG)
27+
28+
logger = logging.getLogger()
29+
stream_handler = logging.StreamHandler(sys.stdout)
30+
logger.addHandler(stream_handler)
31+
32+
33+
class ControlNet(ExamplesTests):
34+
def test_controlnet_checkpointing_checkpoints_total_limit(self):
35+
with tempfile.TemporaryDirectory() as tmpdir:
36+
test_args = f"""
37+
examples/diffusers/controlnet/train_controlnet.py
38+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
39+
--revision refs/pr/4
40+
--dataset_name=hf-internal-testing/fill10
41+
--output_dir={tmpdir}
42+
--resolution=64
43+
--train_batch_size=1
44+
--gradient_accumulation_steps=1
45+
--max_train_steps=6
46+
--checkpoints_total_limit=2
47+
--checkpointing_steps=2
48+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
49+
""".split()
50+
51+
run_command(self._launch_args + test_args)
52+
53+
self.assertEqual(
54+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
55+
{"checkpoint-4", "checkpoint-6"},
56+
)
57+
58+
def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
59+
with tempfile.TemporaryDirectory() as tmpdir:
60+
test_args = f"""
61+
examples/diffusers/controlnet/train_controlnet.py
62+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
63+
--revision refs/pr/4
64+
--dataset_name=hf-internal-testing/fill10
65+
--output_dir={tmpdir}
66+
--resolution=64
67+
--train_batch_size=1
68+
--gradient_accumulation_steps=1
69+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
70+
--max_train_steps=6
71+
--checkpointing_steps=2
72+
""".split()
73+
74+
run_command(self._launch_args + test_args)
75+
76+
self.assertEqual(
77+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
78+
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
79+
)
80+
81+
resume_run_args = f"""
82+
examples/diffusers/controlnet/train_controlnet.py
83+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
84+
--revision refs/pr/4
85+
--dataset_name=hf-internal-testing/fill10
86+
--output_dir={tmpdir}
87+
--resolution=64
88+
--train_batch_size=1
89+
--gradient_accumulation_steps=1
90+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
91+
--max_train_steps=8
92+
--checkpointing_steps=2
93+
--resume_from_checkpoint=checkpoint-6
94+
--checkpoints_total_limit=2
95+
""".split()
96+
97+
run_command(self._launch_args + resume_run_args)
98+
99+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
100+
101+
102+
class ControlNetSDXL(ExamplesTests):
103+
def test_controlnet_sdxl(self):
104+
with tempfile.TemporaryDirectory() as tmpdir:
105+
test_args = f"""
106+
examples/diffusers/controlnet/train_controlnet_sdxl.py
107+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
108+
--revision refs/pr/2
109+
--dataset_name=hf-internal-testing/fill10
110+
--output_dir={tmpdir}
111+
--resolution=64
112+
--train_batch_size=1
113+
--gradient_accumulation_steps=1
114+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
115+
--max_train_steps=4
116+
--checkpointing_steps=2
117+
""".split()
118+
119+
run_command(self._launch_args + test_args)
120+
121+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
122+
123+
124+
class ControlNetSD3(ExamplesTests):
125+
def test_controlnet_sd3(self):
126+
with tempfile.TemporaryDirectory() as tmpdir:
127+
test_args = f"""
128+
examples/diffusers/controlnet/train_controlnet_sd3.py
129+
--pretrained_model_name_or_path=DavyMorgan/tiny-sd3-pipe
130+
--dataset_name=hf-internal-testing/fill10
131+
--output_dir={tmpdir}
132+
--resolution=64
133+
--train_batch_size=1
134+
--gradient_accumulation_steps=1
135+
--controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd3
136+
--max_train_steps=4
137+
--checkpointing_steps=2
138+
""".split()
139+
140+
run_command(self._launch_args + test_args)
141+
142+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
143+
144+
145+
class ControlNetSD35(ExamplesTests):
146+
def test_controlnet_sd3(self):
147+
with tempfile.TemporaryDirectory() as tmpdir:
148+
test_args = f"""
149+
examples/diffusers/controlnet/train_controlnet_sd3.py
150+
--pretrained_model_name_or_path=hf-internal-testing/tiny-sd35-pipe
151+
--dataset_name=hf-internal-testing/fill10
152+
--output_dir={tmpdir}
153+
--resolution=64
154+
--train_batch_size=1
155+
--gradient_accumulation_steps=1
156+
--controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd35
157+
--max_train_steps=4
158+
--checkpointing_steps=2
159+
""".split()
160+
161+
run_command(self._launch_args + test_args)
162+
163+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
164+
165+
166+
class ControlNetflux(ExamplesTests):
167+
def test_controlnet_flux(self):
168+
with tempfile.TemporaryDirectory() as tmpdir:
169+
test_args = f"""
170+
examples/diffusers/controlnet/train_controlnet_flux.py
171+
--pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe
172+
--output_dir={tmpdir}
173+
--dataset_name=hf-internal-testing/fill10
174+
--conditioning_image_column=conditioning_image
175+
--image_column=image
176+
--caption_column=text
177+
--resolution=64
178+
--train_batch_size=1
179+
--gradient_accumulation_steps=1
180+
--max_train_steps=4
181+
--checkpointing_steps=2
182+
--num_double_layers=1
183+
--num_single_layers=1
184+
""".split()
185+
186+
run_command(self._launch_args + test_args)
187+
188+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))

examples/diffusers/controlnet/train_controlnet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -879,8 +879,8 @@ def __len__(self):
879879
if is_master(args):
880880
logger.info(f"Resuming from checkpoint {path}")
881881
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
882-
input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
883-
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True)
882+
input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
883+
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True)
884884
global_step = int(path.split("-")[1])
885885

886886
initial_global_step = global_step
@@ -939,8 +939,7 @@ def __len__(self):
939939
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
940940
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
941941
os.makedirs(save_path, exist_ok=True)
942-
output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
943-
ms.save_checkpoint(unet, output_model_file)
942+
unet.save_pretrained(os.path.join(save_path, "unet"))
944943
logger.info(f"Saved state to {save_path}")
945944

946945
if args.validation_prompt is not None and global_step % args.validation_steps == 0:

examples/diffusers/controlnet/train_controlnet_sdxl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -990,8 +990,8 @@ def __len__(self):
990990
if is_master(args):
991991
logger.info(f"Resuming from checkpoint {path}")
992992
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
993-
input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
994-
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True)
993+
input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
994+
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True)
995995
global_step = int(path.split("-")[1])
996996

997997
initial_global_step = global_step
@@ -1050,8 +1050,7 @@ def __len__(self):
10501050
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
10511051
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
10521052
os.makedirs(save_path, exist_ok=True)
1053-
output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
1054-
ms.save_checkpoint(unet, output_model_file)
1053+
unet.save_pretrained(os.path.join(save_path, "unet"))
10551054
logger.info(f"Saved state to {save_path}")
10561055

10571056
if args.validation_prompt is not None and global_step % args.validation_steps == 0:

0 commit comments

Comments
 (0)