Skip to content

Commit 753bc9d

Browse files
authored
add_example_diffusers_test (#1372)
1 parent b896582 commit 753bc9d

33 files changed

+2353
-45
lines changed

docs/diffusers/imgs/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
### Image Credits
2+
3+
The images in this folder are taken from the [Hugging Face Diffusers repository](https://github.com/huggingface/diffusers/tree/main/docs/source/en/imgs) and are subject to the Apache 2.0 license of the Diffusers project.
102 KB
Loading
13.7 KB
Loading

examples/diffusers/cogvideox_factory/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
> 我们的开发和验证基于Ascend Atlas 800T A2硬件,相关环境如下:
66
> | mindspore | ascend driver | firmware | cann toolkit/kernel |
77
> |:----------:|:--------------:|:-----------:|:------------------:|
8-
> | 2.5 | 24.1.RC2 | 7.5.0.1.129 | 8.0.0.beta1 |
8+
> | 2.6.0 | 24.1.RC2 | 7.3.0.1.231 | 8.1.RC1 |
9+
> | 2.7.0 | 24.1.RC2 | 7.3.0.1.231 | 8.2.RC1 |
910
1011
<table align="center">
1112
<tr>
@@ -409,3 +410,7 @@ NODE_RANK="0"
409410
当前训练脚本并不完全支持原仓代码的所有训练参数,详情参见[`args.py`](./scripts/args.py)中的`check_args()`。
410411

411412
其中一个主要的限制来自于CogVideoX模型中的[3D Causual VAE不支持静态图](https://gist.github.com/townwish4git/b6cd0d213b396eaedfb69b3abcd742da),这导致我们**不支持静态图模式下VAE参与训练**,因此在静态图模式下必须提前进行数据预处理以获取VAE-latents/text-encoder-embeddings cache。
413+
414+
415+
### 注意
416+
训练结束后若出现 `Exception ignored: OSError [Errno 9] Bad file descriptor`,仅为 Python 关闭时的提示,不影响训练结果;使用 Python 3.11,该提示不再出现。

examples/diffusers/cogvideox_factory/cogvideox/models/autoencoder_kl_cogvideox_sp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from mindone.diffusers.models.layers_compat import pad
3232
from mindone.diffusers.models.modeling_outputs import AutoencoderKLOutput
3333
from mindone.diffusers.models.modeling_utils import ModelMixin
34-
from mindone.diffusers.models.normalization import GroupNorm
3534
from mindone.diffusers.models.upsampling import CogVideoXUpsample3D
3635
from mindone.diffusers.utils import logging
3736

@@ -40,7 +39,7 @@
4039
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4140

4241

43-
class GroupNorm_SP(GroupNorm):
42+
class GroupNorm_SP(mint.nn.GroupNorm):
4443
def set_frame_group_size(self, frame_group_size):
4544
self.frame_group_size = frame_group_size
4645

examples/diffusers/cogvideox_factory/scripts/train_text_to_video_lora.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ AMP_LEVEL=O2
3232
DATA_ROOT="preprocessed-dataset"
3333
CAPTION_COLUMN="prompts.txt"
3434
VIDEO_COLUMN="videos.txt"
35-
MODEL_NAME_OR_PATH="THUDM/CogVideoX1.5-5b"
35+
MODEL_NAME_OR_PATH="THUDM/CogVideoX1.5-5B"
3636
H=768
3737
W=1360
3838
F=77

examples/diffusers/cogvideox_factory/scripts/train_text_to_video_sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ DEEPSPEED_ZERO_STAGE=3
4040
DATA_ROOT="preprocessed-dataset"
4141
CAPTION_COLUMN="prompts.txt"
4242
VIDEO_COLUMN="videos.txt"
43-
MODEL_NAME_OR_PATH="THUDM/CogVideoX1.5-5b"
43+
MODEL_NAME_OR_PATH="THUDM/CogVideoX1.5-5B"
4444
H=768
4545
W=1360
4646
F=77

examples/diffusers/cogview/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cd mindone
2929
pip install -e .
3030
# NOTE: transformers requires >=4.46.0
3131
32-
cd examples/cogview
32+
cd examples/diffusers/cogview
3333
```
3434

3535

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
sys.path.append("..")
22+
from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
23+
24+
ExamplesTests._launch_args = ["python"]
25+
26+
logging.basicConfig(level=logging.DEBUG)
27+
28+
logger = logging.getLogger()
29+
stream_handler = logging.StreamHandler(sys.stdout)
30+
logger.addHandler(stream_handler)
31+
32+
33+
class ControlNet(ExamplesTests):
34+
def test_controlnet_checkpointing_checkpoints_total_limit(self):
35+
with tempfile.TemporaryDirectory() as tmpdir:
36+
test_args = f"""
37+
examples/diffusers/controlnet/train_controlnet.py
38+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
39+
--revision refs/pr/4
40+
--dataset_name=hf-internal-testing/fill10
41+
--output_dir={tmpdir}
42+
--resolution=64
43+
--train_batch_size=1
44+
--gradient_accumulation_steps=1
45+
--max_train_steps=6
46+
--checkpoints_total_limit=2
47+
--checkpointing_steps=2
48+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
49+
""".split()
50+
51+
run_command(self._launch_args + test_args)
52+
53+
self.assertEqual(
54+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
55+
{"checkpoint-4", "checkpoint-6"},
56+
)
57+
58+
def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
59+
with tempfile.TemporaryDirectory() as tmpdir:
60+
test_args = f"""
61+
examples/diffusers/controlnet/train_controlnet.py
62+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
63+
--revision refs/pr/4
64+
--dataset_name=hf-internal-testing/fill10
65+
--output_dir={tmpdir}
66+
--resolution=64
67+
--train_batch_size=1
68+
--gradient_accumulation_steps=1
69+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
70+
--max_train_steps=6
71+
--checkpointing_steps=2
72+
""".split()
73+
74+
run_command(self._launch_args + test_args)
75+
76+
self.assertEqual(
77+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
78+
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
79+
)
80+
81+
resume_run_args = f"""
82+
examples/diffusers/controlnet/train_controlnet.py
83+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
84+
--revision refs/pr/4
85+
--dataset_name=hf-internal-testing/fill10
86+
--output_dir={tmpdir}
87+
--resolution=64
88+
--train_batch_size=1
89+
--gradient_accumulation_steps=1
90+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
91+
--max_train_steps=8
92+
--checkpointing_steps=2
93+
--resume_from_checkpoint=checkpoint-6
94+
--checkpoints_total_limit=2
95+
""".split()
96+
97+
run_command(self._launch_args + resume_run_args)
98+
99+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
100+
101+
102+
class ControlNetSDXL(ExamplesTests):
103+
def test_controlnet_sdxl(self):
104+
with tempfile.TemporaryDirectory() as tmpdir:
105+
test_args = f"""
106+
examples/diffusers/controlnet/train_controlnet_sdxl.py
107+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
108+
--revision refs/pr/2
109+
--dataset_name=hf-internal-testing/fill10
110+
--output_dir={tmpdir}
111+
--resolution=64
112+
--train_batch_size=1
113+
--gradient_accumulation_steps=1
114+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
115+
--max_train_steps=4
116+
--checkpointing_steps=2
117+
""".split()
118+
119+
run_command(self._launch_args + test_args)
120+
121+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
122+
123+
124+
class ControlNetflux(ExamplesTests):
125+
def test_controlnet_flux(self):
126+
with tempfile.TemporaryDirectory() as tmpdir:
127+
test_args = f"""
128+
examples/diffusers/controlnet/train_controlnet_flux.py
129+
--pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe
130+
--output_dir={tmpdir}
131+
--dataset_name=hf-internal-testing/fill10
132+
--conditioning_image_column=conditioning_image
133+
--image_column=image
134+
--caption_column=text
135+
--resolution=64
136+
--train_batch_size=1
137+
--gradient_accumulation_steps=1
138+
--max_train_steps=4
139+
--checkpointing_steps=2
140+
--num_double_layers=1
141+
--num_single_layers=1
142+
""".split()
143+
144+
run_command(self._launch_args + test_args)
145+
146+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))

examples/diffusers/controlnet/train_controlnet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -879,8 +879,8 @@ def __len__(self):
879879
if is_master(args):
880880
logger.info(f"Resuming from checkpoint {path}")
881881
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
882-
input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
883-
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True)
882+
input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
883+
ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True)
884884
global_step = int(path.split("-")[1])
885885

886886
initial_global_step = global_step
@@ -939,8 +939,7 @@ def __len__(self):
939939
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
940940
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
941941
os.makedirs(save_path, exist_ok=True)
942-
output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
943-
ms.save_checkpoint(unet, output_model_file)
942+
unet.save_pretrained(os.path.join(save_path, "unet"))
944943
logger.info(f"Saved state to {save_path}")
945944

946945
if args.validation_prompt is not None and global_step % args.validation_steps == 0:

0 commit comments

Comments
 (0)