Skip to content

Commit 4eb87e9

Browse files
committed
update
1 parent 45eb74f commit 4eb87e9

File tree

2 files changed

+48
-65
lines changed

2 files changed

+48
-65
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,7 @@ class JointAttnProcessor2_0:
14081408

14091409
def __init__(self):
14101410
if not hasattr(F, "scaled_dot_product_attention"):
1411-
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1411+
raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
14121412

14131413
def __call__(
14141414
self,

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 47 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import torch
1717
import torch.nn as nn
18-
import torch.nn.functional as F
1918

2019
from ...configuration_utils import ConfigMixin, register_to_config
2120
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
@@ -39,17 +38,6 @@
3938

4039
@maybe_allow_in_graph
4140
class SD3SingleTransformerBlock(nn.Module):
42-
r"""
43-
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
44-
45-
Reference: https://arxiv.org/abs/2403.03206
46-
47-
Parameters:
48-
dim (`int`): The number of channels in the input and output.
49-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
50-
attention_head_dim (`int`): The number of channels in each head.
51-
"""
52-
5341
def __init__(
5442
self,
5543
dim: int,
@@ -59,21 +47,13 @@ def __init__(
5947
super().__init__()
6048

6149
self.norm1 = AdaLayerNormZero(dim)
62-
63-
if hasattr(F, "scaled_dot_product_attention"):
64-
processor = JointAttnProcessor2_0()
65-
else:
66-
raise ValueError(
67-
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
68-
)
69-
7050
self.attn = Attention(
7151
query_dim=dim,
7252
dim_head=attention_head_dim,
7353
heads=num_attention_heads,
7454
out_dim=dim,
7555
bias=True,
76-
processor=processor,
56+
processor=JointAttnProcessor2_0(),
7757
eps=1e-6,
7858
)
7959

@@ -82,23 +62,15 @@ def __init__(
8262

8363
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
8464
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
85-
# Attention.
86-
attn_output = self.attn(
87-
hidden_states=norm_hidden_states,
88-
encoder_hidden_states=None,
89-
)
90-
91-
# Process attention outputs for the `hidden_states`.
92-
attn_output = gate_msa.unsqueeze(1) * attn_output
93-
hidden_states = hidden_states + attn_output
65+
# 1. Attention
66+
attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
67+
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output
9468

69+
# 2. Feed Forward
9570
norm_hidden_states = self.norm2(hidden_states)
96-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
97-
71+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
9872
ff_output = self.ff(norm_hidden_states)
99-
ff_output = gate_mlp.unsqueeze(1) * ff_output
100-
101-
hidden_states = hidden_states + ff_output
73+
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
10274

10375
return hidden_states
10476

@@ -107,26 +79,40 @@ class SD3Transformer2DModel(
10779
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
10880
):
10981
"""
110-
The Transformer model introduced in Stable Diffusion 3.
111-
112-
Reference: https://arxiv.org/abs/2403.03206
82+
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
11383
11484
Parameters:
115-
sample_size (`int`): The width of the latent images. This is fixed during training since
116-
it is used to learn a number of position embeddings.
117-
patch_size (`int`): Patch size to turn the input data into small patches.
118-
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
119-
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
120-
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
121-
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
122-
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
123-
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
124-
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
125-
out_channels (`int`, defaults to 16): Number of output channels.
126-
85+
sample_size (`int`, defaults to `128`):
86+
The width/height of the latents. This is fixed during training since it is used to learn a number of
87+
position embeddings.
88+
patch_size (`int`, defaults to `2`):
89+
Patch size to turn the input data into small patches.
90+
in_channels (`int`, defaults to `16`):
91+
The number of latent channels in the input.
92+
num_layers (`int`, defaults to `18`):
93+
The number of layers of transformer blocks to use.
94+
attention_head_dim (`int`, defaults to `64`):
95+
The number of channels in each head.
96+
num_attention_heads (`int`, defaults to `18`):
97+
The number of heads to use for multi-head attention.
98+
joint_attention_dim (`int`, defaults to `4096`):
99+
The embedding dimension to use for joint text-image attention.
100+
caption_projection_dim (`int`, defaults to `1152`):
101+
The embedding dimension of caption embeddings.
102+
pooled_projection_dim (`int`, defaults to `2048`):
103+
The embedding dimension of pooled text projections.
104+
out_channels (`int`, defaults to `16`):
105+
The number of latent channels in the output.
106+
pos_embed_max_size (`int`, defaults to `96`):
107+
The maximum latent height/width of positional embeddings.
108+
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
109+
The number of dual-stream transformer blocks to use.
110+
qk_norm (`Optional[str]`, defaults to `None`):
111+
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
127112
"""
128113

129114
_supports_gradient_checkpointing = True
115+
_no_split_modules = ["JointTransformerBlock"]
130116
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
131117

132118
@register_to_config
@@ -149,36 +135,33 @@ def __init__(
149135
qk_norm: Optional[str] = None,
150136
):
151137
super().__init__()
152-
default_out_channels = in_channels
153-
self.out_channels = out_channels if out_channels is not None else default_out_channels
154-
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
138+
self.out_channels = out_channels if out_channels is not None else in_channels
139+
self.inner_dim = num_attention_heads * attention_head_dim
155140

156141
self.pos_embed = PatchEmbed(
157-
height=self.config.sample_size,
158-
width=self.config.sample_size,
159-
patch_size=self.config.patch_size,
160-
in_channels=self.config.in_channels,
142+
height=sample_size,
143+
width=sample_size,
144+
patch_size=patch_size,
145+
in_channels=in_channels,
161146
embed_dim=self.inner_dim,
162147
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
163148
)
164149
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
165-
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
150+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
166151
)
167-
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
152+
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
168153

169-
# `attention_head_dim` is doubled to account for the mixing.
170-
# It needs to crafted when we get the actual checkpoints.
171154
self.transformer_blocks = nn.ModuleList(
172155
[
173156
JointTransformerBlock(
174157
dim=self.inner_dim,
175-
num_attention_heads=self.config.num_attention_heads,
176-
attention_head_dim=self.config.attention_head_dim,
158+
num_attention_heads=num_attention_heads,
159+
attention_head_dim=attention_head_dim,
177160
context_pre_only=i == num_layers - 1,
178161
qk_norm=qk_norm,
179162
use_dual_attention=True if i in dual_attention_layers else False,
180163
)
181-
for i in range(self.config.num_layers)
164+
for i in range(num_layers)
182165
]
183166
)
184167

0 commit comments

Comments
 (0)