Skip to content

Commit 7549c12

Browse files
cursoragentpsychedelicious
authored andcommitted
Add FLUX Kontext conditioning support for reference images
Co-authored-by: kent <kent@invoke.ai> Fix Kontext sequence length handling in Flux denoise invocation Co-authored-by: kent <kent@invoke.ai> Fix Kontext step callback to handle combined token sequences Co-authored-by: kent <kent@invoke.ai> fix ruff Fix Flux Kontext
1 parent df8751b commit 7549c12

File tree

4 files changed

+207
-5
lines changed

4 files changed

+207
-5
lines changed

invokeai/app/invocations/fields.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ class FieldDescriptions:
215215
flux_redux_conditioning = "FLUX Redux conditioning tensor"
216216
vllm_model = "The VLLM model to use"
217217
flux_fill_conditioning = "FLUX Fill conditioning tensor"
218+
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
218219

219220

220221
class ImageField(BaseModel):
@@ -291,6 +292,12 @@ class FluxFillConditioningField(BaseModel):
291292
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
292293

293294

295+
class FluxKontextConditioningField(BaseModel):
296+
"""A conditioning field for FLUX Kontext (reference image)."""
297+
298+
image: ImageField = Field(description="The Kontext reference image.")
299+
300+
294301
class SD3ConditioningField(BaseModel):
295302
"""A conditioning tensor primitive value"""
296303

invokeai/app/invocations/flux_denoise.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
FieldDescriptions,
1717
FluxConditioningField,
1818
FluxFillConditioningField,
19+
FluxKontextConditioningField,
1920
FluxReduxConditioningField,
2021
ImageField,
2122
Input,
@@ -34,6 +35,7 @@
3435
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
3536
from invokeai.backend.flux.denoise import denoise
3637
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
38+
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
3739
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
3840
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
3941
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
@@ -150,6 +152,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
150152
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
151153
)
152154

155+
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
156+
default=None,
157+
description="FLUX Kontext conditioning (reference image).",
158+
input=Input.Connection,
159+
)
160+
153161
@torch.no_grad()
154162
def invoke(self, context: InvocationContext) -> LatentsOutput:
155163
latents = self._run_diffusion(context)
@@ -376,14 +384,39 @@ def _run_diffusion(
376384
dtype=inference_dtype,
377385
)
378386

