Skip to content

Commit ad4e8be

Browse files
committed
update
1 parent be84828 commit ad4e8be

File tree

7 files changed

+393
-314
lines changed

7 files changed

+393
-314
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 57 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,63 @@
4646
XLA_AVAILABLE = False
4747

4848

49+
class AttnProcessorMixin:
50+
"""Attention processor used typically in processing Aura Flow."""
51+
52+
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
53+
"""Get projections using standard separate projection matrices."""
54+
# Standard separate projections
55+
query = attn.to_q(hidden_states)
56+
57+
if encoder_hidden_states is None:
58+
encoder_hidden_states = hidden_states
59+
elif attn.norm_cross:
60+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
61+
62+
key = attn.to_k(encoder_hidden_states)
63+
value = attn.to_v(encoder_hidden_states)
64+
65+
# Handle encoder projections if present
66+
encoder_projections = None
67+
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
68+
encoder_query = attn.add_q_proj(encoder_hidden_states)
69+
encoder_key = attn.add_k_proj(encoder_hidden_states)
70+
encoder_value = attn.add_v_proj(encoder_hidden_states)
71+
encoder_projections = (encoder_query, encoder_key, encoder_value)
72+
73+
return query, key, value, encoder_projections
74+
75+
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
76+
"""Get projections using fused QKV projection matrices."""
77+
# Fused QKV projection
78+
qkv = attn.to_qkv(hidden_states)
79+
split_size = qkv.shape[-1] // 3
80+
query, key, value = torch.split(qkv, split_size, dim=-1)
81+
82+
# Handle encoder projections if present
83+
encoder_projections = None
84+
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
85+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
86+
split_size = encoder_qkv.shape[-1] // 3
87+
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
88+
encoder_projections = (encoder_query, encoder_key, encoder_value)
89+
90+
return query, key, value, encoder_projections
91+
92+
def get_projections(self, attn, hidden_states, encoder_hidden_states=None):
93+
"""Public method to get projections based on whether we're using fused mode or not."""
94+
if self.is_fused and hasattr(attn, "to_qkv"):
95+
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
96+
97+
return self._get_projections(attn, hidden_states, encoder_hidden_states)
98+
99+
def attention_fn(self, query, key, value, scale=None, attention_mask=None):
100+
"""Computes the attention. Can be overridden by hardware-specific implementations."""
101+
return F.scaled_dot_product_attention(
102+
query, key, value, attn_mask=attention_mask, scale=scale, dropout_p=0.0, is_causal=False
103+
)
104+
105+
49106
class Attention(nn.Module, AttentionModuleMixin):
50107
default_processor_class = AttnProcessorSDPA
51108
_available_processors = []
@@ -1292,99 +1349,6 @@ def __call__(
12921349
return hidden_states
12931350

12941351

1295-
class AuraFlowAttnProcessorSDPA:
1296-
"""Attention processor used typically in processing Aura Flow."""
1297-
1298-
def __init__(self):
1299-
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1300-
raise ImportError(
1301-
"AuraFlowAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1302-
)
1303-
1304-
def __call__(
1305-
self,
1306-
attn: Attention,
1307-
hidden_states: torch.FloatTensor,
1308-
encoder_hidden_states: torch.FloatTensor = None,
1309-
*args,
1310-
**kwargs,
1311-
) -> torch.FloatTensor:
1312-
batch_size = hidden_states.shape[0]
1313-
1314-
# `sample` projections.
1315-
query = attn.to_q(hidden_states)
1316-
key = attn.to_k(hidden_states)
1317-
value = attn.to_v(hidden_states)
1318-
1319-
# `context` projections.
1320-
if encoder_hidden_states is not None:
1321-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1322-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1323-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1324-
1325-
# Reshape.
1326-
inner_dim = key.shape[-1]
1327-
head_dim = inner_dim // attn.heads
1328-
query = query.view(batch_size, -1, attn.heads, head_dim)
1329-
key = key.view(batch_size, -1, attn.heads, head_dim)
1330-
value = value.view(batch_size, -1, attn.heads, head_dim)
1331-
1332-
# Apply QK norm.
1333-
if attn.norm_q is not None:
1334-
query = attn.norm_q(query)
1335-
if attn.norm_k is not None:
1336-
key = attn.norm_k(key)
1337-
1338-
# Concatenate the projections.
1339-
if encoder_hidden_states is not None:
1340-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1341-
batch_size, -1, attn.heads, head_dim
1342-
)
1343-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
1344-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1345-
batch_size, -1, attn.heads, head_dim
1346-
)
1347-
1348-
if attn.norm_added_q is not None:
1349-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1350-
if attn.norm_added_k is not None:
1351-
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
1352-
1353-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
1354-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1355-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1356-
1357-
query = query.transpose(1, 2)
1358-
key = key.transpose(1, 2)
1359-
value = value.transpose(1, 2)
1360-
1361-
# Attention.
1362-
hidden_states = F.scaled_dot_product_attention(
1363-
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
1364-
)
1365-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1366-
hidden_states = hidden_states.to(query.dtype)
1367-
1368-
# Split the attention outputs.
1369-
if encoder_hidden_states is not None:
1370-
hidden_states, encoder_hidden_states = (
1371-
hidden_states[:, encoder_hidden_states.shape[1] :],
1372-
hidden_states[:, : encoder_hidden_states.shape[1]],
1373-
)
1374-
1375-
# linear proj
1376-
hidden_states = attn.to_out[0](hidden_states)
1377-
# dropout
1378-
hidden_states = attn.to_out[1](hidden_states)
1379-
if encoder_hidden_states is not None:
1380-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1381-
1382-
if encoder_hidden_states is not None:
1383-
return hidden_states, encoder_hidden_states
1384-
else:
1385-
return hidden_states
1386-
1387-
13881352
class FusedAuraFlowAttnProcessorSDPA:
13891353
"""Attention processor used typically in processing Aura Flow with fused projections."""
13901354

