11
11
12
12
13
13
def generate_img_ids_with_offset (
14
- latent_height : int , latent_width : int , batch_size : int , device : torch .device , dtype : torch .dtype , idx_offset : int = 0
14
+ latent_height : int ,
15
+ latent_width : int ,
16
+ batch_size : int ,
17
+ device : torch .device ,
18
+ dtype : torch .dtype ,
19
+ idx_offset : int = 0 ,
15
20
) -> torch .Tensor :
16
21
"""Generate tensor of image position ids with an optional offset.
17
22
@@ -34,24 +39,24 @@ def generate_img_ids_with_offset(
34
39
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
35
40
packed_height = latent_height // 2
36
41
packed_width = latent_width // 2
37
-
42
+
38
43
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
39
44
# The 3 channels represent: [batch_offset, y_position, x_position]
40
45
img_ids = torch .zeros (packed_height , packed_width , 3 , device = device , dtype = dtype )
41
-
46
+
42
47
# Set the batch offset for all positions
43
48
img_ids [..., 0 ] = idx_offset
44
-
49
+
45
50
# Create y-coordinate indices (vertical positions)
46
51
y_indices = torch .arange (packed_height , device = device , dtype = dtype )
47
52
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
48
53
img_ids [..., 1 ] = y_indices [:, None ]
49
-
50
- # Create x-coordinate indices (horizontal positions)
54
+
55
+ # Create x-coordinate indices (horizontal positions)
51
56
x_indices = torch .arange (packed_width , device = device , dtype = dtype )
52
57
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
53
58
img_ids [..., 2 ] = x_indices [None , :]
54
-
59
+
55
60
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
56
61
img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = batch_size )
57
62
0 commit comments