-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Description
FluxPipeline
has utilities that give us img_ids
and txt_ids
:
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): |
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) |
As such these are not created inside the transformer
class.
Whereas in HiDream
, we have something different.
text_ids
are created inside the transformer
class:
txt_ids = torch.zeros( |
img_ids
are overwritten:
https://github.com/huggingface/diffusers/blob/ce1063acfa0cbc2168a7e9dddd4282ab8013b810/src/diffusers/models/transformers/transformer_hidream_image.py#L771C13-L771C20 (probably intentional because it's conditioned)
Then the entire computation
diffusers/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
Lines 726 to 744 in ce1063a
if latents.shape[-2] != latents.shape[-1]: | |
B, C, H, W = latents.shape | |
pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size | |
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) | |
img_ids = torch.zeros(pH, pW, 3) | |
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] | |
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] | |
img_ids = img_ids.reshape(pH * pW, -1) | |
img_ids_pad = torch.zeros(self.transformer.max_seq, 3) | |
img_ids_pad[: pH * pW, :] = img_ids | |
img_sizes = img_sizes.unsqueeze(0).to(latents.device) | |
img_ids = img_ids_pad.unsqueeze(0).to(latents.device) | |
if self.do_classifier_free_guidance: | |
img_sizes = img_sizes.repeat(2 * B, 1) | |
img_ids = img_ids.repeat(2 * B, 1, 1) | |
else: | |
img_sizes = img_ids = None |
happens inside the pipeline __call__()
. Maybe this could take place inside a method similar to the FluxPipeline
?
In general, these could be standardized a bit.
Cc: @yiyixuxu @a-r-r-o-w