|
7 | 7 | import torch
|
8 | 8 | import torchvision.transforms as tv_transforms
|
9 | 9 | from PIL import Image
|
| 10 | +import torchvision.transforms.functional as TVF |
10 | 11 | from torchvision.transforms.functional import resize as tv_resize
|
11 | 12 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
12 | 13 |
|
|
17 | 18 | FluxConditioningField,
|
18 | 19 | FluxFillConditioningField,
|
19 | 20 | FluxReduxConditioningField,
|
| 21 | + FluxUnoReferenceField, |
20 | 22 | ImageField,
|
21 | 23 | Input,
|
22 | 24 | InputField,
|
|
27 | 29 | from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
28 | 30 | from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
29 | 31 | from invokeai.app.invocations.ip_adapter import IPAdapterField
|
| 32 | +from invokeai.app.invocations.flux_uno import preprocess_ref |
30 | 33 | from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
|
31 | 34 | from invokeai.app.invocations.primitives import LatentsOutput
|
32 | 35 | from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
42 | 45 | from invokeai.backend.flux.sampling_utils import (
|
43 | 46 | clip_timestep_schedule_fractional,
|
44 | 47 | generate_img_ids,
|
| 48 | + prepare_multi_ip, |
45 | 49 | get_noise,
|
46 | 50 | get_schedule,
|
47 | 51 | pack,
|
@@ -109,6 +113,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
109 | 113 | description="FLUX Redux conditioning tensor.",
|
110 | 114 | input=Input.Connection,
|
111 | 115 | )
|
| 116 | + uno_reference: FluxUnoReferenceField | None = InputField( |
| 117 | + default=None, |
| 118 | + description="FLUX Redux conditioning tensor.", |
| 119 | + input=Input.Connection, |
| 120 | + ) |
112 | 121 | fill_conditioning: FluxFillConditioningField | None = InputField(
|
113 | 122 | default=None,
|
114 | 123 | description="FLUX Fill conditioning.",
|
@@ -284,6 +293,15 @@ def _run_diffusion(
|
284 | 293 |
|
285 | 294 | img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
286 | 295 |
|
| 296 | + is_flux_uno = self.uno_reference is not None |
| 297 | + if is_flux_uno: |
| 298 | + # Encode reference images and prepare position ids |
| 299 | + uno_ref_imgs = self._prep_uno_reference_imgs(context) |
| 300 | + uno_ref_imgs, uno_ref_ids = prepare_multi_ip(x, uno_ref_imgs) |
| 301 | + else: |
| 302 | + uno_ref_imgs = None |
| 303 | + uno_ref_ids = None |
| 304 | + |
287 | 305 | # Pack all latent tensors.
|
288 | 306 | init_latents = pack(init_latents) if init_latents is not None else None
|
289 | 307 | inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
@@ -391,6 +409,8 @@ def _run_diffusion(
|
391 | 409 | pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
392 | 410 | neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
393 | 411 | img_cond=img_cond,
|
| 412 | + uno_ref_imgs=uno_ref_imgs, |
| 413 | + uno_ref_ids=uno_ref_ids, |
394 | 414 | )
|
395 | 415 |
|
396 | 416 | x = unpack(x.float(), self.height, self.width)
|
@@ -657,6 +677,30 @@ def _prep_controlnet_extensions(
|
657 | 677 | raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
|
658 | 678 |
|
659 | 679 | return controlnet_extensions
|
| 680 | + |
| 681 | + def _prep_uno_reference_imgs(self, context: InvocationContext) -> list[torch.Tensor]: |
| 682 | + # Load the conditioning image and resize it to the target image size. |
| 683 | + assert self.controlnet_vae is not None, 'Controlnet Vae must be set for UNO encoding' |
| 684 | + vae_info = context.models.load(self.controlnet_vae.vae) |
| 685 | + |
| 686 | + assert self.uno_reference is not None, "Needs reference images for UNO" |
| 687 | + |
| 688 | + ref_img_names: list[str] = self.uno_reference.image_names |
| 689 | + ref_latents: list[torch.Tensor] = [] |
| 690 | + |
| 691 | + # TODO: Maybe move reference side to UNO Node |
| 692 | + ref_long_side = 512 if len(ref_img_names) <= 1 else 320 |
| 693 | + |
| 694 | + for img_name in ref_img_names: |
| 695 | + image_pil = context.images.get_pil(img_name) |
| 696 | + image_pil = image_pil.convert("RGB") # To correct resizing |
| 697 | + image_pil = preprocess_ref(image_pil, ref_long_side) # resize and crop |
| 698 | + |
| 699 | + image_tensor = (TVF.to_tensor(image_pil) * 2.0 - 1.0).unsqueeze(0).float() |
| 700 | + ref_latent = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor) |
| 701 | + ref_latents.append(ref_latent) |
| 702 | + |
| 703 | + return ref_latents |
660 | 704 |
|
661 | 705 | def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
|
662 | 706 | if self.control_lora is None:
|
@@ -714,6 +758,7 @@ def _prep_flux_fill_img_cond(
|
714 | 758 | cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
|
715 | 759 | cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
|
716 | 760 | cond_img = np.array(cond_img)
|
| 761 | + |
717 | 762 | cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
|
718 | 763 | cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
|
719 | 764 | cond_img = cond_img.to(device=device, dtype=dtype)
|
|
0 commit comments