18
18
import numpy as np
19
19
import torch
20
20
from transformers import (
21
+ CLIPImageProcessor ,
21
22
CLIPTextModel ,
22
23
CLIPTokenizer ,
24
+ CLIPVisionModelWithProjection ,
23
25
T5EncoderModel ,
24
26
T5TokenizerFast ,
25
27
)
26
28
27
29
from ...image_processor import PipelineImageInput , VaeImageProcessor
28
- from ...loaders import FluxLoraLoaderMixin , FromSingleFileMixin , TextualInversionLoaderMixin
30
+ from ...loaders import FluxIPAdapterMixin , FluxLoraLoaderMixin , FromSingleFileMixin , TextualInversionLoaderMixin
29
31
from ...models .autoencoders import AutoencoderKL
30
32
from ...models .controlnets .controlnet_flux import FluxControlNetModel , FluxMultiControlNetModel
31
33
from ...models .transformers import FluxTransformer2DModel
@@ -171,7 +173,7 @@ def retrieve_timesteps(
171
173
return timesteps , num_inference_steps
172
174
173
175
174
- class FluxControlNetPipeline (DiffusionPipeline , FluxLoraLoaderMixin , FromSingleFileMixin ):
176
+ class FluxControlNetPipeline (DiffusionPipeline , FluxLoraLoaderMixin , FromSingleFileMixin , FluxIPAdapterMixin ):
175
177
r"""
176
178
The Flux pipeline for text-to-image generation.
177
179
@@ -198,8 +200,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
198
200
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
199
201
"""
200
202
201
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
202
- _optional_components = []
203
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder-> transformer->vae"
204
+ _optional_components = ["image_encoder" , "feature_extractor" ]
203
205
_callback_tensor_inputs = ["latents" , "prompt_embeds" ]
204
206
205
207
def __init__ (
@@ -214,6 +216,8 @@ def __init__(
214
216
controlnet : Union [
215
217
FluxControlNetModel , List [FluxControlNetModel ], Tuple [FluxControlNetModel ], FluxMultiControlNetModel
216
218
],
219
+ image_encoder : CLIPVisionModelWithProjection = None ,
220
+ feature_extractor : CLIPImageProcessor = None ,
217
221
):
218
222
super ().__init__ ()
219
223
if isinstance (controlnet , (list , tuple )):
@@ -228,6 +232,8 @@ def __init__(
228
232
transformer = transformer ,
229
233
scheduler = scheduler ,
230
234
controlnet = controlnet ,
235
+ image_encoder = image_encoder ,
236
+ feature_extractor = feature_extractor ,
231
237
)
232
238
self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
233
239
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
@@ -413,14 +419,62 @@ def encode_prompt(
413
419
414
420
return prompt_embeds , pooled_prompt_embeds , text_ids
415
421
422
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
423
+ def encode_image (self , image , device , num_images_per_prompt ):
424
+ dtype = next (self .image_encoder .parameters ()).dtype
425
+
426
+ if not isinstance (image , torch .Tensor ):
427
+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
428
+
429
+ image = image .to (device = device , dtype = dtype )
430
+ image_embeds = self .image_encoder (image ).image_embeds
431
+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
432
+ return image_embeds
433
+
434
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
435
+ def prepare_ip_adapter_image_embeds (
436
+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt
437
+ ):
438
+ image_embeds = []
439
+ if ip_adapter_image_embeds is None :
440
+ if not isinstance (ip_adapter_image , list ):
441
+ ip_adapter_image = [ip_adapter_image ]
442
+
443
+ if len (ip_adapter_image ) != len (self .transformer .encoder_hid_proj .image_projection_layers ):
444
+ raise ValueError (
445
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .transformer .encoder_hid_proj .image_projection_layers )} IP Adapters."
446
+ )
447
+
448
+ for single_ip_adapter_image , image_proj_layer in zip (
449
+ ip_adapter_image , self .transformer .encoder_hid_proj .image_projection_layers
450
+ ):
451
+ single_image_embeds = self .encode_image (single_ip_adapter_image , device , 1 )
452
+
453
+ image_embeds .append (single_image_embeds [None , :])
454
+ else :
455
+ for single_image_embeds in ip_adapter_image_embeds :
456
+ image_embeds .append (single_image_embeds )
457
+
458
+ ip_adapter_image_embeds = []
459
+ for i , single_image_embeds in enumerate (image_embeds ):
460
+ single_image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
461
+ single_image_embeds = single_image_embeds .to (device = device )
462
+ ip_adapter_image_embeds .append (single_image_embeds )
463
+
464
+ return ip_adapter_image_embeds
465
+
416
466
def check_inputs (
417
467
self ,
418
468
prompt ,
419
469
prompt_2 ,
420
470
height ,
421
471
width ,
472
+ negative_prompt = None ,
473
+ negative_prompt_2 = None ,
422
474
prompt_embeds = None ,
475
+ negative_prompt_embeds = None ,
423
476
pooled_prompt_embeds = None ,
477
+ negative_pooled_prompt_embeds = None ,
424
478
callback_on_step_end_tensor_inputs = None ,
425
479
max_sequence_length = None ,
426
480
):
@@ -455,10 +509,33 @@ def check_inputs(
455
509
elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
456
510
raise ValueError (f"`prompt_2` has to be of type `str` or `list` but is { type (prompt_2 )} " )
457
511
512
+ if negative_prompt is not None and negative_prompt_embeds is not None :
513
+ raise ValueError (
514
+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
515
+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
516
+ )
517
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None :
518
+ raise ValueError (
519
+ f"Cannot forward both `negative_prompt_2`: { negative_prompt_2 } and `negative_prompt_embeds`:"
520
+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
521
+ )
522
+
523
+ if prompt_embeds is not None and negative_prompt_embeds is not None :
524
+ if prompt_embeds .shape != negative_prompt_embeds .shape :
525
+ raise ValueError (
526
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
527
+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
528
+ f" { negative_prompt_embeds .shape } ."
529
+ )
530
+
458
531
if prompt_embeds is not None and pooled_prompt_embeds is None :
459
532
raise ValueError (
460
533
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
461
534
)
535
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None :
536
+ raise ValueError (
537
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
538
+ )
462
539
463
540
if max_sequence_length is not None and max_sequence_length > 512 :
464
541
raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
@@ -597,6 +674,9 @@ def __call__(
597
674
self ,
598
675
prompt : Union [str , List [str ]] = None ,
599
676
prompt_2 : Optional [Union [str , List [str ]]] = None ,
677
+ negative_prompt : Union [str , List [str ]] = None ,
678
+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
679
+ true_cfg_scale : float = 1.0 ,
600
680
height : Optional [int ] = None ,
601
681
width : Optional [int ] = None ,
602
682
num_inference_steps : int = 28 ,
@@ -612,6 +692,12 @@ def __call__(
612
692
latents : Optional [torch .FloatTensor ] = None ,
613
693
prompt_embeds : Optional [torch .FloatTensor ] = None ,
614
694
pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
695
+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
696
+ ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
697
+ negative_ip_adapter_image : Optional [PipelineImageInput ] = None ,
698
+ negative_ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
699
+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
700
+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
615
701
output_type : Optional [str ] = "pil" ,
616
702
return_dict : bool = True ,
617
703
joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -679,6 +765,17 @@ def __call__(
679
765
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
680
766
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
681
767
If not provided, pooled text embeddings will be generated from `prompt` input argument.
768
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
769
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
770
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
771
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
772
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
773
+ negative_ip_adapter_image:
774
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
775
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
776
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
777
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
778
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
682
779
output_type (`str`, *optional*, defaults to `"pil"`):
683
780
The output format of the generate image. Choose between
684
781
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -727,8 +824,12 @@ def __call__(
727
824
prompt_2 ,
728
825
height ,
729
826
width ,
827
+ negative_prompt = negative_prompt ,
828
+ negative_prompt_2 = negative_prompt_2 ,
730
829
prompt_embeds = prompt_embeds ,
830
+ negative_prompt_embeds = negative_prompt_embeds ,
731
831
pooled_prompt_embeds = pooled_prompt_embeds ,
832
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
732
833
callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
733
834
max_sequence_length = max_sequence_length ,
734
835
)
@@ -752,6 +853,7 @@ def __call__(
752
853
lora_scale = (
753
854
self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
754
855
)
856
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
755
857
(
756
858
prompt_embeds ,
757
859
pooled_prompt_embeds ,
@@ -766,6 +868,21 @@ def __call__(
766
868
max_sequence_length = max_sequence_length ,
767
869
lora_scale = lora_scale ,
768
870
)
871
+ if do_true_cfg :
872
+ (
873
+ negative_prompt_embeds ,
874
+ negative_pooled_prompt_embeds ,
875
+ _ ,
876
+ ) = self .encode_prompt (
877
+ prompt = negative_prompt ,
878
+ prompt_2 = negative_prompt_2 ,
879
+ prompt_embeds = negative_prompt_embeds ,
880
+ pooled_prompt_embeds = negative_pooled_prompt_embeds ,
881
+ device = device ,
882
+ num_images_per_prompt = num_images_per_prompt ,
883
+ max_sequence_length = max_sequence_length ,
884
+ lora_scale = lora_scale ,
885
+ )
769
886
770
887
# 3. Prepare control image
771
888
num_channels_latents = self .transformer .config .in_channels // 4
@@ -899,12 +1016,43 @@ def __call__(
899
1016
]
900
1017
controlnet_keep .append (keeps [0 ] if isinstance (self .controlnet , FluxControlNetModel ) else keeps )
901
1018
1019
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None ) and (
1020
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1021
+ ):
1022
+ negative_ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
1023
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None ) and (
1024
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1025
+ ):
1026
+ ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
1027
+
1028
+ if self .joint_attention_kwargs is None :
1029
+ self ._joint_attention_kwargs = {}
1030
+
1031
+ image_embeds = None
1032
+ negative_image_embeds = None
1033
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
1034
+ image_embeds = self .prepare_ip_adapter_image_embeds (
1035
+ ip_adapter_image ,
1036
+ ip_adapter_image_embeds ,
1037
+ device ,
1038
+ batch_size * num_images_per_prompt ,
1039
+ )
1040
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None :
1041
+ negative_image_embeds = self .prepare_ip_adapter_image_embeds (
1042
+ negative_ip_adapter_image ,
1043
+ negative_ip_adapter_image_embeds ,
1044
+ device ,
1045
+ batch_size * num_images_per_prompt ,
1046
+ )
1047
+
902
1048
# 7. Denoising loop
903
1049
with self .progress_bar (total = num_inference_steps ) as progress_bar :
904
1050
for i , t in enumerate (timesteps ):
905
1051
if self .interrupt :
906
1052
continue
907
1053
1054
+ if image_embeds is not None :
1055
+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = image_embeds
908
1056
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
909
1057
timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
910
1058
@@ -960,6 +1108,25 @@ def __call__(
960
1108
controlnet_blocks_repeat = controlnet_blocks_repeat ,
961
1109
)[0 ]
962
1110
1111
+ if do_true_cfg :
1112
+ if negative_image_embeds is not None :
1113
+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
1114
+ neg_noise_pred = self .transformer (
1115
+ hidden_states = latents ,
1116
+ timestep = timestep / 1000 ,
1117
+ guidance = guidance ,
1118
+ pooled_projections = negative_pooled_prompt_embeds ,
1119
+ encoder_hidden_states = negative_prompt_embeds ,
1120
+ controlnet_block_samples = controlnet_block_samples ,
1121
+ controlnet_single_block_samples = controlnet_single_block_samples ,
1122
+ txt_ids = text_ids ,
1123
+ img_ids = latent_image_ids ,
1124
+ joint_attention_kwargs = self .joint_attention_kwargs ,
1125
+ return_dict = False ,
1126
+ controlnet_blocks_repeat = controlnet_blocks_repeat ,
1127
+ )[0 ]
1128
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred )
1129
+
963
1130
# compute the previous noisy sample x_t -> x_t-1
964
1131
latents_dtype = latents .dtype
965
1132
latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments