Skip to content

Commit ca1df60

Browse files
hipsterusernamepsychedelicious
authored andcommitted
Explain the Magic
1 parent 7549c12 commit ca1df60

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

invokeai/app/invocations/flux_denoise.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,6 @@ def _run_diffusion(
384384
dtype=inference_dtype,
385385
)
386386

387-
# Instantiate our new extension if the conditioning is provided
388387
kontext_extension = None
389388
if self.kontext_conditioning is not None:
390389
# We need a VAE to encode the reference image. We can reuse the
@@ -400,7 +399,6 @@ def _run_diffusion(
400399
dtype=inference_dtype,
401400
)
402401

403-
# THE CRITICAL INTEGRATION POINT
404402
final_img, final_img_ids = x, img_ids
405403
original_seq_len = x.shape[1] # Store the original sequence length
406404
if kontext_extension is not None:
@@ -426,7 +424,6 @@ def _run_diffusion(
426424
img_cond=img_cond,
427425
)
428426

429-
# Extract only the main image tokens if kontext was applied
430427
if kontext_extension is not None:
431428
x = x[:, :original_seq_len, :] # Keep only the first original_seq_len tokens
432429

invokeai/backend/flux/extensions/kontext_extension.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,48 @@
1111

1212

1313
def generate_img_ids_with_offset(
14-
h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
14+
latent_height: int, latent_width: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
1515
) -> torch.Tensor:
1616
"""Generate tensor of image position ids with an optional offset.
1717
1818
Args:
19-
h (int): Height of image in latent space.
20-
w (int): Width of image in latent space.
21-
batch_size (int): Batch size.
22-
device (torch.device): Device.
23-
dtype (torch.dtype): dtype.
19+
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
20+
latent_width (int): Width of image in latent space (after packing, this becomes w//2).
21+
batch_size (int): Number of images in the batch.
22+
device (torch.device): Device to create tensors on.
23+
dtype (torch.dtype): Data type for the tensors.
2424
idx_offset (int): Offset to add to the first dimension of the image ids.
2525
2626
Returns:
27-
torch.Tensor: Image position ids.
27+
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
2828
"""
2929

3030
if device.type == "mps":
3131
orig_dtype = dtype
3232
dtype = torch.float16
3333

34-
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
35-
img_ids[..., 0] = idx_offset # Set the offset for the first dimension
36-
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
37-
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
34+
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
35+
packed_height = latent_height // 2
36+
packed_width = latent_width // 2
37+
38+
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
39+
# The 3 channels represent: [batch_offset, y_position, x_position]
40+
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
41+
42+
# Set the batch offset for all positions
43+
img_ids[..., 0] = idx_offset
44+
45+
# Create y-coordinate indices (vertical positions)
46+
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
47+
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
48+
img_ids[..., 1] = y_indices[:, None]
49+
50+
# Create x-coordinate indices (horizontal positions)
51+
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
52+
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
53+
img_ids[..., 2] = x_indices[None, :]
54+
55+
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
3856
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
3957

4058
if device.type == "mps":
@@ -80,13 +98,17 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
8098

8199
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
82100

101+
# Extract tensor dimensions with descriptive names
102+
# Latent tensor shape: [batch_size, channels, latent_height, latent_width]
103+
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
104+
83105
# Pack the latents and generate IDs. The idx_offset distinguishes these
84106
# tokens from the main image's tokens, which have an index of 0.
85107
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
86108
kontext_ids = generate_img_ids_with_offset(
87-
h=kontext_latents_unpacked.shape[2],
88-
w=kontext_latents_unpacked.shape[3],
89-
batch_size=kontext_latents_unpacked.shape[0],
109+
latent_height=latent_height,
110+
latent_width=latent_width,
111+
batch_size=batch_size,
90112
device=self._device,
91113
dtype=self._dtype,
92114
idx_offset=1, # Distinguishes reference tokens from main image tokens

0 commit comments

Comments
 (0)