Skip to content

Commit 4420f27

Browse files
authored
supports loading multiple model files & update doc (#115)
1 parent 63d9b89 commit 4420f27

File tree

4 files changed

+300
-24
lines changed

4 files changed

+300
-24
lines changed

diffsynth_engine/pipelines/base.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,21 @@ def unload_loras(self):
9191

9292
@staticmethod
9393
def load_model_checkpoint(
94-
checkpoint_path: str, device: str = "cpu", dtype: torch.dtype = torch.float16
94+
checkpoint_path: str | List[str], device: str = "cpu", dtype: torch.dtype = torch.float16
9595
) -> Dict[str, torch.Tensor]:
96-
if not os.path.isfile(checkpoint_path):
97-
FileNotFoundError(f"{checkpoint_path} is not a file")
98-
if checkpoint_path.endswith(".safetensors"):
99-
return load_file(checkpoint_path, device=device)
100-
if checkpoint_path.endswith(".gguf"):
101-
return load_gguf_checkpoint(checkpoint_path, device=device, dtype=dtype)
102-
raise ValueError(f"{checkpoint_path} is not a .safetensors or .gguf file")
96+
if isinstance(checkpoint_path, str):
97+
checkpoint_path = [checkpoint_path]
98+
state_dict = {}
99+
for path in checkpoint_path:
100+
if not os.path.isfile(path):
101+
raise FileNotFoundError(f"{path} is not a file")
102+
elif path.endswith(".safetensors"):
103+
state_dict.update(**load_file(path, device=device))
104+
elif path.endswith(".gguf"):
105+
state_dict.update(**load_gguf_checkpoint(path, device=device, dtype=dtype))
106+
else:
107+
raise ValueError(f"{path} is not a .safetensors or .gguf file")
108+
return state_dict
103109

104110
@staticmethod
105111
def validate_image_size(

diffsynth_engine/utils/download.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import shutil
33
import tqdm
44
import tempfile
5-
from typing import Optional
5+
from typing import List, Optional
66
from pathlib import Path
77
from urllib.parse import urlparse
88
import requests
9+
import glob
910

1011
from modelscope import snapshot_download
1112
from modelscope.hub.api import HubApi
@@ -23,11 +24,11 @@
2324
def fetch_model(
2425
model_uri: str,
2526
revision: Optional[str] = None,
26-
path: Optional[str] = None,
27+
path: Optional[str | List[str]] = None,
2728
access_token: Optional[str] = None,
2829
source: str = "modelscope",
29-
fetch_safetensors: bool = True,
30-
) -> str:
30+
fetch_safetensors: bool = True, # TODO: supports other formats like GGUF
31+
) -> str | List[str]:
3132
if source == "modelscope":
3233
return fetch_modelscope_model(model_uri, revision, path, access_token, fetch_safetensors)
3334
if source == "civitai":
@@ -38,7 +39,7 @@ def fetch_model(
3839
def fetch_modelscope_model(
3940
model_id: str,
4041
revision: Optional[str] = None,
41-
path: Optional[str] = None,
42+
path: Optional[str | List[str]] = None,
4243
access_token: Optional[str] = None,
4344
fetch_safetensors: bool = True,
4445
) -> str:
@@ -52,12 +53,15 @@ def fetch_modelscope_model(
5253
directory = os.path.join(DIFFSYNTH_CACHE, "modelscope", model_id, revision if revision else "__version")
5354
dirpath = snapshot_download(model_id, revision=revision, local_dir=directory, allow_patterns=path)
5455

55-
if path is not None:
56-
path = os.path.join(dirpath, path)
56+
if isinstance(path, str):
57+
path = glob.glob(os.path.join(dirpath, path))
58+
path = path[0] if len(path) == 1 else path
59+
elif isinstance(path, list):
60+
path = [os.path.join(dirpath, p) for p in path]
5761
else:
5862
path = dirpath
5963

60-
if os.path.isdir(path) and fetch_safetensors:
64+
if isinstance(path, str) and os.path.isdir(path) and fetch_safetensors:
6165
return _fetch_safetensors(path)
6266
return path
6367

@@ -122,16 +126,17 @@ def ensure_directory_exists(filename: str):
122126
Path(filename).parent.mkdir(parents=True, exist_ok=True)
123127

124128

125-
def _fetch_safetensors(dirpath: str) -> str:
129+
def _fetch_safetensors(dirpath: str) -> str | List[str]:
126130
all_safetensors = []
127131
for filename in os.listdir(dirpath):
128132
if filename.endswith(".safetensors"):
129133
all_safetensors.append(os.path.join(dirpath, filename))
130-
if len(all_safetensors) == 1:
131-
logger.info(f"Fetch safetensors file {all_safetensors[0]}")
132-
return all_safetensors[0]
133-
elif len(all_safetensors) == 0:
134+
if len(all_safetensors) == 0:
134135
logger.error(f"No safetensors file found in {dirpath}")
136+
return dirpath
137+
elif len(all_safetensors) == 1:
138+
all_safetensors = all_safetensors[0]
139+
logger.info(f"Fetch safetensors file {all_safetensors}")
135140
else:
136-
logger.error(f"Multiple safetensors files found in {dirpath}, please specify the file name")
137-
return dirpath
141+
logger.info(f"Fetch safetensors files {all_safetensors}")
142+
return all_safetensors

docs/tutorial.md

Lines changed: 241 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,241 @@
1-
# ToDo
1+
# DiffSynth-Engine User Guide
2+
3+
## Installation
4+
5+
Before using DiffSynth-Engine, please ensure your device meets the following requirements:
6+
7+
* NVIDIA GPU with CUDA Compute Capability 8.6+ (e.g., RTX 50 Series, RTX 40 Series, RTX 30 Series, see [NVIDIA documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities) for details) or Apple Silicon M-series chips.
8+
9+
Python environment requirements: Python 3.10+.
10+
11+
Use `pip3` to install DiffSynth-Engine from PyPI:
12+
13+
```shell
14+
pip3 install diffsynth-engine
15+
```
16+
17+
DiffSynth-Engine also supports installation from source, which provides access to the latest features but might come with stability issues. We recommend installing the stable version via `pip3`.
18+
19+
```shell
20+
git clone https://github.com/modelscope/diffsynth-engine.git && cd diffsynth-engine
21+
pip3 install -e .
22+
```
23+
24+
## Model Download
25+
26+
DiffSynth-Engine supports loading models from the [ModelScope Model Hub](https://www.modelscope.cn/aigc/models) by model ID. For example, on the [MajicFlus model page](https://www.modelscope.cn/models/MAILAND/majicflus_v1/summary?version=v1.0), we can find the model ID and the corresponding model filename in the image below.
27+
28+
![Image](https://github.com/user-attachments/assets/a6f71768-487d-4376-8974-fe6563f2896c)
29+
30+
Next, download the MajicFlus model with the following code.
31+
32+
```python
33+
from diffsynth_engine import fetch_model
34+
35+
model_path = fetch_model("MAILAND/majicflus_v1", path="majicflus_v134.safetensors")
36+
```
37+
38+
![Image](https://github.com/user-attachments/assets/596c3383-23b3-4372-a7ce-3c4e1c1ba81a)
39+
40+
For sharded models, specify multiple files using the `path` parameter.
41+
42+
```python
43+
from diffsynth_engine import fetch_model
44+
45+
model_path = fetch_model("Wan-AI/Wan2.1-T2V-14B", path=[
46+
"diffusion_pytorch_model-00001-of-00006.safetensors",
47+
"diffusion_pytorch_model-00002-of-00006.safetensors",
48+
"diffusion_pytorch_model-00003-of-00006.safetensors",
49+
"diffusion_pytorch_model-00004-of-00006.safetensors",
50+
"diffusion_pytorch_model-00005-of-00006.safetensors",
51+
"diffusion_pytorch_model-00006-of-00006.safetensors",
52+
])
53+
```
54+
55+
It also supports using wildcards to match multiple files.
56+
57+
```python
58+
from diffsynth_engine import fetch_model
59+
60+
model_path = fetch_model("Wan-AI/Wan2.1-T2V-14B", path="diffusion_pytorch_model*.safetensors")
61+
```
62+
63+
The file path `model_path` returned by the `fetch_model` function is the path to the downloaded file(s).
64+
65+
## Model Types
66+
67+
Diffusion models come in a wide variety of architectures. Each model is loaded and run for inference by a corresponding pipeline. The model types we currently support include:
68+
69+
| Model Architecture | Example | Pipeline |
70+
| :----------------- | :----------------------------------------------------------- | :-------------------- |
71+
| SD1.5 | [DreamShaper](https://www.modelscope.cn/models/MusePublic/DreamShaper_SD_1_5) | `SDImagePipeline` |
72+
| SDXL | [RealVisXL](https://www.modelscope.cn/models/MusePublic/42_ckpt_SD_XL) | `SDXLImagePipeline` |
73+
| FLUX | [MajicFlus](https://www.modelscope.cn/models/MAILAND/majicflus_v1/summary?version=v1.0) | `FluxImagePipeline` |
74+
| Wan2.1 | [Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | `WanVideoPipeline` |
75+
| SD1.5 LoRA | [Detail Tweaker](https://www.modelscope.cn/models/MusePublic/Detail_Tweaker_LoRA_xijietiaozheng_LoRA_SD_1_5) | `SDImagePipeline` |
76+
| SDXL LoRA | [Aesthetic Anime](https://www.modelscope.cn/models/MusePublic/100_lora_SD_XL) | `SDXLImagePipeline` |
77+
| FLUX LoRA | [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) | `FluxImagePipeline` |
78+
| Wan2.1 LoRA | [Highres-fix](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1) | `WanVideoPipeline` |
79+
80+
Among these, SD1.5, SDXL, and FLUX are base models for image generation, while Wan2.1 is a base model for video generation. Base models can generate content independently. SD1.5 LoRA, SDXL LoRA, FLUX LoRA, and Wan2.1 LoRA are [LoRA](https://arxiv.org/abs/2106.09685) models. LoRA models are trained as "additional branches" on top of base models to enhance specific capabilities. They must be combined with a base model to be used for generation.
81+
82+
We will continuously update DiffSynth-Engine to support more models.
83+
84+
## Model Inference
85+
86+
After the model is downloaded, load the model with the corresponding pipeline and perform inference.
87+
88+
### Image Generation
89+
90+
The following code calls `FluxImagePipeline` to load the [MajicFlus](https://www.modelscope.cn/models/MAILAND/majicflus_v1/summary?version=v1.0) model and generate an image. To load other types of models, replace `FluxImagePipeline` in the code with the corresponding pipeline.
91+
92+
```python
93+
from diffsynth_engine import fetch_model, FluxImagePipeline
94+
95+
model_path = fetch_model("MAILAND/majicflus_v1", path="majicflus_v134.safetensors")
96+
pipe = FluxImagePipeline.from_pretrained(model_path, device='cuda:0')
97+
image = pipe(prompt="a cat")
98+
image.save("image.png")
99+
```
100+
101+
Please note that if some necessary modules, like text encoders, are missing from a model repository, the pipeline will automatically download the required files.
102+
103+
#### Detailed Parameters
104+
105+
In the image generation pipeline `pipe`, we can use the following parameters for fine-grained control:
106+
107+
* `prompt`: The prompt, used to describe the content of the generated image, e.g., "a cat".
108+
* `negative_prompt`: The negative prompt, used to describe content you do not want in the image, e.g., "ugly".
109+
* `cfg_scale`: The guidance scale for [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598). A larger value usually results in stronger correlation between the text and the image but reduces the diversity of the generated content.
110+
* `clip_skip`: The number of layers to skip in the [CLIP](https://arxiv.org/abs/2103.00020) text encoder. The more layers skipped, the lower the text-image correlation, but this can lead to interesting variations in the generated content.
111+
* `input_image`: Input image, used for image-to-image generation.
112+
* `mask_image`: Mask image, used for image inpainting.
113+
* `denoising_strength`: The denoising strength. When set to 1, a full generation process is performed. When set to a value between 0 and 1, some information from the input image is preserved.
114+
* `height`: Image height.
115+
* `width`: Image width.
116+
* `num_inference_steps`: The number of inference steps. Generally, more steps lead to longer computation time but higher image quality.
117+
* `tiled`: Whether to enable tiled processing for the VAE. This option is disabled by default. Enabling it can reduce VRAM usage.
118+
* `tile_size`: The window size for tiled VAE processing.
119+
* `tile_stride`: The stride for tiled VAE processing.
120+
* `seed`: The random seed. A fixed seed ensures reproducible results.
121+
* `progress_bar_cmd`: The progress bar module. [`tqdm`](https://github.com/tqdm/tqdm) is enabled by default. To disable the progress bar, set it to `lambda x: x`.
122+
123+
#### Loading LoRA
124+
125+
We supports loading LoRA on top of the base model. For example, the following code loads a [Cheongsam LoRA](https://www.modelscope.cn/models/DonRat/MAJICFLUS_SuperChinesestyleheongsam) based on the [MajicFlus](https://www.modelscope.cn/models/MAILAND/majicflus_v1/summary?version=v1.0) model to generate images of cheongsams, which the base model might struggle to create.
126+
127+
```python
128+
from diffsynth_engine import fetch_model, FluxImagePipeline
129+
130+
model_path = fetch_model("MAILAND/majicflus_v1", path="majicflus_v134.safetensors")
131+
lora_path = fetch_model("DonRat/MAJICFLUS_SuperChinesestyleheongsam", path="麦橘超国风旗袍.safetensors")
132+
133+
pipe = FluxImagePipeline.from_pretrained(model_path, device='cuda:0')
134+
pipe.load_lora(path=lora_path, scale=1.0)
135+
image = pipe(prompt="a girl, qipao")
136+
image.save("image.png")
137+
```
138+
139+
The `scale` parameter in the code controls the degree of influence the LoRA model has on the base model. A value of 1.0 is usually sufficient. When set to a value greater than 1, the LoRA's effect will be stronger, but this may cause artifacts or degradation in the image content. Please adjust this parameter with caution.
140+
141+
#### VRAM Optimization
142+
143+
DiffSynth-Engine supports various levels of VRAM optimization, allowing models to run on GPUs with low VRAM. For example, at `bfloat16` precision and with no optimization options enabled, the FLUX model requires 35.84GB of VRAM for inference. By adding the parameter `offload_mode="cpu_offload"`, the VRAM requirement drops to 22.83GB. Furthermore, using `offload_mode="sequential_cpu_offload"` reduces the requirement to just 4.30GB, although this comes with an increase of inference time.
144+
145+
```python
146+
from diffsynth_engine import fetch_model, FluxImagePipeline
147+
148+
model_path = fetch_model("MAILAND/majicflus_v1", path="majicflus_v134.safetensors")
149+
pipe = FluxImagePipeline.from_pretrained(model_path, offload_mode="sequential_cpu_offload")
150+
image = pipe(prompt="a cat")
151+
image.save("image.png")
152+
```
153+
154+
### Video Generation
155+
156+
DiffSynth-Engine also supports video generation. The following code loads the [Wan Video Generation Model](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) and generates a video.
157+
158+
```python
159+
from diffsynth_engine.pipelines.wan_video import WanVideoPipeline, WanModelConfig
160+
from diffsynth_engine.utils.video import save_video
161+
from diffsynth_engine import fetch_model
162+
163+
config = WanModelConfig(
164+
model_path=fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors"),
165+
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
166+
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
167+
)
168+
pipe = WanVideoPipeline.from_pretrained(config, device="cuda")
169+
# The prompt translates to: "A lively puppy runs quickly on a green lawn. The puppy has brownish-yellow fur,
170+
# its two ears are perked up, and it looks focused and cheerful. Sunlight shines on it,
171+
# making its fur look especially soft and shiny."
172+
video = pipe(prompt="一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。")
173+
save_video(video, "video.mp4")
174+
```
175+
176+
#### Detailed Parameters
177+
178+
In the video generation pipeline `pipe`, we can use the following parameters for fine-grained control:
179+
180+
* `prompt`: The prompt, used to describe the content of the generated video, e.g., "a cat".
181+
* `negative_prompt`: The negative prompt, used to describe content you do not want in the video, e.g., "ugly".
182+
* `cfg_scale`: The guidance scale for [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598). A larger value usually results in stronger correlation between the text and the video but reduces the diversity of the generated content.
183+
* `input_image`: Input image, only effective in image-to-video models, such as [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P).
184+
* `input_video`: Input video, used for video-to-video generation.
185+
* `denoising_strength`: The denoising strength. When set to 1, a full generation process is performed. When set to a value between 0 and 1, some information from the input video is preserved.
186+
* `height`: Video frame height.
187+
* `width`: Video frame width.
188+
* `num_frames`: Number of video frames.
189+
* `num_inference_steps`: The number of inference steps. Generally, more steps lead to longer computation time but higher video quality.
190+
* `tiled`: Whether to enable tiled processing for the VAE. This option is disabled by default. Enabling it can reduce VRAM usage.
191+
* `tile_size`: The window size for tiled VAE processing.
192+
* `tile_stride`: The stride for tiled VAE processing.
193+
* `seed`: The random seed. A fixed seed ensures reproducible results.
194+
195+
#### Loading LoRA
196+
197+
We supports loading LoRA on top of the base model. For example, the following code loads a [High-Resolution Fix LoRA](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1) on top of the [Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model to improve the generation quality at high resolutions.
198+
199+
```python
200+
from diffsynth_engine.pipelines.wan_video import WanVideoPipeline, WanModelConfig
201+
from diffsynth_engine.utils.video import save_video
202+
from diffsynth_engine import fetch_model
203+
204+
config = WanModelConfig(
205+
model_path=fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors"),
206+
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
207+
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
208+
)
209+
lora_path = fetch_model("DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1", path="model.safetensors")
210+
pipe = WanVideoPipeline.from_pretrained(config, device="cuda")
211+
pipe.load_lora(path=lora_path, scale=1.0)
212+
# The prompt translates to: "A lively puppy runs quickly on a green lawn. The puppy has brownish-yellow fur,
213+
# its two ears are perked up, and it looks focused and cheerful. Sunlight shines on it,
214+
# making its fur look especially soft and shiny."
215+
video = pipe(prompt="一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。")
216+
save_video(video, "video.mp4")
217+
```
218+
219+
The `scale` parameter in the code controls the degree of influence the LoRA model has on the base model. A value of 1.0 is usually sufficient. When set to a value greater than 1, the LoRA's effect will be stronger, but this may cause artifacts or degradation in the image content. Please adjust this parameter with caution.
220+
221+
#### Multi-GPU Parallelism
222+
223+
We supports multi-GPU parallel inference of the Wan2.1 model for faster video generation. Add the parameters `parallelism=4` (the number of GPUs to use) and `use_cfg_parallel=True` into the code to enable parallelism.
224+
225+
```python
226+
from diffsynth_engine.pipelines.wan_video import WanVideoPipeline, WanModelConfig
227+
from diffsynth_engine.utils.video import save_video
228+
from diffsynth_engine import fetch_model
229+
230+
config = WanModelConfig(
231+
model_path=fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors"),
232+
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
233+
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
234+
)
235+
pipe = WanVideoPipeline.from_pretrained(config, device="cuda", parallelism=4, use_cfg_parallel=True)
236+
# The prompt translates to: "A lively puppy runs quickly on a green lawn. The puppy has brownish-yellow fur,
237+
# its two ears are perked up, and it looks focused and cheerful. Sunlight shines on it,
238+
# making its fur look especially soft and shiny."
239+
video = pipe(prompt="一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。")
240+
save_video(video, "video.mp4")
241+
```

0 commit comments

Comments
 (0)