Skip to content

Commit 51e1c56

Browse files
hipsterusernamepsychedelicious
authored andcommitted
ruff
1 parent ca1df60 commit 51e1c56

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

invokeai/backend/flux/extensions/kontext_extension.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212

1313
def generate_img_ids_with_offset(
14-
latent_height: int, latent_width: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
14+
latent_height: int,
15+
latent_width: int,
16+
batch_size: int,
17+
device: torch.device,
18+
dtype: torch.dtype,
19+
idx_offset: int = 0,
1520
) -> torch.Tensor:
1621
"""Generate tensor of image position ids with an optional offset.
1722
@@ -34,24 +39,24 @@ def generate_img_ids_with_offset(
3439
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
3540
packed_height = latent_height // 2
3641
packed_width = latent_width // 2
37-
42+
3843
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
3944
# The 3 channels represent: [batch_offset, y_position, x_position]
4045
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
41-
46+
4247
# Set the batch offset for all positions
4348
img_ids[..., 0] = idx_offset
44-
49+
4550
# Create y-coordinate indices (vertical positions)
4651
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
4752
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
4853
img_ids[..., 1] = y_indices[:, None]
49-
50-
# Create x-coordinate indices (horizontal positions)
54+
55+
# Create x-coordinate indices (horizontal positions)
5156
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
5257
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
5358
img_ids[..., 2] = x_indices[None, :]
54-
59+
5560
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
5661
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
5762

0 commit comments

Comments
 (0)