Skip to content

Commit 29ddb7f

Browse files
author
Attashe
committed
ace++ and uno framework integration
1 parent e7e874f commit 29ddb7f

File tree

7 files changed

+483
-1
lines changed

7 files changed

+483
-1
lines changed

invokeai/app/invocations/fields.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Any, Callable, Optional, Tuple
2+
from typing import Any, Callable, Optional, Tuple, List
33

44
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
55
from pydantic.fields import _Unset
@@ -280,6 +280,15 @@ class FluxReduxConditioningField(BaseModel):
280280
)
281281

282282

283+
class FluxUnoReferenceField(BaseModel):
284+
"""A FLUX Uno image list primitive value"""
285+
286+
image_names: List[str] = Field(
287+
default=None,
288+
description="The name of the image associated with this conditioning tensor. This is used to store the image "
289+
"in the context.",
290+
)
291+
283292
class FluxFillConditioningField(BaseModel):
284293
"""A FLUX Fill conditioning field."""
285294

invokeai/app/invocations/flux_denoise.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torchvision.transforms as tv_transforms
99
from PIL import Image
10+
import torchvision.transforms.functional as TVF
1011
from torchvision.transforms.functional import resize as tv_resize
1112
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
1213

@@ -17,6 +18,7 @@
1718
FluxConditioningField,
1819
FluxFillConditioningField,
1920
FluxReduxConditioningField,
21+
FluxUnoReferenceField,
2022
ImageField,
2123
Input,
2224
InputField,
@@ -27,6 +29,7 @@
2729
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
2830
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
2931
from invokeai.app.invocations.ip_adapter import IPAdapterField
32+
from invokeai.app.invocations.flux_uno import preprocess_ref
3033
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
3134
from invokeai.app.invocations.primitives import LatentsOutput
3235
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -42,6 +45,7 @@
4245
from invokeai.backend.flux.sampling_utils import (
4346
clip_timestep_schedule_fractional,
4447
generate_img_ids,
48+
prepare_multi_ip,
4549
get_noise,
4650
get_schedule,
4751
pack,
@@ -109,6 +113,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
109113
description="FLUX Redux conditioning tensor.",
110114
input=Input.Connection,
111115
)
116+
uno_reference: FluxUnoReferenceField | None = InputField(
117+
default=None,
118+
description="FLUX Redux conditioning tensor.",
119+
input=Input.Connection,
120+
)
112121
fill_conditioning: FluxFillConditioningField | None = InputField(
113122
default=None,
114123
description="FLUX Fill conditioning.",
@@ -284,6 +293,15 @@ def _run_diffusion(
284293

285294
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
286295

296+
is_flux_uno = self.uno_reference is not None
297+
if is_flux_uno:
298+
# Encode reference images and prepare position ids
299+
uno_ref_imgs = self._prep_uno_reference_imgs(context)
300+
uno_ref_imgs, uno_ref_ids = prepare_multi_ip(x, uno_ref_imgs)
301+
else:
302+
uno_ref_imgs = None
303+
uno_ref_ids = None
304+
287305
# Pack all latent tensors.
288306
init_latents = pack(init_latents) if init_latents is not None else None
289307
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
@@ -391,6 +409,8 @@ def _run_diffusion(
391409
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
392410
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
393411
img_cond=img_cond,
412+
uno_ref_imgs=uno_ref_imgs,
413+
uno_ref_ids=uno_ref_ids,
394414
)
395415

396416
x = unpack(x.float(), self.height, self.width)
@@ -657,6 +677,30 @@ def _prep_controlnet_extensions(
657677
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
658678

659679
return controlnet_extensions
680+
681+
def _prep_uno_reference_imgs(self, context: InvocationContext) -> list[torch.Tensor]:
682+
# Load the conditioning image and resize it to the target image size.
683+
assert self.controlnet_vae is not None, 'Controlnet Vae must be set for UNO encoding'
684+
vae_info = context.models.load(self.controlnet_vae.vae)
685+
686+
assert self.uno_reference is not None, "Needs reference images for UNO"
687+
688+
ref_img_names: list[str] = self.uno_reference.image_names
689+
ref_latents: list[torch.Tensor] = []
690+
691+
# TODO: Maybe move reference side to UNO Node
692+
ref_long_side = 512 if len(ref_img_names) <= 1 else 320
693+
694+
for img_name in ref_img_names:
695+
image_pil = context.images.get_pil(img_name)
696+
image_pil = image_pil.convert("RGB") # To correct resizing
697+
image_pil = preprocess_ref(image_pil, ref_long_side) # resize and crop
698+
699+
image_tensor = (TVF.to_tensor(image_pil) * 2.0 - 1.0).unsqueeze(0).float()
700+
ref_latent = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
701+
ref_latents.append(ref_latent)
702+
703+
return ref_latents
660704

661705
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
662706
if self.control_lora is None:
@@ -714,6 +758,7 @@ def _prep_flux_fill_img_cond(
714758
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
715759
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
716760
cond_img = np.array(cond_img)
761+
717762
cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
718763
cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
719764
cond_img = cond_img.to(device=device, dtype=dtype)

invokeai/app/invocations/flux_uno.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Literal, Optional
2+
3+
from PIL import Image
4+
import torchvision.transforms.functional as TVF
5+
6+
from invokeai.app.invocations.baseinvocation import (
7+
BaseModel,
8+
BaseInvocation,
9+
BaseInvocationOutput,
10+
Classification,
11+
invocation,
12+
invocation_output,
13+
)
14+
from invokeai.app.invocations.fields import (
15+
InputField,
16+
OutputField,
17+
FluxUnoReferenceField
18+
)
19+
from invokeai.app.invocations.primitives import ImageField
20+
from invokeai.app.services.shared.invocation_context import InvocationContext
21+
22+
23+
def preprocess_ref(raw_image: Image.Image, long_size: int = 512) -> Image.Image:
24+
"""Resize and center crop reference image
25+
Code from https://github.com/bytedance/UNO/blob/main/uno/flux/pipeline.py
26+
"""
27+
# Get the width and height of the original image
28+
image_w, image_h = raw_image.size
29+
30+
# Calculate the long and short sides
31+
if image_w >= image_h:
32+
new_w = long_size
33+
new_h = int((long_size / image_w) * image_h)
34+
else:
35+
new_h = long_size
36+
new_w = int((long_size / image_h) * image_w)
37+
38+
# Scale proportionally to the new width and height
39+
raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
40+
target_w = new_w // 16 * 16
41+
target_h = new_h // 16 * 16
42+
43+
# Calculate the starting coordinates of the clipping to achieve center clipping
44+
left = (new_w - target_w) // 2
45+
top = (new_h - target_h) // 2
46+
right = left + target_w
47+
bottom = top + target_h
48+
49+
# Center crop
50+
raw_image = raw_image.crop((left, top, right, bottom))
51+
52+
# Convert to RGB mode
53+
raw_image = raw_image.convert("RGB")
54+
return raw_image
55+
56+
57+
@invocation_output("flux_uno_output")
58+
class FluxUnoOutput(BaseInvocationOutput):
59+
"""The conditioning output of a FLUX Redux invocation."""
60+
61+
uno_refs: FluxUnoReferenceField = OutputField(
62+
description="Reference images container", title="Reference images"
63+
)
64+
65+
# TODO(attashe): adjust tags and category
66+
@invocation(
67+
"flux_uno",
68+
title="FLUX UNO",
69+
tags=["ip_adapter", "control"],
70+
category="ip_adapter",
71+
version="2.1.0",
72+
classification=Classification.Beta,
73+
)
74+
class FluxReduxInvocation(BaseInvocation):
75+
"""Runs a FLUX Redux model to generate a conditioning tensor."""
76+
77+
image: ImageField = InputField(description="The FLUX Redux image prompt.")
78+
image2: Optional[ImageField] = InputField(default=None, description="2nd reference")
79+
image3: Optional[ImageField] = InputField(default=None, description="3rd reference")
80+
image4: Optional[ImageField] = InputField(default=None, description="4th reference")
81+
82+
def invoke(self, context: InvocationContext) -> FluxUnoOutput:
83+
84+
images: list[str] = []
85+
for image in [self.image, self.image2, self.image3, self.image4]:
86+
if image is not None:
87+
image_pil = context.images.get_pil(image.image_name)
88+
images.append(context.images.save(image=image_pil).image_name)
89+
90+
return FluxUnoOutput(
91+
uno_refs=FluxUnoReferenceField(
92+
image_names=images)
93+
)

0 commit comments

Comments
 (0)