Skip to content

Commit 14a3f0c

Browse files
committed
make style and make quality
1 parent 1891d8d commit 14a3f0c

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/diffusers/models/transformers/transformer_infinitetalk.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@
2323

2424
from ...configuration_utils import ConfigMixin, register_to_config
2525
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
26-
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
26+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2727
from ...utils.torch_utils import maybe_allow_in_graph
2828
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
2929
from ..attention_dispatch import dispatch_attention_fn
3030
from ..cache_utils import CacheMixin
31-
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
3231
from ..modeling_outputs import Transformer2DModelOutput
3332
from ..modeling_utils import ModelMixin
3433
from ..normalization import FP32LayerNorm
@@ -432,7 +431,7 @@ def __call__(
432431
_, audio_seq_len, _ = encoder_hidden_states.shape
433432
dim_head = attn.inner_dim // attn.heads
434433
dim_head_kv = attn.kv_inner_dim // attn.heads
435-
434+
436435
# For audio cross-attention, reshape such that the seq_len runs over only the spatial dims
437436
hidden_states = hidden_states.reshape(batch_size * grid_size_t, -1, hidden_dim) # [B * N_t, S, C]
438437

@@ -1056,7 +1055,7 @@ def forward(
10561055

10571056
if encoder_hidden_states_image is not None:
10581057
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
1059-
1058+
10601059
# 3. Prepare audio embedding using the audio adapter
10611060
audio_cond = encoder_hidden_states_audio.to(device=hidden_states.device, dtype=hidden_states.dtype)
10621061
audio_cond_first_frame = audio_cond[:, :1, ...]

0 commit comments

Comments
 (0)