Skip to content

Commit f36ba9f

Browse files
authored
[modular diffusers] Wan (#11913)
* update
1 parent 1c50a5f commit f36ba9f

File tree

13 files changed

+1333
-3
lines changed

13 files changed

+1333
-3
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@
366366
[
367367
"StableDiffusionXLAutoBlocks",
368368
"StableDiffusionXLModularPipeline",
369+
"WanAutoBlocks",
370+
"WanModularPipeline",
369371
]
370372
)
371373
_import_structure["pipelines"].extend(
@@ -999,6 +1001,8 @@
9991001
from .modular_pipelines import (
10001002
StableDiffusionXLAutoBlocks,
10011003
StableDiffusionXLModularPipeline,
1004+
WanAutoBlocks,
1005+
WanModularPipeline,
10021006
)
10031007
from .pipelines import (
10041008
AllegroPipeline,

src/diffusers/hooks/_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _register(cls):
107107
def _register_attention_processors_metadata():
108108
from ..models.attention_processor import AttnProcessor2_0
109109
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
110111

111112
# AttnProcessor2_0
112113
AttentionProcessorRegistry.register(
@@ -124,6 +125,14 @@ def _register_attention_processors_metadata():
124125
),
125126
)
126127

128+
# WanAttnProcessor2_0
129+
AttentionProcessorRegistry.register(
130+
model_class=WanAttnProcessor2_0,
131+
metadata=AttentionProcessorMetadata(
132+
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
133+
),
134+
)
135+
127136

128137
def _register_transformer_blocks_metadata():
129138
from ..models.attention import BasicTransformerBlock
@@ -261,4 +270,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
261270

262271
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
263272
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
273+
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
264274
# fmt: on

src/diffusers/hooks/layer_skip.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,19 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
9191
if kwargs is None:
9292
kwargs = {}
9393
if func is torch.nn.functional.scaled_dot_product_attention:
94+
query = kwargs.get("query", None)
95+
key = kwargs.get("key", None)
9496
value = kwargs.get("value", None)
95-
if value is None:
96-
value = args[2]
97-
return value
97+
query = query if query is not None else args[0]
98+
key = key if key is not None else args[1]
99+
value = value if value is not None else args[2]
100+
# If the Q sequence length does not match KV sequence length, methods like
101+
# Perturbed Attention Guidance cannot be used (because the caller expects
102+
# the same sequence length as Q, but if we return V here, it will not match).
103+
# When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
104+
# the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
105+
if query.shape[2] == value.shape[2]:
106+
return value
98107
return func(*args, **kwargs)
99108

100109

src/diffusers/modular_pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"InsertableDict",
4141
]
4242
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
43+
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
4344
_import_structure["components_manager"] = ["ComponentsManager"]
4445

4546
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -71,6 +72,7 @@
7172
StableDiffusionXLAutoBlocks,
7273
StableDiffusionXLModularPipeline,
7374
)
75+
from .wan import WanAutoBlocks, WanModularPipeline
7476
else:
7577
import sys
7678

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@
6060
MODULAR_PIPELINE_MAPPING = OrderedDict(
6161
[
6262
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
63+
("wan", "WanModularPipeline"),
6364
]
6465
)
6566

6667
MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
6768
[
6869
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
70+
("WanModularPipeline", "WanAutoBlocks"),
6971
]
7072
)
7173

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
try:
17+
if not (is_transformers_available() and is_torch_available()):
18+
raise OptionalDependencyNotAvailable()
19+
except OptionalDependencyNotAvailable:
20+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
21+
22+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23+
else:
24+
_import_structure["encoders"] = ["WanTextEncoderStep"]
25+
_import_structure["modular_blocks"] = [
26+
"ALL_BLOCKS",
27+
"AUTO_BLOCKS",
28+
"TEXT2VIDEO_BLOCKS",
29+
"WanAutoBeforeDenoiseStep",
30+
"WanAutoBlocks",
31+
"WanAutoBlocks",
32+
"WanAutoDecodeStep",
33+
"WanAutoDenoiseStep",
34+
]
35+
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
36+
37+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
38+
try:
39+
if not (is_transformers_available() and is_torch_available()):
40+
raise OptionalDependencyNotAvailable()
41+
except OptionalDependencyNotAvailable:
42+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
43+
else:
44+
from .encoders import WanTextEncoderStep
45+
from .modular_blocks import (
46+
ALL_BLOCKS,
47+
AUTO_BLOCKS,
48+
TEXT2VIDEO_BLOCKS,
49+
WanAutoBeforeDenoiseStep,
50+
WanAutoBlocks,
51+
WanAutoDecodeStep,
52+
WanAutoDenoiseStep,
53+
)
54+
from .modular_pipeline import WanModularPipeline
55+
else:
56+
import sys
57+
58+
sys.modules[__name__] = _LazyModule(
59+
__name__,
60+
globals()["__file__"],
61+
_import_structure,
62+
module_spec=__spec__,
63+
)
64+
65+
for name, value in _dummy_objects.items():
66+
setattr(sys.modules[__name__], name, value)

0 commit comments

Comments
 (0)