Skip to content

Commit ada9cd1

Browse files
committed
add_example_diffusers_test
1 parent b17f5c5 commit ada9cd1

28 files changed

+2338
-40
lines changed
102 KB
Loading
13.7 KB
Loading

examples/diffusers/cogvideox_factory/cogvideox/models/autoencoder_kl_cogvideox_sp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from mindone.diffusers.models.layers_compat import pad
3232
from mindone.diffusers.models.modeling_outputs import AutoencoderKLOutput
3333
from mindone.diffusers.models.modeling_utils import ModelMixin
34-
from mindone.diffusers.models.normalization import GroupNorm
3534
from mindone.diffusers.models.upsampling import CogVideoXUpsample3D
3635
from mindone.diffusers.utils import logging
3736

@@ -40,7 +39,7 @@
4039
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4140

4241

43-
class GroupNorm_SP(GroupNorm):
42+
class GroupNorm_SP(mint.nn.GroupNorm):
4443
def set_frame_group_size(self, frame_group_size):
4544
self.frame_group_size = frame_group_size
4645

examples/diffusers/cogview/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cd mindone
2929
pip install -e .
3030
# NOTE: transformers requires >=4.46.0
3131
32-
cd examples/cogview
32+
cd examples/diffusers/cogview
3333
```
3434

3535

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
sys.path.append("..")
22+
from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
23+
24+
ExamplesTests._launch_args = ["python"]
25+
26+
logging.basicConfig(level=logging.DEBUG)
27+
28+
logger = logging.getLogger()
29+
stream_handler = logging.StreamHandler(sys.stdout)
30+
logger.addHandler(stream_handler)
31+
32+
33+
class ControlNet(ExamplesTests):
34+
def test_controlnet_checkpointing_checkpoints_total_limit(self):
35+
with tempfile.TemporaryDirectory() as tmpdir:
36+
test_args = f"""
37+
examples/diffusers/controlnet/train_controlnet.py
38+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
39+
--revision refs/pr/4
40+
--dataset_name=hf-internal-testing/fill10
41+
--output_dir={tmpdir}
42+
--resolution=64
43+
--train_batch_size=1
44+
--gradient_accumulation_steps=1
45+
--max_train_steps=6
46+
--checkpoints_total_limit=2
47+
--checkpointing_steps=2
48+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
49+
""".split()
50+
51+
run_command(self._launch_args + test_args)
52+
53+
self.assertEqual(
54+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
55+
{"checkpoint-4", "checkpoint-6"},
56+
)
57+
58+
def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
59+
with tempfile.TemporaryDirectory() as tmpdir:
60+
test_args = f"""
61+
examples/diffusers/controlnet/train_controlnet.py
62+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
63+
--revision refs/pr/4
64+
--dataset_name=hf-internal-testing/fill10
65+
--output_dir={tmpdir}
66+
--resolution=64
67+
--train_batch_size=1
68+
--gradient_accumulation_steps=1
69+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
70+
--max_train_steps=6
71+
--checkpointing_steps=2
72+
""".split()
73+
74+
run_command(self._launch_args + test_args)
75+
76+
self.assertEqual(
77+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
78+
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
79+
)
80+
81+
resume_run_args = f"""
82+
examples/diffusers/controlnet/train_controlnet.py
83+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
84+
--revision refs/pr/4
85+
--dataset_name=hf-internal-testing/fill10
86+
--output_dir={tmpdir}
87+
--resolution=64
88+
--train_batch_size=1
89+
--gradient_accumulation_steps=1
90+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
91+
--max_train_steps=8
92+
--checkpointing_steps=2
93+
--resume_from_checkpoint=checkpoint-6
94+
--checkpoints_total_limit=2
95+
""".split()
96+
97+
run_command(self._launch_args + resume_run_args)
98+
99+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
100+
101+
102+
class ControlNetSDXL(ExamplesTests):
103+
def test_controlnet_sdxl(self):
104+
with tempfile.TemporaryDirectory() as tmpdir:
105+
test_args = f"""
106+
examples/diffusers/controlnet/train_controlnet_sdxl.py
107+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
108+
--revision refs/pr/2
109+
--dataset_name=hf-internal-testing/fill10
110+
--output_dir={tmpdir}
111+
--resolution=64
112+
--train_batch_size=1
113+
--gradient_accumulation_steps=1
114+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
115+
--max_train_steps=4
116+
--checkpointing_steps=2
117+
""".split()
118+
119+
run_command(self._launch_args + test_args)
120+
121+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
122+
123+
124+
class ControlNetflux(ExamplesTests):
125+
def test_controlnet_flux(self):
126+
with tempfile.TemporaryDirectory() as tmpdir:
127+
test_args = f"""
128+
examples/diffusers/controlnet/train_controlnet_flux.py
129+
--pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe
130+
--output_dir={tmpdir}
131+
--dataset_name=hf-internal-testing/fill10
132+
--conditioning_image_column=conditioning_image
133+
--image_column=image
134+
--caption_column=text
135+
--resolution=64
136+
--train_batch_size=1
137+
--gradient_accumulation_steps=1
138+
--max_train_steps=4
139+
--checkpointing_steps=2
140+
--num_double_layers=1
141+
--num_single_layers=1
142+
""".split()
143+
144+
run_command(self._launch_args + test_args)
145+
146+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))

examples/diffusers/controlnet/train_controlnet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -879,8 +879,8 @@ def __len__(self):
879879
if is_master(args):
880880
logger.info(f"Resuming from checkpoint {path}")
881881
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
882-
input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
883-
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True)
882+
input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
883+
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True)
884884
global_step = int(path.split("-")[1])
885885

886886
initial_global_step = global_step
@@ -939,8 +939,7 @@ def __len__(self):
939939
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
940940
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
941941
os.makedirs(save_path, exist_ok=True)
942-
output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
943-
ms.save_checkpoint(unet, output_model_file)
942+
unet.save_pretrained(os.path.join(save_path, "unet"))
944943
logger.info(f"Saved state to {save_path}")
945944

946945
if args.validation_prompt is not None and global_step % args.validation_steps == 0:

examples/diffusers/controlnet/train_controlnet_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from mindspore.dataset import GeneratorDataset, transforms, vision
3636

3737
from mindone.diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel
38-
from mindone.diffusers.models.controlnet_flux import FluxControlNetModel
38+
from mindone.diffusers.models.controlnets.controlnet_flux import FluxControlNetModel
3939
from mindone.diffusers.models.layers_compat import set_amp_strategy
4040
from mindone.diffusers.optimization import get_scheduler
4141
from mindone.diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline

examples/diffusers/controlnet/train_controlnet_sdxl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -990,8 +990,8 @@ def __len__(self):
990990
if is_master(args):
991991
logger.info(f"Resuming from checkpoint {path}")
992992
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
993-
input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
994-
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True)
993+
input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
994+
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True)
995995
global_step = int(path.split("-")[1])
996996

997997
initial_global_step = global_step
@@ -1050,8 +1050,7 @@ def __len__(self):
10501050
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
10511051
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
10521052
os.makedirs(save_path, exist_ok=True)
1053-
output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
1054-
ms.save_checkpoint(unet, output_model_file)
1053+
unet.save_pretrained(os.path.join(save_path, "unet"))
10551054
logger.info(f"Saved state to {save_path}")
10561055

10571056
if args.validation_prompt is not None and global_step % args.validation_steps == 0:

0 commit comments

Comments
 (0)