Skip to content

Commit 201da97

Browse files
authored
Merge branch 'main' into custom-code-updates
2 parents 4423097 + f36ba9f commit 201da97

33 files changed

+1588
-112
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset):
971971

972972
def __init__(
973973
self,
974+
args,
974975
instance_data_root,
975976
instance_prompt,
976977
class_prompt,
@@ -980,10 +981,8 @@ def __init__(
980981
class_num=None,
981982
size=1024,
982983
repeats=1,
983-
center_crop=False,
984984
):
985985
self.size = size
986-
self.center_crop = center_crop
987986

988987
self.instance_prompt = instance_prompt
989988
self.custom_instance_prompts = None
@@ -1058,7 +1057,7 @@ def __init__(
10581057
if interpolation is None:
10591058
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
10601059
train_resize = transforms.Resize(size, interpolation=interpolation)
1061-
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
1060+
train_crop = transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size)
10621061
train_flip = transforms.RandomHorizontalFlip(p=1.0)
10631062
train_transforms = transforms.Compose(
10641063
[
@@ -1075,11 +1074,11 @@ def __init__(
10751074
# flip
10761075
image = train_flip(image)
10771076
if args.center_crop:
1078-
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
1079-
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
1077+
y1 = max(0, int(round((image.height - self.size) / 2.0)))
1078+
x1 = max(0, int(round((image.width - self.size) / 2.0)))
10801079
image = train_crop(image)
10811080
else:
1082-
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
1081+
y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
10831082
image = crop(image, y1, x1, h, w)
10841083
image = train_transforms(image)
10851084
self.pixel_values.append(image)
@@ -1102,7 +1101,7 @@ def __init__(
11021101
self.image_transforms = transforms.Compose(
11031102
[
11041103
transforms.Resize(size, interpolation=interpolation),
1105-
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
1104+
transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size),
11061105
transforms.ToTensor(),
11071106
transforms.Normalize([0.5], [0.5]),
11081107
]
@@ -1827,6 +1826,7 @@ def load_model_hook(models, input_dir):
18271826

18281827
# Dataset and DataLoaders creation:
18291828
train_dataset = DreamBoothDataset(
1829+
args=args,
18301830
instance_data_root=args.instance_data_dir,
18311831
instance_prompt=args.instance_prompt,
18321832
train_text_encoder_ti=args.train_text_encoder_ti,
@@ -1836,7 +1836,6 @@ def load_model_hook(models, input_dir):
18361836
class_num=args.num_class_images,
18371837
size=args.resolution,
18381838
repeats=args.repeats,
1839-
center_crop=args.center_crop,
18401839
)
18411840

18421841
train_dataloader = torch.utils.data.DataLoader(

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@
366366
[
367367
"StableDiffusionXLAutoBlocks",
368368
"StableDiffusionXLModularPipeline",
369+
"WanAutoBlocks",
370+
"WanModularPipeline",
369371
]
370372
)
371373
_import_structure["pipelines"].extend(
@@ -999,6 +1001,8 @@
9991001
from .modular_pipelines import (
10001002
StableDiffusionXLAutoBlocks,
10011003
StableDiffusionXLModularPipeline,
1004+
WanAutoBlocks,
1005+
WanModularPipeline,
10021006
)
10031007
from .pipelines import (
10041008
AllegroPipeline,

src/diffusers/hooks/_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _register(cls):
107107
def _register_attention_processors_metadata():
108108
from ..models.attention_processor import AttnProcessor2_0
109109
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
110111

111112
# AttnProcessor2_0
112113
AttentionProcessorRegistry.register(
@@ -124,6 +125,14 @@ def _register_attention_processors_metadata():
124125
),
125126
)
126127

128+
# WanAttnProcessor2_0
129+
AttentionProcessorRegistry.register(
130+
model_class=WanAttnProcessor2_0,
131+
metadata=AttentionProcessorMetadata(
132+
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
133+
),
134+
)
135+
127136

128137
def _register_transformer_blocks_metadata():
129138
from ..models.attention import BasicTransformerBlock
@@ -261,4 +270,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
261270

262271
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
263272
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
273+
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
264274
# fmt: on

src/diffusers/hooks/layer_skip.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,19 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
9191
if kwargs is None:
9292
kwargs = {}
9393
if func is torch.nn.functional.scaled_dot_product_attention:
94+
query = kwargs.get("query", None)
95+
key = kwargs.get("key", None)
9496
value = kwargs.get("value", None)
95-
if value is None:
96-
value = args[2]
97-
return value
97+
query = query if query is not None else args[0]
98+
key = key if key is not None else args[1]
99+
value = value if value is not None else args[2]
100+
# If the Q sequence length does not match KV sequence length, methods like
101+
# Perturbed Attention Guidance cannot be used (because the caller expects
102+
# the same sequence length as Q, but if we return V here, it will not match).
103+
# When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
104+
# the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
105+
if query.shape[2] == value.shape[2]:
106+
return value
98107
return func(*args, **kwargs)
99108

100109

src/diffusers/models/attention_dispatch.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,37 @@
3838
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
3939

4040

41-
logger = get_logger(__name__) # pylint: disable=invalid-name
42-
43-
44-
if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
41+
_REQUIRED_FLASH_VERSION = "2.6.3"
42+
_REQUIRED_SAGE_VERSION = "2.1.1"
43+
_REQUIRED_FLEX_VERSION = "2.5.0"
44+
_REQUIRED_XLA_VERSION = "2.2"
45+
_REQUIRED_XFORMERS_VERSION = "0.0.29"
46+
47+
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
48+
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
49+
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
50+
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
51+
_CAN_USE_NPU_ATTN = is_torch_npu_available()
52+
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
53+
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
54+
55+
56+
if _CAN_USE_FLASH_ATTN:
4557
from flash_attn import flash_attn_func, flash_attn_varlen_func
4658
else:
47-
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
4859
flash_attn_func = None
4960
flash_attn_varlen_func = None
5061

5162

52-
if is_flash_attn_3_available():
63+
if _CAN_USE_FLASH_ATTN_3:
5364
from flash_attn_interface import flash_attn_func as flash_attn_3_func
5465
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
5566
else:
5667
flash_attn_3_func = None
5768
flash_attn_3_varlen_func = None
5869

5970

60-
if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
71+
if _CAN_USE_SAGE_ATTN:
6172
from sageattention import (
6273
sageattn,
6374
sageattn_qk_int8_pv_fp8_cuda,
@@ -67,9 +78,6 @@
6778
sageattn_varlen,
6879
)
6980
else:
70-
logger.warning(
71-
"`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
72-
)
7381
sageattn = None
7482
sageattn_qk_int8_pv_fp16_cuda = None
7583
sageattn_qk_int8_pv_fp16_triton = None
@@ -78,39 +86,39 @@
7886
sageattn_varlen = None
7987

8088

81-
if is_torch_version(">=", "2.5.0"):
89+
if _CAN_USE_FLEX_ATTN:
8290
# We cannot import the flex_attention function from the package directly because it is expected (from the
8391
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
8492
# compiled function.
8593
import torch.nn.attention.flex_attention as flex_attention
8694

8795

88-
if is_torch_npu_available():
96+
if _CAN_USE_NPU_ATTN:
8997
from torch_npu import npu_fusion_attention
9098
else:
9199
npu_fusion_attention = None
92100

93101

94-
if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
102+
if _CAN_USE_XLA_ATTN:
95103
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
96104
else:
97105
xla_flash_attention = None
98106

99107

100-
if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
108+
if _CAN_USE_XFORMERS_ATTN:
101109
import xformers.ops as xops
102110
else:
103-
logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
104111
xops = None
105112

106113

114+
logger = get_logger(__name__) # pylint: disable=invalid-name
115+
107116
# TODO(aryan): Add support for the following:
108117
# - Sage Attention++
109118
# - block sparse, radial and other attention methods
110119
# - CP with sage attention, flex, xformers, other missing backends
111120
# - Add support for normal and CP training with backends that don't support it yet
112121

113-
114122
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
115123
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
116124
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
@@ -179,13 +187,16 @@ def list_backends(cls):
179187

180188

181189
@contextlib.contextmanager
182-
def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
190+
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
183191
"""
184192
Context manager to set the active attention backend.
185193
"""
186194
if backend not in _AttentionBackendRegistry._backends:
187195
raise ValueError(f"Backend {backend} is not registered.")
188196

197+
backend = AttentionBackendName(backend)
198+
_check_attention_backend_requirements(backend)
199+
189200
old_backend = _AttentionBackendRegistry._active_backend
190201
_AttentionBackendRegistry._active_backend = backend
191202

@@ -226,9 +237,10 @@ def dispatch_attention_fn(
226237
"dropout_p": dropout_p,
227238
"is_causal": is_causal,
228239
"scale": scale,
229-
"enable_gqa": enable_gqa,
230240
**attention_kwargs,
231241
}
242+
if is_torch_version(">=", "2.5.0"):
243+
kwargs["enable_gqa"] = enable_gqa
232244

233245
if _AttentionBackendRegistry._checks_enabled:
234246
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
@@ -305,6 +317,57 @@ def _check_shape(
305317
# ===== Helper functions =====
306318

307319

320+
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
321+
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
322+
if not _CAN_USE_FLASH_ATTN:
323+
raise RuntimeError(
324+
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
325+
)
326+
327+
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
328+
if not _CAN_USE_FLASH_ATTN_3:
329+
raise RuntimeError(
330+
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
331+
)
332+
333+
elif backend in [
334+
AttentionBackendName.SAGE,
335+
AttentionBackendName.SAGE_VARLEN,
336+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
337+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
338+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
339+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
340+
]:
341+
if not _CAN_USE_SAGE_ATTN:
342+
raise RuntimeError(
343+
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
344+
)
345+
346+
elif backend == AttentionBackendName.FLEX:
347+
if not _CAN_USE_FLEX_ATTN:
348+
raise RuntimeError(
349+
f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
350+
)
351+
352+
elif backend == AttentionBackendName._NATIVE_NPU:
353+
if not _CAN_USE_NPU_ATTN:
354+
raise RuntimeError(
355+
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
356+
)
357+
358+
elif backend == AttentionBackendName._NATIVE_XLA:
359+
if not _CAN_USE_XLA_ATTN:
360+
raise RuntimeError(
361+
f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
362+
)
363+
364+
elif backend == AttentionBackendName.XFORMERS:
365+
if not _CAN_USE_XFORMERS_ATTN:
366+
raise RuntimeError(
367+
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
368+
)
369+
370+
308371
@functools.lru_cache(maxsize=128)
309372
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
310373
batch_size: int,

src/diffusers/models/modeling_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,19 +622,21 @@ def set_attention_backend(self, backend: str) -> None:
622622
attention as backend.
623623
"""
624624
from .attention import AttentionModuleMixin
625-
from .attention_dispatch import AttentionBackendName
625+
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
626626

627627
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
628628
from .attention_processor import Attention, MochiAttention
629629

630+
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
631+
630632
backend = backend.lower()
631633
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
632634
if backend not in available_backends:
633635
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
634-
635636
backend = AttentionBackendName(backend)
636-
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
637+
_check_attention_backend_requirements(backend)
637638

639+
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
638640
for module in self.modules():
639641
if not isinstance(module, attention_classes):
640642
continue
@@ -651,6 +653,8 @@ def reset_attention_backend(self) -> None:
651653
from .attention import AttentionModuleMixin
652654
from .attention_processor import Attention, MochiAttention
653655

656+
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
657+
654658
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
655659
for module in self.modules():
656660
if not isinstance(module, attention_classes):

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class conditioning with `class_embed_type` equal to `None`.
165165
"""
166166

167167
_supports_gradient_checkpointing = True
168-
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
168+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
169169
_skip_layerwise_casting_patterns = ["norm"]
170170
_repeated_blocks = ["BasicTransformerBlock"]
171171

src/diffusers/modular_pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"InsertableDict",
4141
]
4242
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
43+
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
4344
_import_structure["components_manager"] = ["ComponentsManager"]
4445

4546
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -71,6 +72,7 @@
7172
StableDiffusionXLAutoBlocks,
7273
StableDiffusionXLModularPipeline,
7374
)
75+
from .wan import WanAutoBlocks, WanModularPipeline
7476
else:
7577
import sys
7678

0 commit comments

Comments
 (0)