|
23 | 23 |
|
24 | 24 | from ...configuration_utils import ConfigMixin, register_to_config |
25 | 25 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin |
26 | | -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers |
| 26 | +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers |
27 | 27 | from ...utils.torch_utils import maybe_allow_in_graph |
28 | 28 | from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward |
29 | 29 | from ..attention_dispatch import dispatch_attention_fn |
30 | 30 | from ..cache_utils import CacheMixin |
31 | | -from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed |
32 | 31 | from ..modeling_outputs import Transformer2DModelOutput |
33 | 32 | from ..modeling_utils import ModelMixin |
34 | 33 | from ..normalization import FP32LayerNorm |
@@ -432,7 +431,7 @@ def __call__( |
432 | 431 | _, audio_seq_len, _ = encoder_hidden_states.shape |
433 | 432 | dim_head = attn.inner_dim // attn.heads |
434 | 433 | dim_head_kv = attn.kv_inner_dim // attn.heads |
435 | | - |
| 434 | + |
436 | 435 | # For audio cross-attention, reshape such that the seq_len runs over only the spatial dims |
437 | 436 | hidden_states = hidden_states.reshape(batch_size * grid_size_t, -1, hidden_dim) # [B * N_t, S, C] |
438 | 437 |
|
@@ -1056,7 +1055,7 @@ def forward( |
1056 | 1055 |
|
1057 | 1056 | if encoder_hidden_states_image is not None: |
1058 | 1057 | encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) |
1059 | | - |
| 1058 | + |
1060 | 1059 | # 3. Prepare audio embedding using the audio adapter |
1061 | 1060 | audio_cond = encoder_hidden_states_audio.to(device=hidden_states.device, dtype=hidden_states.dtype) |
1062 | 1061 | audio_cond_first_frame = audio_cond[:, :1, ...] |
|
0 commit comments