@@ -2335,104 +2299,6 @@ def __call__(
23352299
return hidden_states
23362300

23372301

2338-
class HunyuanAttnProcessorSDPA:
2339-
r"""
2340-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2341-
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
2342-
"""
2343-
2344-
def __init__(self):
2345-
if not hasattr(F, "scaled_dot_product_attention"):
2346-
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2347-
2348-
def __call__(
2349-
self,
2350-
attn: Attention,
2351-
hidden_states: torch.Tensor,
2352-
encoder_hidden_states: Optional[torch.Tensor] = None,
2353-
attention_mask: Optional[torch.Tensor] = None,
2354-
temb: Optional[torch.Tensor] = None,
2355-
image_rotary_emb: Optional[torch.Tensor] = None,
2356-
) -> torch.Tensor:
2357-
from .embeddings import apply_rotary_emb
2358-
2359-
residual = hidden_states
2360-
if attn.spatial_norm is not None:
2361-
hidden_states = attn.spatial_norm(hidden_states, temb)
2362-
2363-
input_ndim = hidden_states.ndim
2364-
2365-
if input_ndim == 4:
2366-
batch_size, channel, height, width = hidden_states.shape
2367-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2368-
2369-
batch_size, sequence_length, _ = (
2370-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2371-
)
2372-
2373-
if attention_mask is not None:
2374-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2375-
# scaled_dot_product_attention expects attention_mask shape to be
2376-
# (batch, heads, source_length, target_length)
2377-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2378-
2379-
if attn.group_norm is not None:
2380-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2381-
2382-
query = attn.to_q(hidden_states)
2383-
2384-
if encoder_hidden_states is None:
2385-
encoder_hidden_states = hidden_states
2386-
elif attn.norm_cross:
2387-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2388-
2389-
key = attn.to_k(encoder_hidden_states)
2390-
value = attn.to_v(encoder_hidden_states)
2391-
2392-
inner_dim = key.shape[-1]
2393-
head_dim = inner_dim // attn.heads
2394-
2395-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2396-
2397-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2398-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2399-
2400-
if attn.norm_q is not None:
2401-
query = attn.norm_q(query)
2402-
if attn.norm_k is not None:
2403-
key = attn.norm_k(key)
2404-
2405-
# Apply RoPE if needed
2406-
if image_rotary_emb is not None:
2407-
query = apply_rotary_emb(query, image_rotary_emb)
2408-
if not attn.is_cross_attention:
2409-
key = apply_rotary_emb(key, image_rotary_emb)
2410-
2411-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2412-
# TODO: add support for attn.scale when we move to Torch 2.1
2413-
hidden_states = F.scaled_dot_product_attention(
2414-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2415-
)
2416-
2417-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2418-
hidden_states = hidden_states.to(query.dtype)
2419-
2420-
# linear proj
2421-
hidden_states = attn.to_out[0](hidden_states)
2422-
# dropout
2423-
hidden_states = attn.to_out[1](hidden_states)
2424-
2425-
if input_ndim == 4:
2426-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2427-
2428-
if attn.residual_connection:
2429-
hidden_states = hidden_states + residual
2430-
2431-
hidden_states = hidden_states / attn.rescale_output_factor
2432-
2433-
return hidden_states
2434-
2435-
24362302
class FusedHunyuanAttnProcessorSDPA:
24372303
r"""
24382304
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused

src/diffusers/models/auto_model.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import importlib
1615
import os
1716
from typing import Optional, Union
1817

19-
from huggingface_hub.utils import validate_hf_hub_args
18+
from huggingface_hub.utils import EntryNotFoundError, validate_hf_hub_args
2019

2120
from ..configuration_utils import ConfigMixin
21+
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
2222

2323

2424
class AutoModel(ConfigMixin):
@@ -153,17 +153,39 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
153153
"token": token,
154154
"local_files_only": local_files_only,
155155
"revision": revision,
156-
"subfolder": subfolder,
157156
}
158157

159-
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
160-
orig_class_name = config["_class_name"]
158+
library = None
159+
orig_class_name = None
160+
from diffusers import pipelines
161161

162-
library = importlib.import_module("diffusers")
162+
# Always attempt to fetch model_index.json first
163+
try:
164+
cls.config_name = "model_index.json"
165+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
163166

164-
model_cls = getattr(library, orig_class_name, None)
165-
if model_cls is None:
166-
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
167+
if subfolder is not None and subfolder in config:
168+
library, orig_class_name = config[subfolder]
169+
170+
except (OSError, EntryNotFoundError) as e:
171+
logger.debug(e)
172+
173+
# Unable to load from model_index.json so fallback to loading from config
174+
if library is None and orig_class_name is None:
175+
cls.config_name = "config.json"
176+
load_config_kwargs.update({"subfolder": subfolder})
177+
178+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
179+
orig_class_name = config["_class_name"]
180+
library = "diffusers"
181+
182+
model_cls, _ = get_class_obj_and_candidates(
183+
library_name=library,
184+
class_name=orig_class_name,
185+
importable_classes=ALL_IMPORTABLE_CLASSES,
186+
pipelines=pipelines,
187+
is_pipeline_module=hasattr(pipelines, library),
188+
)
167189

168190
kwargs = {**load_config_kwargs, **kwargs}
169191
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)

0 commit comments

Comments
 (0)