15
15
16
16
import torch
17
17
import torch .nn as nn
18
- import torch .nn .functional as F
19
18
20
19
from ...configuration_utils import ConfigMixin , register_to_config
21
20
from ...loaders import FromOriginalModelMixin , PeftAdapterMixin , SD3Transformer2DLoadersMixin
39
38
40
39
@maybe_allow_in_graph
41
40
class SD3SingleTransformerBlock (nn .Module ):
42
- r"""
43
- A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
44
-
45
- Reference: https://arxiv.org/abs/2403.03206
46
-
47
- Parameters:
48
- dim (`int`): The number of channels in the input and output.
49
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
50
- attention_head_dim (`int`): The number of channels in each head.
51
- """
52
-
53
41
def __init__ (
54
42
self ,
55
43
dim : int ,
@@ -59,21 +47,13 @@ def __init__(
59
47
super ().__init__ ()
60
48
61
49
self .norm1 = AdaLayerNormZero (dim )
62
-
63
- if hasattr (F , "scaled_dot_product_attention" ):
64
- processor = JointAttnProcessor2_0 ()
65
- else :
66
- raise ValueError (
67
- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
68
- )
69
-
70
50
self .attn = Attention (
71
51
query_dim = dim ,
72
52
dim_head = attention_head_dim ,
73
53
heads = num_attention_heads ,
74
54
out_dim = dim ,
75
55
bias = True ,
76
- processor = processor ,
56
+ processor = JointAttnProcessor2_0 () ,
77
57
eps = 1e-6 ,
78
58
)
79
59
@@ -82,23 +62,15 @@ def __init__(
82
62
83
63
def forward (self , hidden_states : torch .Tensor , temb : torch .Tensor ):
84
64
norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
85
- # Attention.
86
- attn_output = self .attn (
87
- hidden_states = norm_hidden_states ,
88
- encoder_hidden_states = None ,
89
- )
90
-
91
- # Process attention outputs for the `hidden_states`.
92
- attn_output = gate_msa .unsqueeze (1 ) * attn_output
93
- hidden_states = hidden_states + attn_output
65
+ # 1. Attention
66
+ attn_output = self .attn (hidden_states = norm_hidden_states , encoder_hidden_states = None )
67
+ hidden_states = hidden_states + gate_msa .unsqueeze (1 ) * attn_output
94
68
69
+ # 2. Feed Forward
95
70
norm_hidden_states = self .norm2 (hidden_states )
96
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
97
-
71
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp .unsqueeze (1 )) + shift_mlp .unsqueeze (1 )
98
72
ff_output = self .ff (norm_hidden_states )
99
- ff_output = gate_mlp .unsqueeze (1 ) * ff_output
100
-
101
- hidden_states = hidden_states + ff_output
73
+ hidden_states = hidden_states + gate_mlp .unsqueeze (1 ) * ff_output
102
74
103
75
return hidden_states
104
76
@@ -107,26 +79,40 @@ class SD3Transformer2DModel(
107
79
ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , SD3Transformer2DLoadersMixin
108
80
):
109
81
"""
110
- The Transformer model introduced in Stable Diffusion 3.
111
-
112
- Reference: https://arxiv.org/abs/2403.03206
82
+ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
113
83
114
84
Parameters:
115
- sample_size (`int`): The width of the latent images. This is fixed during training since
116
- it is used to learn a number of position embeddings.
117
- patch_size (`int`): Patch size to turn the input data into small patches.
118
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
119
- num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
120
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
121
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
122
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
123
- caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
124
- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
125
- out_channels (`int`, defaults to 16): Number of output channels.
126
-
85
+ sample_size (`int`, defaults to `128`):
86
+ The width/height of the latents. This is fixed during training since it is used to learn a number of
87
+ position embeddings.
88
+ patch_size (`int`, defaults to `2`):
89
+ Patch size to turn the input data into small patches.
90
+ in_channels (`int`, defaults to `16`):
91
+ The number of latent channels in the input.
92
+ num_layers (`int`, defaults to `18`):
93
+ The number of layers of transformer blocks to use.
94
+ attention_head_dim (`int`, defaults to `64`):
95
+ The number of channels in each head.
96
+ num_attention_heads (`int`, defaults to `18`):
97
+ The number of heads to use for multi-head attention.
98
+ joint_attention_dim (`int`, defaults to `4096`):
99
+ The embedding dimension to use for joint text-image attention.
100
+ caption_projection_dim (`int`, defaults to `1152`):
101
+ The embedding dimension of caption embeddings.
102
+ pooled_projection_dim (`int`, defaults to `2048`):
103
+ The embedding dimension of pooled text projections.
104
+ out_channels (`int`, defaults to `16`):
105
+ The number of latent channels in the output.
106
+ pos_embed_max_size (`int`, defaults to `96`):
107
+ The maximum latent height/width of positional embeddings.
108
+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
109
+ The number of dual-stream transformer blocks to use.
110
+ qk_norm (`Optional[str]`, defaults to `None`):
111
+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
127
112
"""
128
113
129
114
_supports_gradient_checkpointing = True
115
+ _no_split_modules = ["JointTransformerBlock" ]
130
116
_skip_layerwise_casting_patterns = ["pos_embed" , "norm" ]
131
117
132
118
@register_to_config
@@ -149,36 +135,33 @@ def __init__(
149
135
qk_norm : Optional [str ] = None ,
150
136
):
151
137
super ().__init__ ()
152
- default_out_channels = in_channels
153
- self .out_channels = out_channels if out_channels is not None else default_out_channels
154
- self .inner_dim = self .config .num_attention_heads * self .config .attention_head_dim
138
+ self .out_channels = out_channels if out_channels is not None else in_channels
139
+ self .inner_dim = num_attention_heads * attention_head_dim
155
140
156
141
self .pos_embed = PatchEmbed (
157
- height = self . config . sample_size ,
158
- width = self . config . sample_size ,
159
- patch_size = self . config . patch_size ,
160
- in_channels = self . config . in_channels ,
142
+ height = sample_size ,
143
+ width = sample_size ,
144
+ patch_size = patch_size ,
145
+ in_channels = in_channels ,
161
146
embed_dim = self .inner_dim ,
162
147
pos_embed_max_size = pos_embed_max_size , # hard-code for now.
163
148
)
164
149
self .time_text_embed = CombinedTimestepTextProjEmbeddings (
165
- embedding_dim = self .inner_dim , pooled_projection_dim = self . config . pooled_projection_dim
150
+ embedding_dim = self .inner_dim , pooled_projection_dim = pooled_projection_dim
166
151
)
167
- self .context_embedder = nn .Linear (self . config . joint_attention_dim , self . config . caption_projection_dim )
152
+ self .context_embedder = nn .Linear (joint_attention_dim , caption_projection_dim )
168
153
169
- # `attention_head_dim` is doubled to account for the mixing.
170
- # It needs to crafted when we get the actual checkpoints.
171
154
self .transformer_blocks = nn .ModuleList (
172
155
[
173
156
JointTransformerBlock (
174
157
dim = self .inner_dim ,
175
- num_attention_heads = self . config . num_attention_heads ,
176
- attention_head_dim = self . config . attention_head_dim ,
158
+ num_attention_heads = num_attention_heads ,
159
+ attention_head_dim = attention_head_dim ,
177
160
context_pre_only = i == num_layers - 1 ,
178
161
qk_norm = qk_norm ,
179
162
use_dual_attention = True if i in dual_attention_layers else False ,
180
163
)
181
- for i in range (self . config . num_layers )
164
+ for i in range (num_layers )
182
165
]
183
166
)
184
167
0 commit comments