Skip to content

Commit 315e357

Browse files
authored
Merge branch 'main' into integrations/first-block-cache-2
2 parents 1f33ca2 + df1d7b0 commit 315e357

File tree

13 files changed

+982
-20
lines changed

13 files changed

+982
-20
lines changed

docs/source/en/api/pipelines/wan.md

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,46 @@ output = pipe(
133133
export_to_video(output, "wan-i2v.mp4", fps=16)
134134
```
135135

136+
### Video to Video Generation
137+
138+
```python
139+
import torch
140+
from diffusers.utils import load_video, export_to_video
141+
from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
142+
143+
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
144+
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
145+
vae = AutoencoderKLWan.from_pretrained(
146+
model_id, subfolder="vae", torch_dtype=torch.float32
147+
)
148+
pipe = WanVideoToVideoPipeline.from_pretrained(
149+
model_id, vae=vae, torch_dtype=torch.bfloat16
150+
)
151+
flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
152+
pipe.scheduler = UniPCMultistepScheduler.from_config(
153+
pipe.scheduler.config, flow_shift=flow_shift
154+
)
155+
# change to pipe.to("cuda") if you have sufficient VRAM
156+
pipe.enable_model_cpu_offload()
157+
158+
prompt = "A robot standing on a mountain top. The sun is setting in the background"
159+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
160+
video = load_video(
161+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
162+
)
163+
output = pipe(
164+
video=video,
165+
prompt=prompt,
166+
negative_prompt=negative_prompt,
167+
height=480,
168+
width=512,
169+
guidance_scale=7.0,
170+
strength=0.7,
171+
).frames[0]
172+
173+
export_to_video(output, "wan-v2v.mp4", fps=16)
174+
```
175+
136176
## Memory Optimizations for Wan 2.1
137177

138178
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
@@ -323,7 +363,7 @@ import numpy as np
323363
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
324364
from diffusers.hooks.group_offloading import apply_group_offloading
325365
from diffusers.utils import export_to_video, load_image
326-
from transformers import UMT5EncoderModel, CLIPVisionMode
366+
from transformers import UMT5EncoderModel, CLIPVisionModel
327367

328368
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
329369
image_encoder = CLIPVisionModel.from_pretrained(
@@ -356,7 +396,7 @@ prompt = (
356396
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
357397
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
358398
)
359-
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
399+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
360400
num_frames = 33
361401

362402
output = pipe(
@@ -372,7 +412,7 @@ output = pipe(
372412
export_to_video(output, "wan-i2v.mp4", fps=16)
373413
```
374414

375-
### Using a Custom Scheduler
415+
## Using a Custom Scheduler
376416

377417
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
378418

@@ -403,7 +443,7 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
403443
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
404444
```
405445

406-
## Recommendations for Inference:
446+
## Recommendations for Inference
407447
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
408448
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
409449
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.

docs/source/en/installation.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@ Your Python environment will find the `main` version of 🤗 Diffusers on the ne
161161

162162
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
163163

164-
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `True` and 🤗 Diffusers will only load previously downloaded files in the cache.
164+
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
165165

166166
```shell
167-
export HF_HUB_OFFLINE=True
167+
export HF_HUB_OFFLINE=1
168168
```
169169

170170
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
@@ -179,14 +179,16 @@ Telemetry is only sent when loading models and pipelines from the Hub,
179179
and it is not collected if you're loading local files.
180180

181181
We understand that not everyone wants to share additional information,and we respect your privacy.
182-
You can disable telemetry collection by setting the `DISABLE_TELEMETRY` environment variable from your terminal:
182+
You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
183183

184184
On Linux/MacOS:
185+
185186
```bash
186-
export DISABLE_TELEMETRY=YES
187+
export HF_HUB_DISABLE_TELEMETRY=1
187188
```
188189

189190
On Windows:
191+
190192
```bash
191-
set DISABLE_TELEMETRY=YES
193+
set HF_HUB_DISABLE_TELEMETRY=1
192194
```

examples/community/lpw_stable_diffusion_xl.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,7 +1773,7 @@ def denoising_value_valid(dnv):
17731773
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
17741774
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
17751775
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1776-
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1776+
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
17771777
" `pipeline.unet` or your `mask_image` or `image` input."
17781778
)
17791779
elif num_channels_unet != 4:
@@ -1924,7 +1924,22 @@ def denoising_value_valid(dnv):
19241924
self.upcast_vae()
19251925
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
19261926

1927-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1927+
# unscale/denormalize the latents
1928+
# denormalize with the mean and std if available and not None
1929+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1930+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1931+
if has_latents_mean and has_latents_std:
1932+
latents_mean = (
1933+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1934+
)
1935+
latents_std = (
1936+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1937+
)
1938+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1939+
else:
1940+
latents = latents / self.vae.config.scaling_factor
1941+
1942+
image = self.vae.decode(latents, return_dict=False)[0]
19281943

19291944
# cast back to fp16 if needed
19301945
if needs_upcasting:

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@
511511
"VQDiffusionPipeline",
512512
"WanImageToVideoPipeline",
513513
"WanPipeline",
514+
"WanVideoToVideoPipeline",
514515
"WuerstchenCombinedPipeline",
515516
"WuerstchenDecoderPipeline",
516517
"WuerstchenPriorPipeline",
@@ -1066,6 +1067,7 @@
10661067
VQDiffusionPipeline,
10671068
WanImageToVideoPipeline,
10681069
WanPipeline,
1070+
WanVideoToVideoPipeline,
10691071
WuerstchenCombinedPipeline,
10701072
WuerstchenDecoderPipeline,
10711073
WuerstchenPriorPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@
356356
"WuerstchenDecoderPipeline",
357357
"WuerstchenPriorPipeline",
358358
]
359-
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline"]
359+
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"]
360360
try:
361361
if not is_onnx_available():
362362
raise OptionalDependencyNotAvailable()
@@ -709,7 +709,7 @@
709709
UniDiffuserPipeline,
710710
UniDiffuserTextDecoder,
711711
)
712-
from .wan import WanImageToVideoPipeline, WanPipeline
712+
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
713713
from .wuerstchen import (
714714
WuerstchenCombinedPipeline,
715715
WuerstchenDecoderPipeline,

src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -487,19 +487,21 @@ def prepare_latents(
487487
) -> torch.Tensor:
488488
height = height // self.vae_spatial_compression_ratio
489489
width = width // self.vae_spatial_compression_ratio
490-
num_frames = (
491-
(num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
492-
)
490+
num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
493491

494492
shape = (batch_size, num_channels_latents, num_frames, height, width)
495493
mask_shape = (batch_size, 1, num_frames, height, width)
496494

497495
if latents is not None:
498-
conditioning_mask = latents.new_zeros(shape)
496+
conditioning_mask = latents.new_zeros(mask_shape)
499497
conditioning_mask[:, :, 0] = 1.0
500498
conditioning_mask = self._pack_latents(
501499
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
502-
)
500+
).squeeze(-1)
501+
if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape:
502+
raise ValueError(
503+
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}."
504+
)
503505
return latents.to(device=device, dtype=dtype), conditioning_mask
504506

505507
if isinstance(generator, list):

src/diffusers/pipelines/wan/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
else:
2525
_import_structure["pipeline_wan"] = ["WanPipeline"]
2626
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
27-
27+
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
2828
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2929
try:
3030
if not (is_transformers_available() and is_torch_available()):
@@ -35,6 +35,7 @@
3535
else:
3636
from .pipeline_wan import WanPipeline
3737
from .pipeline_wan_i2v import WanImageToVideoPipeline
38+
from .pipeline_wan_video2video import WanVideoToVideoPipeline
3839

3940
else:
4041
import sys

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,13 @@ def __call__(
458458
callback_on_step_end_tensor_inputs,
459459
)
460460

461+
if num_frames % self.vae_scale_factor_temporal != 1:
462+
logger.warning(
463+
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
464+
)
465+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
466+
num_frames = max(num_frames, 1)
467+
461468
self._guidance_scale = guidance_scale
462469
self._attention_kwargs = attention_kwargs
463470
self._current_timestep = None

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,13 @@ def __call__(
559559
callback_on_step_end_tensor_inputs,
560560
)
561561

562+
if num_frames % self.vae_scale_factor_temporal != 1:
563+
logger.warning(
564+
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
565+
)
566+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
567+
num_frames = max(num_frames, 1)
568+
562569
self._guidance_scale = guidance_scale
563570
self._attention_kwargs = attention_kwargs
564571
self._current_timestep = None

0 commit comments

Comments
 (0)