Skip to content

Commit 694f965

Browse files
hlkysayakpaul
andauthored
Support IPAdapter for more Flux pipelines (#10708)
* Support IPAdapter for more Flux pipelines * -copied from --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 2d8a41c commit 694f965

File tree

8 files changed

+531
-22
lines changed

8 files changed

+531
-22
lines changed

src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,6 @@ def get_timesteps(self, num_inference_steps, strength, device):
438438

439439
return timesteps, num_inference_steps - t_start
440440

441-
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs
442441
def check_inputs(
443442
self,
444443
prompt,

src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,6 @@ def get_timesteps(self, num_inference_steps, strength, device):
477477

478478
return timesteps, num_inference_steps - t_start
479479

480-
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs
481480
def check_inputs(
482481
self,
483482
prompt,

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 171 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
import numpy as np
1919
import torch
2020
from transformers import (
21+
CLIPImageProcessor,
2122
CLIPTextModel,
2223
CLIPTokenizer,
24+
CLIPVisionModelWithProjection,
2325
T5EncoderModel,
2426
T5TokenizerFast,
2527
)
2628

2729
from ...image_processor import PipelineImageInput, VaeImageProcessor
28-
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
30+
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
2931
from ...models.autoencoders import AutoencoderKL
3032
from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
3133
from ...models.transformers import FluxTransformer2DModel
@@ -171,7 +173,7 @@ def retrieve_timesteps(
171173
return timesteps, num_inference_steps
172174

173175

174-
class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
176+
class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin):
175177
r"""
176178
The Flux pipeline for text-to-image generation.
177179
@@ -198,8 +200,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
198200
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
199201
"""
200202

201-
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
202-
_optional_components = []
203+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
204+
_optional_components = ["image_encoder", "feature_extractor"]
203205
_callback_tensor_inputs = ["latents", "prompt_embeds"]
204206

205207
def __init__(
@@ -214,6 +216,8 @@ def __init__(
214216
controlnet: Union[
215217
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
216218
],
219+
image_encoder: CLIPVisionModelWithProjection = None,
220+
feature_extractor: CLIPImageProcessor = None,
217221
):
218222
super().__init__()
219223
if isinstance(controlnet, (list, tuple)):
@@ -228,6 +232,8 @@ def __init__(
228232
transformer=transformer,
229233
scheduler=scheduler,
230234
controlnet=controlnet,
235+
image_encoder=image_encoder,
236+
feature_extractor=feature_extractor,
231237
)
232238
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
233239
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
@@ -413,14 +419,62 @@ def encode_prompt(
413419

414420
return prompt_embeds, pooled_prompt_embeds, text_ids
415421

422+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
423+
def encode_image(self, image, device, num_images_per_prompt):
424+
dtype = next(self.image_encoder.parameters()).dtype
425+
426+
if not isinstance(image, torch.Tensor):
427+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
428+
429+
image = image.to(device=device, dtype=dtype)
430+
image_embeds = self.image_encoder(image).image_embeds
431+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
432+
return image_embeds
433+
434+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
435+
def prepare_ip_adapter_image_embeds(
436+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
437+
):
438+
image_embeds = []
439+
if ip_adapter_image_embeds is None:
440+
if not isinstance(ip_adapter_image, list):
441+
ip_adapter_image = [ip_adapter_image]
442+
443+
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
444+
raise ValueError(
445+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
446+
)
447+
448+
for single_ip_adapter_image, image_proj_layer in zip(
449+
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
450+
):
451+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
452+
453+
image_embeds.append(single_image_embeds[None, :])
454+
else:
455+
for single_image_embeds in ip_adapter_image_embeds:
456+
image_embeds.append(single_image_embeds)
457+
458+
ip_adapter_image_embeds = []
459+
for i, single_image_embeds in enumerate(image_embeds):
460+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
461+
single_image_embeds = single_image_embeds.to(device=device)
462+
ip_adapter_image_embeds.append(single_image_embeds)
463+
464+
return ip_adapter_image_embeds
465+
416466
def check_inputs(
417467
self,
418468
prompt,
419469
prompt_2,
420470
height,
421471
width,
472+
negative_prompt=None,
473+
negative_prompt_2=None,
422474
prompt_embeds=None,
475+
negative_prompt_embeds=None,
423476
pooled_prompt_embeds=None,
477+
negative_pooled_prompt_embeds=None,
424478
callback_on_step_end_tensor_inputs=None,
425479
max_sequence_length=None,
426480
):
@@ -455,10 +509,33 @@ def check_inputs(
455509
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
456510
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
457511

512+
if negative_prompt is not None and negative_prompt_embeds is not None:
513+
raise ValueError(
514+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
515+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
516+
)
517+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
518+
raise ValueError(
519+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
520+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
521+
)
522+
523+
if prompt_embeds is not None and negative_prompt_embeds is not None:
524+
if prompt_embeds.shape != negative_prompt_embeds.shape:
525+
raise ValueError(
526+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
527+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
528+
f" {negative_prompt_embeds.shape}."
529+
)
530+
458531
if prompt_embeds is not None and pooled_prompt_embeds is None:
459532
raise ValueError(
460533
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
461534
)
535+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
536+
raise ValueError(
537+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
538+
)
462539

463540
if max_sequence_length is not None and max_sequence_length > 512:
464541
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -597,6 +674,9 @@ def __call__(
597674
self,
598675
prompt: Union[str, List[str]] = None,
599676
prompt_2: Optional[Union[str, List[str]]] = None,
677+
negative_prompt: Union[str, List[str]] = None,
678+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
679+
true_cfg_scale: float = 1.0,
600680
height: Optional[int] = None,
601681
width: Optional[int] = None,
602682
num_inference_steps: int = 28,
@@ -612,6 +692,12 @@ def __call__(
612692
latents: Optional[torch.FloatTensor] = None,
613693
prompt_embeds: Optional[torch.FloatTensor] = None,
614694
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
695+
ip_adapter_image: Optional[PipelineImageInput] = None,
696+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
697+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
698+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
699+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
700+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
615701
output_type: Optional[str] = "pil",
616702
return_dict: bool = True,
617703
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -679,6 +765,17 @@ def __call__(
679765
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
680766
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
681767
If not provided, pooled text embeddings will be generated from `prompt` input argument.
768+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
769+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
770+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
771+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
772+
provided, embeddings are computed from the `ip_adapter_image` input argument.
773+
negative_ip_adapter_image:
774+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
775+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
776+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
777+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
778+
provided, embeddings are computed from the `ip_adapter_image` input argument.
682779
output_type (`str`, *optional*, defaults to `"pil"`):
683780
The output format of the generate image. Choose between
684781
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -727,8 +824,12 @@ def __call__(
727824
prompt_2,
728825
height,
729826
width,
827+
negative_prompt=negative_prompt,
828+
negative_prompt_2=negative_prompt_2,
730829
prompt_embeds=prompt_embeds,
830+
negative_prompt_embeds=negative_prompt_embeds,
731831
pooled_prompt_embeds=pooled_prompt_embeds,
832+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
732833
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
733834
max_sequence_length=max_sequence_length,
734835
)
@@ -752,6 +853,7 @@ def __call__(
752853
lora_scale = (
753854
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
754855
)
856+
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
755857
(
756858
prompt_embeds,
757859
pooled_prompt_embeds,
@@ -766,6 +868,21 @@ def __call__(
766868
max_sequence_length=max_sequence_length,
767869
lora_scale=lora_scale,
768870
)
871+
if do_true_cfg:
872+
(
873+
negative_prompt_embeds,
874+
negative_pooled_prompt_embeds,
875+
_,
876+
) = self.encode_prompt(
877+
prompt=negative_prompt,
878+
prompt_2=negative_prompt_2,
879+
prompt_embeds=negative_prompt_embeds,
880+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
881+
device=device,
882+
num_images_per_prompt=num_images_per_prompt,
883+
max_sequence_length=max_sequence_length,
884+
lora_scale=lora_scale,
885+
)
769886

770887
# 3. Prepare control image
771888
num_channels_latents = self.transformer.config.in_channels // 4
@@ -899,12 +1016,43 @@ def __call__(
8991016
]
9001017
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
9011018

1019+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1020+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1021+
):
1022+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1023+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1024+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1025+
):
1026+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1027+
1028+
if self.joint_attention_kwargs is None:
1029+
self._joint_attention_kwargs = {}
1030+
1031+
image_embeds = None
1032+
negative_image_embeds = None
1033+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1034+
image_embeds = self.prepare_ip_adapter_image_embeds(
1035+
ip_adapter_image,
1036+
ip_adapter_image_embeds,
1037+
device,
1038+
batch_size * num_images_per_prompt,
1039+
)
1040+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1041+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1042+
negative_ip_adapter_image,
1043+
negative_ip_adapter_image_embeds,
1044+
device,
1045+
batch_size * num_images_per_prompt,
1046+
)
1047+
9021048
# 7. Denoising loop
9031049
with self.progress_bar(total=num_inference_steps) as progress_bar:
9041050
for i, t in enumerate(timesteps):
9051051
if self.interrupt:
9061052
continue
9071053

1054+
if image_embeds is not None:
1055+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
9081056
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
9091057
timestep = t.expand(latents.shape[0]).to(latents.dtype)
9101058

@@ -960,6 +1108,25 @@ def __call__(
9601108
controlnet_blocks_repeat=controlnet_blocks_repeat,
9611109
)[0]
9621110

1111+
if do_true_cfg:
1112+
if negative_image_embeds is not None:
1113+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1114+
neg_noise_pred = self.transformer(
1115+
hidden_states=latents,
1116+
timestep=timestep / 1000,
1117+
guidance=guidance,
1118+
pooled_projections=negative_pooled_prompt_embeds,
1119+
encoder_hidden_states=negative_prompt_embeds,
1120+
controlnet_block_samples=controlnet_block_samples,
1121+
controlnet_single_block_samples=controlnet_single_block_samples,
1122+
txt_ids=text_ids,
1123+
img_ids=latent_image_ids,
1124+
joint_attention_kwargs=self.joint_attention_kwargs,
1125+
return_dict=False,
1126+
controlnet_blocks_repeat=controlnet_blocks_repeat,
1127+
)[0]
1128+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1129+
9631130
# compute the previous noisy sample x_t -> x_t-1
9641131
latents_dtype = latents.dtype
9651132
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

0 commit comments

Comments
 (0)