|
46 | 46 | XLA_AVAILABLE = False
|
47 | 47 |
|
48 | 48 |
|
| 49 | +class AttnProcessorMixin: |
| 50 | + """Attention processor used typically in processing Aura Flow.""" |
| 51 | + |
| 52 | + def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): |
| 53 | + """Get projections using standard separate projection matrices.""" |
| 54 | + # Standard separate projections |
| 55 | + query = attn.to_q(hidden_states) |
| 56 | + |
| 57 | + if encoder_hidden_states is None: |
| 58 | + encoder_hidden_states = hidden_states |
| 59 | + elif attn.norm_cross: |
| 60 | + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| 61 | + |
| 62 | + key = attn.to_k(encoder_hidden_states) |
| 63 | + value = attn.to_v(encoder_hidden_states) |
| 64 | + |
| 65 | + # Handle encoder projections if present |
| 66 | + encoder_projections = None |
| 67 | + if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): |
| 68 | + encoder_query = attn.add_q_proj(encoder_hidden_states) |
| 69 | + encoder_key = attn.add_k_proj(encoder_hidden_states) |
| 70 | + encoder_value = attn.add_v_proj(encoder_hidden_states) |
| 71 | + encoder_projections = (encoder_query, encoder_key, encoder_value) |
| 72 | + |
| 73 | + return query, key, value, encoder_projections |
| 74 | + |
| 75 | + def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): |
| 76 | + """Get projections using fused QKV projection matrices.""" |
| 77 | + # Fused QKV projection |
| 78 | + qkv = attn.to_qkv(hidden_states) |
| 79 | + split_size = qkv.shape[-1] // 3 |
| 80 | + query, key, value = torch.split(qkv, split_size, dim=-1) |
| 81 | + |
| 82 | + # Handle encoder projections if present |
| 83 | + encoder_projections = None |
| 84 | + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): |
| 85 | + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) |
| 86 | + split_size = encoder_qkv.shape[-1] // 3 |
| 87 | + encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1) |
| 88 | + encoder_projections = (encoder_query, encoder_key, encoder_value) |
| 89 | + |
| 90 | + return query, key, value, encoder_projections |
| 91 | + |
| 92 | + def get_projections(self, attn, hidden_states, encoder_hidden_states=None): |
| 93 | + """Public method to get projections based on whether we're using fused mode or not.""" |
| 94 | + if self.is_fused and hasattr(attn, "to_qkv"): |
| 95 | + return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) |
| 96 | + |
| 97 | + return self._get_projections(attn, hidden_states, encoder_hidden_states) |
| 98 | + |
| 99 | + def attention_fn(self, query, key, value, scale=None, attention_mask=None): |
| 100 | + """Computes the attention. Can be overridden by hardware-specific implementations.""" |
| 101 | + return F.scaled_dot_product_attention( |
| 102 | + query, key, value, attn_mask=attention_mask, scale=scale, dropout_p=0.0, is_causal=False |
| 103 | + ) |
| 104 | + |
| 105 | + |
49 | 106 | class Attention(nn.Module, AttentionModuleMixin):
|
50 | 107 | default_processor_class = AttnProcessorSDPA
|
51 | 108 | _available_processors = []
|
@@ -1292,99 +1349,6 @@ def __call__(
|
1292 | 1349 | return hidden_states
|
1293 | 1350 |
|
1294 | 1351 |
|
1295 |
| -class AuraFlowAttnProcessorSDPA: |
1296 |
| - """Attention processor used typically in processing Aura Flow.""" |
1297 |
| - |
1298 |
| - def __init__(self): |
1299 |
| - if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): |
1300 |
| - raise ImportError( |
1301 |
| - "AuraFlowAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " |
1302 |
| - ) |
1303 |
| - |
1304 |
| - def __call__( |
1305 |
| - self, |
1306 |
| - attn: Attention, |
1307 |
| - hidden_states: torch.FloatTensor, |
1308 |
| - encoder_hidden_states: torch.FloatTensor = None, |
1309 |
| - *args, |
1310 |
| - **kwargs, |
1311 |
| - ) -> torch.FloatTensor: |
1312 |
| - batch_size = hidden_states.shape[0] |
1313 |
| - |
1314 |
| - # `sample` projections. |
1315 |
| - query = attn.to_q(hidden_states) |
1316 |
| - key = attn.to_k(hidden_states) |
1317 |
| - value = attn.to_v(hidden_states) |
1318 |
| - |
1319 |
| - # `context` projections. |
1320 |
| - if encoder_hidden_states is not None: |
1321 |
| - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
1322 |
| - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
1323 |
| - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
1324 |
| - |
1325 |
| - # Reshape. |
1326 |
| - inner_dim = key.shape[-1] |
1327 |
| - head_dim = inner_dim // attn.heads |
1328 |
| - query = query.view(batch_size, -1, attn.heads, head_dim) |
1329 |
| - key = key.view(batch_size, -1, attn.heads, head_dim) |
1330 |
| - value = value.view(batch_size, -1, attn.heads, head_dim) |
1331 |
| - |
1332 |
| - # Apply QK norm. |
1333 |
| - if attn.norm_q is not None: |
1334 |
| - query = attn.norm_q(query) |
1335 |
| - if attn.norm_k is not None: |
1336 |
| - key = attn.norm_k(key) |
1337 |
| - |
1338 |
| - # Concatenate the projections. |
1339 |
| - if encoder_hidden_states is not None: |
1340 |
| - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( |
1341 |
| - batch_size, -1, attn.heads, head_dim |
1342 |
| - ) |
1343 |
| - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) |
1344 |
| - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( |
1345 |
| - batch_size, -1, attn.heads, head_dim |
1346 |
| - ) |
1347 |
| - |
1348 |
| - if attn.norm_added_q is not None: |
1349 |
| - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) |
1350 |
| - if attn.norm_added_k is not None: |
1351 |
| - encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) |
1352 |
| - |
1353 |
| - query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) |
1354 |
| - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) |
1355 |
| - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) |
1356 |
| - |
1357 |
| - query = query.transpose(1, 2) |
1358 |
| - key = key.transpose(1, 2) |
1359 |
| - value = value.transpose(1, 2) |
1360 |
| - |
1361 |
| - # Attention. |
1362 |
| - hidden_states = F.scaled_dot_product_attention( |
1363 |
| - query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False |
1364 |
| - ) |
1365 |
| - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
1366 |
| - hidden_states = hidden_states.to(query.dtype) |
1367 |
| - |
1368 |
| - # Split the attention outputs. |
1369 |
| - if encoder_hidden_states is not None: |
1370 |
| - hidden_states, encoder_hidden_states = ( |
1371 |
| - hidden_states[:, encoder_hidden_states.shape[1] :], |
1372 |
| - hidden_states[:, : encoder_hidden_states.shape[1]], |
1373 |
| - ) |
1374 |
| - |
1375 |
| - # linear proj |
1376 |
| - hidden_states = attn.to_out[0](hidden_states) |
1377 |
| - # dropout |
1378 |
| - hidden_states = attn.to_out[1](hidden_states) |
1379 |
| - if encoder_hidden_states is not None: |
1380 |
| - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
1381 |
| - |
1382 |
| - if encoder_hidden_states is not None: |
1383 |
| - return hidden_states, encoder_hidden_states |
1384 |
| - else: |
1385 |
| - return hidden_states |
1386 |
| - |
1387 |
| - |
1388 | 1352 | class FusedAuraFlowAttnProcessorSDPA:
|
1389 | 1353 | """Attention processor used typically in processing Aura Flow with fused projections."""
|
1390 | 1354 |
|
@@ -2335,104 +2299,6 @@ def __call__(
|
2335 | 2299 | return hidden_states
|
2336 | 2300 |
|
2337 | 2301 |
|
2338 |
| -class HunyuanAttnProcessorSDPA: |
2339 |
| - r""" |
2340 |
| - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is |
2341 |
| - used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. |
2342 |
| - """ |
2343 |
| - |
2344 |
| - def __init__(self): |
2345 |
| - if not hasattr(F, "scaled_dot_product_attention"): |
2346 |
| - raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
2347 |
| - |
2348 |
| - def __call__( |
2349 |
| - self, |
2350 |
| - attn: Attention, |
2351 |
| - hidden_states: torch.Tensor, |
2352 |
| - encoder_hidden_states: Optional[torch.Tensor] = None, |
2353 |
| - attention_mask: Optional[torch.Tensor] = None, |
2354 |
| - temb: Optional[torch.Tensor] = None, |
2355 |
| - image_rotary_emb: Optional[torch.Tensor] = None, |
2356 |
| - ) -> torch.Tensor: |
2357 |
| - from .embeddings import apply_rotary_emb |
2358 |
| - |
2359 |
| - residual = hidden_states |
2360 |
| - if attn.spatial_norm is not None: |
2361 |
| - hidden_states = attn.spatial_norm(hidden_states, temb) |
2362 |
| - |
2363 |
| - input_ndim = hidden_states.ndim |
2364 |
| - |
2365 |
| - if input_ndim == 4: |
2366 |
| - batch_size, channel, height, width = hidden_states.shape |
2367 |
| - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
2368 |
| - |
2369 |
| - batch_size, sequence_length, _ = ( |
2370 |
| - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
2371 |
| - ) |
2372 |
| - |
2373 |
| - if attention_mask is not None: |
2374 |
| - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
2375 |
| - # scaled_dot_product_attention expects attention_mask shape to be |
2376 |
| - # (batch, heads, source_length, target_length) |
2377 |
| - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
2378 |
| - |
2379 |
| - if attn.group_norm is not None: |
2380 |
| - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
2381 |
| - |
2382 |
| - query = attn.to_q(hidden_states) |
2383 |
| - |
2384 |
| - if encoder_hidden_states is None: |
2385 |
| - encoder_hidden_states = hidden_states |
2386 |
| - elif attn.norm_cross: |
2387 |
| - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
2388 |
| - |
2389 |
| - key = attn.to_k(encoder_hidden_states) |
2390 |
| - value = attn.to_v(encoder_hidden_states) |
2391 |
| - |
2392 |
| - inner_dim = key.shape[-1] |
2393 |
| - head_dim = inner_dim // attn.heads |
2394 |
| - |
2395 |
| - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
2396 |
| - |
2397 |
| - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
2398 |
| - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
2399 |
| - |
2400 |
| - if attn.norm_q is not None: |
2401 |
| - query = attn.norm_q(query) |
2402 |
| - if attn.norm_k is not None: |
2403 |
| - key = attn.norm_k(key) |
2404 |
| - |
2405 |
| - # Apply RoPE if needed |
2406 |
| - if image_rotary_emb is not None: |
2407 |
| - query = apply_rotary_emb(query, image_rotary_emb) |
2408 |
| - if not attn.is_cross_attention: |
2409 |
| - key = apply_rotary_emb(key, image_rotary_emb) |
2410 |
| - |
2411 |
| - # the output of sdp = (batch, num_heads, seq_len, head_dim) |
2412 |
| - # TODO: add support for attn.scale when we move to Torch 2.1 |
2413 |
| - hidden_states = F.scaled_dot_product_attention( |
2414 |
| - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
2415 |
| - ) |
2416 |
| - |
2417 |
| - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
2418 |
| - hidden_states = hidden_states.to(query.dtype) |
2419 |
| - |
2420 |
| - # linear proj |
2421 |
| - hidden_states = attn.to_out[0](hidden_states) |
2422 |
| - # dropout |
2423 |
| - hidden_states = attn.to_out[1](hidden_states) |
2424 |
| - |
2425 |
| - if input_ndim == 4: |
2426 |
| - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
2427 |
| - |
2428 |
| - if attn.residual_connection: |
2429 |
| - hidden_states = hidden_states + residual |
2430 |
| - |
2431 |
| - hidden_states = hidden_states / attn.rescale_output_factor |
2432 |
| - |
2433 |
| - return hidden_states |
2434 |
| - |
2435 |
| - |
2436 | 2302 | class FusedHunyuanAttnProcessorSDPA:
|
2437 | 2303 | r"""
|
2438 | 2304 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
|
|
0 commit comments