Skip to content

Commit 9f2d5c9

Browse files
authored
Flux with Remote Encode (#11091)
* Flux img2img remote encode * Flux inpaint * -copied from
1 parent dc62e69 commit 9f2d5c9

File tree

6 files changed

+21
-12
lines changed

6 files changed

+21
-12
lines changed

src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
533533

534534
return latents
535535

536-
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
537536
def prepare_latents(
538537
self,
539538
image,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
533533

534534
return latents
535535

536-
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
537536
def prepare_latents(
538537
self,
539538
image,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
561561

562562
return latents
563563

564-
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
565564
def prepare_latents(
566565
self,
567566
image,
@@ -614,7 +613,6 @@ def prepare_latents(
614613
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
615614
return latents, noise, image_latents, latent_image_ids
616615

617-
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
618616
def prepare_mask_latents(
619617
self,
620618
mask,

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,10 @@ def __init__(
225225
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
226226
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
227227
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
228-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
228+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
229+
self.image_processor = VaeImageProcessor(
230+
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
231+
)
229232
self.tokenizer_max_length = (
230233
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
231234
)
@@ -634,7 +637,10 @@ def prepare_latents(
634637
return latents.to(device=device, dtype=dtype), latent_image_ids
635638

636639
image = image.to(device=device, dtype=dtype)
637-
image_latents = self._encode_vae_image(image=image, generator=generator)
640+
if image.shape[1] != self.latent_channels:
641+
image_latents = self._encode_vae_image(image=image, generator=generator)
642+
else:
643+
image_latents = image
638644
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
639645
# expand init_latents for batch_size
640646
additional_image_per_prompt = batch_size // image_latents.shape[0]

src/diffusers/pipelines/flux/pipeline_flux_inpaint.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,13 @@ def __init__(
222222
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
223223
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
224224
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
225-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
226-
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
225+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
226+
self.image_processor = VaeImageProcessor(
227+
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
228+
)
227229
self.mask_processor = VaeImageProcessor(
228230
vae_scale_factor=self.vae_scale_factor * 2,
229-
vae_latent_channels=latent_channels,
231+
vae_latent_channels=self.latent_channels,
230232
do_normalize=False,
231233
do_binarize=True,
232234
do_convert_grayscale=True,
@@ -653,7 +655,10 @@ def prepare_latents(
653655
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
654656

655657
image = image.to(device=device, dtype=dtype)
656-
image_latents = self._encode_vae_image(image=image, generator=generator)
658+
if image.shape[1] != self.latent_channels:
659+
image_latents = self._encode_vae_image(image=image, generator=generator)
660+
else:
661+
image_latents = image
657662

658663
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
659664
# expand init_latents for batch_size
@@ -710,7 +715,9 @@ def prepare_mask_latents(
710715
else:
711716
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
712717

713-
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
718+
masked_image_latents = (
719+
masked_image_latents - self.vae.config.shift_factor
720+
) * self.vae.config.scaling_factor
714721

715722
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
716723
if mask.shape[0] < batch_size:

src/diffusers/utils/remote_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def prepare_encode(
367367
if shift_factor is not None:
368368
parameters["shift_factor"] = shift_factor
369369
if isinstance(image, torch.Tensor):
370-
data = safetensors.torch._tobytes(image, "tensor")
370+
data = safetensors.torch._tobytes(image.contiguous(), "tensor")
371371
parameters["shape"] = list(image.shape)
372372
parameters["dtype"] = str(image.dtype).split(".")[-1]
373373
else:

0 commit comments

Comments
 (0)