Skip to content

Commit 66394bf

Browse files
authored
Chroma Follow Up (#11725)
* update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * updte * update * update * update
1 parent 62cce30 commit 66394bf

File tree

9 files changed

+1356
-29
lines changed

9 files changed

+1356
-29
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@
353353
"AuraFlowPipeline",
354354
"BlipDiffusionControlNetPipeline",
355355
"BlipDiffusionPipeline",
356+
"ChromaImg2ImgPipeline",
356357
"ChromaPipeline",
357358
"CLIPImageProjection",
358359
"CogVideoXFunControlPipeline",
@@ -945,6 +946,7 @@
945946
AudioLDM2UNet2DConditionModel,
946947
AudioLDMPipeline,
947948
AuraFlowPipeline,
949+
ChromaImg2ImgPipeline,
948950
ChromaPipeline,
949951
CLIPImageProjection,
950952
CogVideoXFunControlPipeline,

src/diffusers/models/attention_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2543,7 +2543,9 @@ def __call__(
25432543
query = apply_rotary_emb(query, image_rotary_emb)
25442544
key = apply_rotary_emb(key, image_rotary_emb)
25452545

2546-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2546+
hidden_states = F.scaled_dot_product_attention(
2547+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2548+
)
25472549

25482550
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
25492551
hidden_states = hidden_states.to(query.dtype)
@@ -2776,7 +2778,9 @@ def __call__(
27762778
query = apply_rotary_emb(query, image_rotary_emb)
27772779
key = apply_rotary_emb(key, image_rotary_emb)
27782780

2779-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2781+
hidden_states = F.scaled_dot_product_attention(
2782+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2783+
)
27802784
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
27812785
hidden_states = hidden_states.to(query.dtype)
27822786

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,15 +250,21 @@ def forward(
250250
hidden_states: torch.Tensor,
251251
temb: torch.Tensor,
252252
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
253+
attention_mask: Optional[torch.Tensor] = None,
253254
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
254255
) -> torch.Tensor:
255256
residual = hidden_states
256257
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
257258
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
258259
joint_attention_kwargs = joint_attention_kwargs or {}
260+
261+
if attention_mask is not None:
262+
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
263+
259264
attn_output = self.attn(
260265
hidden_states=norm_hidden_states,
261266
image_rotary_emb=image_rotary_emb,
267+
attention_mask=attention_mask,
262268
**joint_attention_kwargs,
263269
)
264270

@@ -312,6 +318,7 @@ def forward(
312318
encoder_hidden_states: torch.Tensor,
313319
temb: torch.Tensor,
314320
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
321+
attention_mask: Optional[torch.Tensor] = None,
315322
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
316323
) -> Tuple[torch.Tensor, torch.Tensor]:
317324
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
@@ -321,11 +328,15 @@ def forward(
321328
encoder_hidden_states, emb=temb_txt
322329
)
323330
joint_attention_kwargs = joint_attention_kwargs or {}
331+
if attention_mask is not None:
332+
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
333+
324334
# Attention.
325335
attention_outputs = self.attn(
326336
hidden_states=norm_hidden_states,
327337
encoder_hidden_states=norm_encoder_hidden_states,
328338
image_rotary_emb=image_rotary_emb,
339+
attention_mask=attention_mask,
329340
**joint_attention_kwargs,
330341
)
331342

@@ -570,6 +581,7 @@ def forward(
570581
timestep: torch.LongTensor = None,
571582
img_ids: torch.Tensor = None,
572583
txt_ids: torch.Tensor = None,
584+
attention_mask: torch.Tensor = None,
573585
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
574586
controlnet_block_samples=None,
575587
controlnet_single_block_samples=None,
@@ -659,11 +671,7 @@ def forward(
659671
)
660672
if torch.is_grad_enabled() and self.gradient_checkpointing:
661673
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
662-
block,
663-
hidden_states,
664-
encoder_hidden_states,
665-
temb,
666-
image_rotary_emb,
674+
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
667675
)
668676

669677
else:
@@ -672,6 +680,7 @@ def forward(
672680
encoder_hidden_states=encoder_hidden_states,
673681
temb=temb,
674682
image_rotary_emb=image_rotary_emb,
683+
attention_mask=attention_mask,
675684
joint_attention_kwargs=joint_attention_kwargs,
676685
)
677686

@@ -704,6 +713,7 @@ def forward(
704713
hidden_states=hidden_states,
705714
temb=temb,
706715
image_rotary_emb=image_rotary_emb,
716+
attention_mask=attention_mask,
707717
joint_attention_kwargs=joint_attention_kwargs,
708718
)
709719

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@
148148
"AudioLDM2UNet2DConditionModel",
149149
]
150150
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
151-
_import_structure["chroma"] = ["ChromaPipeline"]
151+
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
152152
_import_structure["cogvideo"] = [
153153
"CogVideoXPipeline",
154154
"CogVideoXImageToVideoPipeline",
@@ -537,7 +537,7 @@
537537
)
538538
from .aura_flow import AuraFlowPipeline
539539
from .blip_diffusion import BlipDiffusionPipeline
540-
from .chroma import ChromaPipeline
540+
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
541541
from .cogvideo import (
542542
CogVideoXFunControlPipeline,
543543
CogVideoXImageToVideoPipeline,

src/diffusers/pipelines/chroma/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
26+
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
2627
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2728
try:
2829
if not (is_transformers_available() and is_torch_available()):
@@ -31,6 +32,7 @@
3132
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3233
else:
3334
from .pipeline_chroma import ChromaPipeline
35+
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
3436
else:
3537
import sys
3638

0 commit comments

Comments
 (0)