|
11 | 11 |
|
12 | 12 |
|
13 | 13 | def generate_img_ids_with_offset(
|
14 |
| - h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0 |
| 14 | + latent_height: int, latent_width: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0 |
15 | 15 | ) -> torch.Tensor:
|
16 | 16 | """Generate tensor of image position ids with an optional offset.
|
17 | 17 |
|
18 | 18 | Args:
|
19 |
| - h (int): Height of image in latent space. |
20 |
| - w (int): Width of image in latent space. |
21 |
| - batch_size (int): Batch size. |
22 |
| - device (torch.device): Device. |
23 |
| - dtype (torch.dtype): dtype. |
| 19 | + latent_height (int): Height of image in latent space (after packing, this becomes h//2). |
| 20 | + latent_width (int): Width of image in latent space (after packing, this becomes w//2). |
| 21 | + batch_size (int): Number of images in the batch. |
| 22 | + device (torch.device): Device to create tensors on. |
| 23 | + dtype (torch.dtype): Data type for the tensors. |
24 | 24 | idx_offset (int): Offset to add to the first dimension of the image ids.
|
25 | 25 |
|
26 | 26 | Returns:
|
27 |
| - torch.Tensor: Image position ids. |
| 27 | + torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3]. |
28 | 28 | """
|
29 | 29 |
|
30 | 30 | if device.type == "mps":
|
31 | 31 | orig_dtype = dtype
|
32 | 32 | dtype = torch.float16
|
33 | 33 |
|
34 |
| - img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype) |
35 |
| - img_ids[..., 0] = idx_offset # Set the offset for the first dimension |
36 |
| - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None] |
37 |
| - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :] |
| 34 | + # After packing, the spatial dimensions are halved due to the 2x2 patch structure |
| 35 | + packed_height = latent_height // 2 |
| 36 | + packed_width = latent_width // 2 |
| 37 | + |
| 38 | + # Create base tensor for position IDs with shape [packed_height, packed_width, 3] |
| 39 | + # The 3 channels represent: [batch_offset, y_position, x_position] |
| 40 | + img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype) |
| 41 | + |
| 42 | + # Set the batch offset for all positions |
| 43 | + img_ids[..., 0] = idx_offset |
| 44 | + |
| 45 | + # Create y-coordinate indices (vertical positions) |
| 46 | + y_indices = torch.arange(packed_height, device=device, dtype=dtype) |
| 47 | + # Broadcast y_indices to match the spatial dimensions [packed_height, 1] |
| 48 | + img_ids[..., 1] = y_indices[:, None] |
| 49 | + |
| 50 | + # Create x-coordinate indices (horizontal positions) |
| 51 | + x_indices = torch.arange(packed_width, device=device, dtype=dtype) |
| 52 | + # Broadcast x_indices to match the spatial dimensions [1, packed_width] |
| 53 | + img_ids[..., 2] = x_indices[None, :] |
| 54 | + |
| 55 | + # Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3] |
38 | 56 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
39 | 57 |
|
40 | 58 | if device.type == "mps":
|
@@ -80,13 +98,17 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
|
80 | 98 |
|
81 | 99 | kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
82 | 100 |
|
| 101 | + # Extract tensor dimensions with descriptive names |
| 102 | + # Latent tensor shape: [batch_size, channels, latent_height, latent_width] |
| 103 | + batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape |
| 104 | + |
83 | 105 | # Pack the latents and generate IDs. The idx_offset distinguishes these
|
84 | 106 | # tokens from the main image's tokens, which have an index of 0.
|
85 | 107 | kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
|
86 | 108 | kontext_ids = generate_img_ids_with_offset(
|
87 |
| - h=kontext_latents_unpacked.shape[2], |
88 |
| - w=kontext_latents_unpacked.shape[3], |
89 |
| - batch_size=kontext_latents_unpacked.shape[0], |
| 109 | + latent_height=latent_height, |
| 110 | + latent_width=latent_width, |
| 111 | + batch_size=batch_size, |
90 | 112 | device=self._device,
|
91 | 113 | dtype=self._dtype,
|
92 | 114 | idx_offset=1, # Distinguishes reference tokens from main image tokens
|
|
0 commit comments