Skip to content

Commit 7d9e321

Browse files
committed
Improve tracing of window attn models with simpler reshape logic
1 parent a3c6685 commit 7d9e321

File tree

5 files changed

+15
-15
lines changed

5 files changed

+15
-15
lines changed

timm/models/davit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,9 @@ def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int
217217
Returns:
218218
x: (B, H, W, C)
219219
"""
220-
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
221-
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
222-
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
220+
C = windows.shape[-1]
221+
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
222+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
223223
return x
224224

225225

timm/models/gcvit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,9 @@ def window_partition(x, window_size: Tuple[int, int]):
243243
@register_notrace_function # reason: int argument is a Proxy
244244
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
245245
H, W = img_size
246-
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
247-
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
248-
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
246+
C = windows.shape[-1]
247+
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
248+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
249249
return x
250250

251251

timm/models/swin_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ def window_reverse(windows, window_size: int, H: int, W: int):
126126
Returns:
127127
x: (B, H, W, C)
128128
"""
129-
B = int(windows.shape[0] / (H * W / window_size / window_size))
130-
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
131-
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
129+
C = windows.shape[-1]
130+
x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C)
131+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
132132
return x
133133

134134

timm/models/swin_transformer_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
120120
x: (B, H, W, C)
121121
"""
122122
H, W = img_size
123-
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
124-
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
125-
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
123+
C = windows.shape[-1]
124+
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
125+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
126126
return x
127127

128128

timm/models/swin_transformer_v2_cr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
139139
x: (B, H, W, C)
140140
"""
141141
H, W = img_size
142-
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
143-
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
144-
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
142+
C = windows.shape[-1]
143+
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
144+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
145145
return x
146146

147147

0 commit comments

Comments
 (0)