@@ -315,9 +315,9 @@ def _get_llama3_prompt_embeds(
315
315
def encode_prompt (
316
316
self ,
317
317
prompt : Union [str , List [str ]],
318
- prompt_2 : Union [str , List [str ]],
319
- prompt_3 : Union [str , List [str ]],
320
- prompt_4 : Union [str , List [str ]],
318
+ prompt_2 : Optional [ Union [str , List [str ]]] = None ,
319
+ prompt_3 : Optional [ Union [str , List [str ]]] = None ,
320
+ prompt_4 : Optional [ Union [str , List [str ]]] = None ,
321
321
device : Optional [torch .device ] = None ,
322
322
dtype : Optional [torch .dtype ] = None ,
323
323
num_images_per_prompt : int = 1 ,
@@ -339,118 +339,104 @@ def encode_prompt(
339
339
else :
340
340
batch_size = prompt_embeds [0 ].shape [0 ] if isinstance (prompt_embeds , list ) else prompt_embeds .shape [0 ]
341
341
342
- prompt_embeds , pooled_prompt_embeds = self ._encode_prompt (
343
- prompt = prompt ,
344
- prompt_2 = prompt_2 ,
345
- prompt_3 = prompt_3 ,
346
- prompt_4 = prompt_4 ,
347
- device = device ,
348
- dtype = dtype ,
349
- num_images_per_prompt = num_images_per_prompt ,
350
- prompt_embeds = prompt_embeds ,
351
- pooled_prompt_embeds = pooled_prompt_embeds ,
352
- max_sequence_length = max_sequence_length ,
353
- )
354
-
355
- if do_classifier_free_guidance and negative_prompt_embeds is None :
356
- negative_prompt = negative_prompt or ""
357
- negative_prompt_2 = negative_prompt_2 or negative_prompt
358
- negative_prompt_3 = negative_prompt_3 or negative_prompt
359
- negative_prompt_4 = negative_prompt_4 or negative_prompt
342
+ device = device or self ._execution_device
360
343
361
- # normalize str to list
362
- negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
363
- negative_prompt_2 = (
364
- batch_size * [negative_prompt_2 ] if isinstance (negative_prompt_2 , str ) else negative_prompt_2
365
- )
366
- negative_prompt_3 = (
367
- batch_size * [negative_prompt_3 ] if isinstance (negative_prompt_3 , str ) else negative_prompt_3
368
- )
369
- negative_prompt_4 = (
370
- batch_size * [negative_prompt_4 ] if isinstance (negative_prompt_4 , str ) else negative_prompt_4
344
+ if pooled_prompt_embeds is None :
345
+ pooled_prompt_embeds_1 = self ._get_clip_prompt_embeds (
346
+ self .tokenizer , self .text_encoder , prompt , max_sequence_length , device , dtype
371
347
)
372
348
373
- if prompt is not None and type (prompt ) is not type (negative_prompt ):
374
- raise TypeError (
375
- f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
376
- f" { type (prompt )} ."
377
- )
378
- elif batch_size != len (negative_prompt ):
379
- raise ValueError (
380
- f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
381
- f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
382
- " the batch size of `prompt`."
383
- )
384
-
385
- negative_prompt_embeds , negative_pooled_prompt_embeds = self ._encode_prompt (
386
- prompt = negative_prompt ,
387
- prompt_2 = negative_prompt_2 ,
388
- prompt_3 = negative_prompt_3 ,
389
- prompt_4 = negative_prompt_4 ,
390
- device = device ,
391
- dtype = dtype ,
392
- num_images_per_prompt = num_images_per_prompt ,
393
- prompt_embeds = negative_prompt_embeds ,
394
- pooled_prompt_embeds = negative_pooled_prompt_embeds ,
395
- max_sequence_length = max_sequence_length ,
349
+ if do_classifier_free_guidance and negative_pooled_prompt_embeds is None :
350
+ negative_prompt = negative_prompt or ""
351
+ negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
352
+ negative_pooled_prompt_embeds_1 = self ._get_clip_prompt_embeds (
353
+ self .tokenizer , self .text_encoder , negative_prompt , max_sequence_length , device , dtype
396
354
)
397
- return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
398
-
399
- def _encode_prompt (
400
- self ,
401
- prompt : Union [str , List [str ]],
402
- prompt_2 : Union [str , List [str ]],
403
- prompt_3 : Union [str , List [str ]],
404
- prompt_4 : Union [str , List [str ]],
405
- device : Optional [torch .device ] = None ,
406
- dtype : Optional [torch .dtype ] = None ,
407
- num_images_per_prompt : int = 1 ,
408
- prompt_embeds : Optional [List [torch .FloatTensor ]] = None ,
409
- pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
410
- max_sequence_length : int = 128 ,
411
- ):
412
- device = device or self ._execution_device
413
- if prompt is not None :
414
- batch_size = len (prompt )
415
- else :
416
- batch_size = prompt_embeds [0 ].shape [0 ] if isinstance (prompt_embeds , list ) else prompt_embeds .shape [0 ]
417
355
418
356
if pooled_prompt_embeds is None :
419
357
prompt_2 = prompt_2 or prompt
420
358
prompt_2 = [prompt_2 ] if isinstance (prompt_2 , str ) else prompt_2
421
359
422
- pooled_prompt_embeds_1 = self ._get_clip_prompt_embeds (
423
- self .tokenizer , self .text_encoder , prompt , max_sequence_length , device , dtype
424
- )
425
360
pooled_prompt_embeds_2 = self ._get_clip_prompt_embeds (
426
361
self .tokenizer_2 , self .text_encoder_2 , prompt_2 , max_sequence_length , device , dtype
427
362
)
363
+
364
+ if do_classifier_free_guidance and negative_pooled_prompt_embeds is None :
365
+ negative_prompt_2 = negative_prompt_2 or ""
366
+ negative_prompt_2 = (
367
+ batch_size * [negative_prompt_2 ] if isinstance (negative_prompt_2 , str ) else negative_prompt_2
368
+ )
369
+ negative_pooled_prompt_embeds_2 = self ._get_clip_prompt_embeds (
370
+ self .tokenizer_2 , self .text_encoder_2 , negative_prompt_2 , max_sequence_length , device , dtype
371
+ )
372
+
373
+ if pooled_prompt_embeds is None :
428
374
pooled_prompt_embeds = torch .cat ([pooled_prompt_embeds_1 , pooled_prompt_embeds_2 ], dim = - 1 )
429
375
430
- pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , num_images_per_prompt )
431
- pooled_prompt_embeds = pooled_prompt_embeds .view (batch_size * num_images_per_prompt , - 1 )
376
+ if do_classifier_free_guidance and negative_pooled_prompt_embeds is None :
377
+ negative_pooled_prompt_embeds = torch .cat (
378
+ [negative_pooled_prompt_embeds_1 , negative_pooled_prompt_embeds_2 ], dim = - 1
379
+ )
380
+
381
+ pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , num_images_per_prompt )
382
+ pooled_prompt_embeds = pooled_prompt_embeds .view (batch_size * num_images_per_prompt , - 1 )
383
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds .repeat (1 , num_images_per_prompt )
384
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds .view (batch_size * num_images_per_prompt , - 1 )
432
385
433
386
if prompt_embeds is None :
434
387
prompt_3 = prompt_3 or prompt
435
388
prompt_3 = [prompt_3 ] if isinstance (prompt_3 , str ) else prompt_3
389
+ t5_prompt_embeds = self ._get_t5_prompt_embeds (prompt_3 , max_sequence_length , device , dtype )
390
+
391
+ if do_classifier_free_guidance and negative_prompt_embeds is None :
392
+ negative_prompt_3 = negative_prompt_3 or ""
393
+ negative_prompt_3 = (
394
+ batch_size * [negative_prompt_3 ] if isinstance (negative_prompt_3 , str ) else negative_prompt_3
395
+ )
396
+ negative_t5_prompt_embeds = self ._get_t5_prompt_embeds (
397
+ negative_prompt_3 , max_sequence_length , device , dtype
398
+ )
436
399
400
+ if prompt_embeds is None :
437
401
prompt_4 = prompt_4 or prompt
438
402
prompt_4 = [prompt_4 ] if isinstance (prompt_4 , str ) else prompt_4
439
-
440
- t5_prompt_embeds = self ._get_t5_prompt_embeds (prompt_3 , max_sequence_length , device , dtype )
441
403
llama3_prompt_embeds = self ._get_llama3_prompt_embeds (prompt_4 , max_sequence_length , device , dtype )
442
404
443
- _ , seq_len , _ = t5_prompt_embeds .shape
444
- t5_prompt_embeds = t5_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
445
- t5_prompt_embeds = t5_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
446
-
447
- _ , _ , seq_len , dim = llama3_prompt_embeds .shape
448
- llama3_prompt_embeds = llama3_prompt_embeds .repeat (1 , 1 , num_images_per_prompt , 1 )
449
- llama3_prompt_embeds = llama3_prompt_embeds .view (- 1 , batch_size * num_images_per_prompt , seq_len , dim )
405
+ if do_classifier_free_guidance and negative_prompt_embeds is None :
406
+ negative_prompt_4 = negative_prompt_4 or ""
407
+ negative_prompt_4 = (
408
+ batch_size * [negative_prompt_4 ] if isinstance (negative_prompt_4 , str ) else negative_prompt_4
409
+ )
410
+ negative_llama3_prompt_embeds = self ._get_llama3_prompt_embeds (
411
+ negative_prompt_4 , max_sequence_length , device , dtype
412
+ )
450
413
414
+ if prompt_embeds is None :
451
415
prompt_embeds = [t5_prompt_embeds , llama3_prompt_embeds ]
452
416
453
- return prompt_embeds , pooled_prompt_embeds
417
+ if do_classifier_free_guidance and negative_prompt_embeds is None :
418
+ negative_prompt_embeds = [negative_t5_prompt_embeds , negative_llama3_prompt_embeds ]
419
+
420
+ _ , seq_len , _ = prompt_embeds [0 ].shape
421
+ prompt_embeds [0 ] = prompt_embeds [0 ].repeat (1 , num_images_per_prompt , 1 )
422
+ prompt_embeds [0 ] = prompt_embeds [0 ].view (batch_size * num_images_per_prompt , seq_len , - 1 )
423
+
424
+ _ , _ , seq_len , dim = prompt_embeds [1 ].shape
425
+ prompt_embeds [1 ] = prompt_embeds [1 ].repeat (1 , 1 , num_images_per_prompt , 1 )
426
+ prompt_embeds [1 ] = prompt_embeds [1 ].view (- 1 , batch_size * num_images_per_prompt , seq_len , dim )
427
+
428
+ if do_classifier_free_guidance :
429
+ _ , seq_len , _ = negative_prompt_embeds [0 ].shape
430
+ negative_prompt_embeds [0 ] = negative_prompt_embeds [0 ].repeat (1 , num_images_per_prompt , 1 )
431
+ negative_prompt_embeds [0 ] = negative_prompt_embeds [0 ].view (batch_size * num_images_per_prompt , seq_len , - 1 )
432
+
433
+ _ , _ , seq_len , dim = negative_prompt_embeds [1 ].shape
434
+ negative_prompt_embeds [1 ] = negative_prompt_embeds [1 ].repeat (1 , 1 , num_images_per_prompt , 1 )
435
+ negative_prompt_embeds [1 ] = negative_prompt_embeds [1 ].view (
436
+ - 1 , batch_size * num_images_per_prompt , seq_len , dim
437
+ )
438
+
439
+ return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
454
440
455
441
def enable_vae_slicing (self ):
456
442
r"""
0 commit comments