Skip to content

Commit 6fe984b

Browse files
committed
add
1 parent 0ef2935 commit 6fe984b

File tree

1 file changed

+75
-89
lines changed

1 file changed

+75
-89
lines changed

src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py

Lines changed: 75 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,9 @@ def _get_llama3_prompt_embeds(
315315
def encode_prompt(
316316
self,
317317
prompt: Union[str, List[str]],
318-
prompt_2: Union[str, List[str]],
319-
prompt_3: Union[str, List[str]],
320-
prompt_4: Union[str, List[str]],
318+
prompt_2: Optional[Union[str, List[str]]] = None,
319+
prompt_3: Optional[Union[str, List[str]]] = None,
320+
prompt_4: Optional[Union[str, List[str]]] = None,
321321
device: Optional[torch.device] = None,
322322
dtype: Optional[torch.dtype] = None,
323323
num_images_per_prompt: int = 1,
@@ -339,118 +339,104 @@ def encode_prompt(
339339
else:
340340
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
341341

342-
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
343-
prompt=prompt,
344-
prompt_2=prompt_2,
345-
prompt_3=prompt_3,
346-
prompt_4=prompt_4,
347-
device=device,
348-
dtype=dtype,
349-
num_images_per_prompt=num_images_per_prompt,
350-
prompt_embeds=prompt_embeds,
351-
pooled_prompt_embeds=pooled_prompt_embeds,
352-
max_sequence_length=max_sequence_length,
353-
)
354-
355-
if do_classifier_free_guidance and negative_prompt_embeds is None:
356-
negative_prompt = negative_prompt or ""
357-
negative_prompt_2 = negative_prompt_2 or negative_prompt
358-
negative_prompt_3 = negative_prompt_3 or negative_prompt
359-
negative_prompt_4 = negative_prompt_4 or negative_prompt
342+
device = device or self._execution_device
360343

361-
# normalize str to list
362-
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
363-
negative_prompt_2 = (
364-
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
365-
)
366-
negative_prompt_3 = (
367-
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
368-
)
369-
negative_prompt_4 = (
370-
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
344+
if pooled_prompt_embeds is None:
345+
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
346+
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
371347
)
372348

373-
if prompt is not None and type(prompt) is not type(negative_prompt):
374-
raise TypeError(
375-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
376-
f" {type(prompt)}."
377-
)
378-
elif batch_size != len(negative_prompt):
379-
raise ValueError(
380-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
381-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
382-
" the batch size of `prompt`."
383-
)
384-
385-
negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
386-
prompt=negative_prompt,
387-
prompt_2=negative_prompt_2,
388-
prompt_3=negative_prompt_3,
389-
prompt_4=negative_prompt_4,
390-
device=device,
391-
dtype=dtype,
392-
num_images_per_prompt=num_images_per_prompt,
393-
prompt_embeds=negative_prompt_embeds,
394-
pooled_prompt_embeds=negative_pooled_prompt_embeds,
395-
max_sequence_length=max_sequence_length,
349+
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
350+
negative_prompt = negative_prompt or ""
351+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
352+
negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
353+
self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype
396354
)
397-
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
398-
399-
def _encode_prompt(
400-
self,
401-
prompt: Union[str, List[str]],
402-
prompt_2: Union[str, List[str]],
403-
prompt_3: Union[str, List[str]],
404-
prompt_4: Union[str, List[str]],
405-
device: Optional[torch.device] = None,
406-
dtype: Optional[torch.dtype] = None,
407-
num_images_per_prompt: int = 1,
408-
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
409-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
410-
max_sequence_length: int = 128,
411-
):
412-
device = device or self._execution_device
413-
if prompt is not None:
414-
batch_size = len(prompt)
415-
else:
416-
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
417355

418356
if pooled_prompt_embeds is None:
419357
prompt_2 = prompt_2 or prompt
420358
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
421359

422-
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
423-
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
424-
)
425360
pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
426361
self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype
427362
)
363+
364+
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
365+
negative_prompt_2 = negative_prompt_2 or ""
366+
negative_prompt_2 = (
367+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
368+
)
369+
negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
370+
self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype
371+
)
372+
373+
if pooled_prompt_embeds is None:
428374
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
429375

430-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
431-
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
376+
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
377+
negative_pooled_prompt_embeds = torch.cat(
378+
[negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1
379+
)
380+
381+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
382+
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
383+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
384+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
432385

433386
if prompt_embeds is None:
434387
prompt_3 = prompt_3 or prompt
435388
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
389+
t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
390+
391+
if do_classifier_free_guidance and negative_prompt_embeds is None:
392+
negative_prompt_3 = negative_prompt_3 or ""
393+
negative_prompt_3 = (
394+
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
395+
)
396+
negative_t5_prompt_embeds = self._get_t5_prompt_embeds(
397+
negative_prompt_3, max_sequence_length, device, dtype
398+
)
436399

400+
if prompt_embeds is None:
437401
prompt_4 = prompt_4 or prompt
438402
prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
439-
440-
t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
441403
llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
442404

443-
_, seq_len, _ = t5_prompt_embeds.shape
444-
t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1)
445-
t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
446-
447-
_, _, seq_len, dim = llama3_prompt_embeds.shape
448-
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
449-
llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
405+
if do_classifier_free_guidance and negative_prompt_embeds is None:
406+
negative_prompt_4 = negative_prompt_4 or ""
407+
negative_prompt_4 = (
408+
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
409+
)
410+
negative_llama3_prompt_embeds = self._get_llama3_prompt_embeds(
411+
negative_prompt_4, max_sequence_length, device, dtype
412+
)
450413

414+
if prompt_embeds is None:
451415
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
452416

453-
return prompt_embeds, pooled_prompt_embeds
417+
if do_classifier_free_guidance and negative_prompt_embeds is None:
418+
negative_prompt_embeds = [negative_t5_prompt_embeds, negative_llama3_prompt_embeds]
419+
420+
_, seq_len, _ = prompt_embeds[0].shape
421+
prompt_embeds[0] = prompt_embeds[0].repeat(1, num_images_per_prompt, 1)
422+
prompt_embeds[0] = prompt_embeds[0].view(batch_size * num_images_per_prompt, seq_len, -1)
423+
424+
_, _, seq_len, dim = prompt_embeds[1].shape
425+
prompt_embeds[1] = prompt_embeds[1].repeat(1, 1, num_images_per_prompt, 1)
426+
prompt_embeds[1] = prompt_embeds[1].view(-1, batch_size * num_images_per_prompt, seq_len, dim)
427+
428+
if do_classifier_free_guidance:
429+
_, seq_len, _ = negative_prompt_embeds[0].shape
430+
negative_prompt_embeds[0] = negative_prompt_embeds[0].repeat(1, num_images_per_prompt, 1)
431+
negative_prompt_embeds[0] = negative_prompt_embeds[0].view(batch_size * num_images_per_prompt, seq_len, -1)
432+
433+
_, _, seq_len, dim = negative_prompt_embeds[1].shape
434+
negative_prompt_embeds[1] = negative_prompt_embeds[1].repeat(1, 1, num_images_per_prompt, 1)
435+
negative_prompt_embeds[1] = negative_prompt_embeds[1].view(
436+
-1, batch_size * num_images_per_prompt, seq_len, dim
437+
)
438+
439+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
454440

455441
def enable_vae_slicing(self):
456442
r"""

0 commit comments

Comments
 (0)