387+
# Instantiate our new extension if the conditioning is provided
388+
kontext_extension = None
389+
if self.kontext_conditioning is not None:
390+
# We need a VAE to encode the reference image. We can reuse the
391+
# controlnet_vae field as it serves a similar purpose (image to latents).
392+
if not self.controlnet_vae:
393+
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
394+
395+
kontext_extension = KontextExtension(
396+
kontext_field=self.kontext_conditioning,
397+
context=context,
398+
vae_field=self.controlnet_vae, # Pass the VAE field
399+
device=TorchDevice.choose_torch_device(),
400+
dtype=inference_dtype,
401+
)
402+
403+
# THE CRITICAL INTEGRATION POINT
404+
final_img, final_img_ids = x, img_ids
405+
original_seq_len = x.shape[1] # Store the original sequence length
406+
if kontext_extension is not None:
407+
final_img, final_img_ids = kontext_extension.apply(final_img, final_img_ids)
408+
409+
# The denoise function will now use the combined tensors
379410
x = denoise(
380411
model=transformer,
381-
img=x,
382-
img_ids=img_ids,
412+
img=final_img, # Pass the combined image tokens
413+
img_ids=final_img_ids, # Pass the combined image IDs
383414
pos_regional_prompting_extension=pos_regional_prompting_extension,
384415
neg_regional_prompting_extension=neg_regional_prompting_extension,
385416
timesteps=timesteps,
386-
step_callback=self._build_step_callback(context),
417+
step_callback=self._build_step_callback(
418+
context, original_seq_len if kontext_extension is not None else None
419+
),
387420
guidance=self.guidance,
388421
cfg_scale=cfg_scale,
389422
inpaint_extension=inpaint_extension,
@@ -393,6 +426,10 @@ def _run_diffusion(
393426
img_cond=img_cond,
394427
)
395428

429+
# Extract only the main image tokens if kontext was applied
430+
if kontext_extension is not None:
431+
x = x[:, :original_seq_len, :] # Keep only the first original_seq_len tokens
432+
396433
x = unpack(x.float(), self.height, self.width)
397434
return x
398435

@@ -863,9 +900,15 @@ def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatc
863900
yield (lora_info.model, lora.weight)
864901
del lora_info
865902

866-
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
903+
def _build_step_callback(
904+
self, context: InvocationContext, original_seq_len: Optional[int] = None
905+
) -> Callable[[PipelineIntermediateState], None]:
867906
def step_callback(state: PipelineIntermediateState) -> None:
868-
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
907+
# Extract only main image tokens if Kontext conditioning was applied
908+
latents = state.latents.float()
909+
if original_seq_len is not None:
910+
latents = latents[:, :original_seq_len, :]
911+
state.latents = unpack(latents, self.height, self.width).squeeze()
869912
context.util.flux_step_callback(state)
870913

871914
return step_callback
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from invokeai.app.invocations.baseinvocation import (
2+
BaseInvocation,
3+
BaseInvocationOutput,
4+
invocation,
5+
invocation_output,
6+
)
7+
from invokeai.app.invocations.fields import (
8+
FieldDescriptions,
9+
FluxKontextConditioningField,
10+
InputField,
11+
OutputField,
12+
)
13+
from invokeai.app.invocations.primitives import ImageField
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
16+
17+
@invocation_output("flux_kontext_output")
18+
class FluxKontextOutput(BaseInvocationOutput):
19+
"""The conditioning output of a FLUX Kontext invocation."""
20+
21+
kontext_cond: FluxKontextConditioningField = OutputField(
22+
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
23+
)
24+
25+
26+
@invocation(
27+
"flux_kontext",
28+
title="Kontext Conditioning - FLUX",
29+
tags=["conditioning", "kontext", "flux"],
30+
category="conditioning",
31+
version="1.0.0",
32+
)
33+
class FluxKontextInvocation(BaseInvocation):
34+
"""Prepares a reference image for FLUX Kontext conditioning."""
35+
36+
image: ImageField = InputField(description="The Kontext reference image.")
37+
38+
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
39+
"""Packages the provided image into a Kontext conditioning field."""
40+
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import einops
2+
import torch
3+
from einops import repeat
4+
5+
from invokeai.app.invocations.fields import FluxKontextConditioningField
6+
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
7+
from invokeai.app.invocations.model import VAEField
8+
from invokeai.app.services.shared.invocation_context import InvocationContext
9+
from invokeai.backend.flux.sampling_utils import pack
10+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
11+
12+
13+
def generate_img_ids_with_offset(
14+
h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
15+
) -> torch.Tensor:
16+
"""Generate tensor of image position ids with an optional offset.
17+
18+
Args:
19+
h (int): Height of image in latent space.
20+
w (int): Width of image in latent space.
21+
batch_size (int): Batch size.
22+
device (torch.device): Device.
23+
dtype (torch.dtype): dtype.
24+
idx_offset (int): Offset to add to the first dimension of the image ids.
25+
26+
Returns:
27+
torch.Tensor: Image position ids.
28+
"""
29+
30+
if device.type == "mps":
31+
orig_dtype = dtype
32+
dtype = torch.float16
33+
34+
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
35+
img_ids[..., 0] = idx_offset # Set the offset for the first dimension
36+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
37+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
38+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
39+
40+
if device.type == "mps":
41+
img_ids = img_ids.to(orig_dtype)
42+
43+
return img_ids
44+
45+
46+
class KontextExtension:
47+
"""Applies FLUX Kontext (reference image) conditioning."""
48+
49+
def __init__(
50+
self,
51+
kontext_field: FluxKontextConditioningField,
52+
context: InvocationContext,
53+
vae_field: VAEField,
54+
device: torch.device,
55+
dtype: torch.dtype,
56+
):
57+
"""
58+
Initializes the KontextExtension, pre-processing the reference image
59+
into latents and positional IDs.
60+
"""
61+
self._context = context
62+
self._device = device
63+
self._dtype = dtype
64+
self._vae_field = vae_field
65+
self.kontext_field = kontext_field
66+
67+
# Pre-process and cache the kontext latents and ids upon initialization.
68+
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
69+
70+
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
71+
"""Encodes the reference image and prepares its latents and IDs."""
72+
image = self._context.images.get_pil(self.kontext_field.image.image_name)
73+
74+
# Reuse VAE encoding logic from FluxVaeEncodeInvocation
75+
vae_info = self._context.models.load(self._vae_field.vae)
76+
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
77+
if image_tensor.dim() == 3:
78+
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
79+
image_tensor = image_tensor.to(self._device)
80+
81+
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
82+
83+
# Pack the latents and generate IDs. The idx_offset distinguishes these
84+
# tokens from the main image's tokens, which have an index of 0.
85+
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
86+
kontext_ids = generate_img_ids_with_offset(
87+
h=kontext_latents_unpacked.shape[2],
88+
w=kontext_latents_unpacked.shape[3],
89+
batch_size=kontext_latents_unpacked.shape[0],
90+
device=self._device,
91+
dtype=self._dtype,
92+
idx_offset=1, # Distinguishes reference tokens from main image tokens
93+
)
94+
95+
return kontext_latents_packed, kontext_ids
96+
97+
def apply(
98+
self,
99+
img: torch.Tensor,
100+
img_ids: torch.Tensor,
101+
) -> tuple[torch.Tensor, torch.Tensor]:
102+
"""Concatenates the pre-processed kontext data to the main image sequence."""
103+
# Ensure batch sizes match, repeating kontext data if necessary for batch operations.
104+
if img.shape[0] != self.kontext_latents.shape[0]:
105+
self.kontext_latents = self.kontext_latents.repeat(img.shape[0], 1, 1)
106+
self.kontext_ids = self.kontext_ids.repeat(img.shape[0], 1, 1)
107+
108+
# Concatenate along the sequence dimension (dim=1)
109+
combined_img = torch.cat([img, self.kontext_latents], dim=1)
110+
combined_img_ids = torch.cat([img_ids, self.kontext_ids], dim=1)
111+
112+
return combined_img, combined_img_ids

0 commit comments

Comments
 (0)