Skip to content

Commit 2e03511

Browse files
authored
fix: fix the incompatibility of swintransformerv2 in ms2.0 (#719)
1 parent 308825d commit 2e03511

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

mindcv/models/swintransformerv2.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,34 +113,33 @@ def __init__(
113113
self.cpb_act1 = nn.ReLU()
114114
self.cpb_mlp2 = nn.Dense(512, num_heads, has_bias=False)
115115

116-
relative_coords_h = Tensor(np.arange(-(self.window_size[0] - 1), self.window_size[0]), mstype.float32)
117-
relative_coords_w = Tensor(np.arange(-(self.window_size[1] - 1), self.window_size[1]), mstype.float32)
118-
relative_coords_table = ops.stack(ops.meshgrid((relative_coords_h, relative_coords_w), indexing="ij"), axis=0)
119-
relative_coords_table = relative_coords_table.transpose(1, 2, 0)
120-
relative_coords_table = ops.expand_dims(relative_coords_table, axis=0)
116+
relative_coords_h = np.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=float)
117+
relative_coords_w = np.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=float)
118+
relative_coords_table = np.stack(np.meshgrid(relative_coords_h, relative_coords_w, indexing="ij"), axis=0)
119+
relative_coords_table = np.transpose(relative_coords_table, (1, 2, 0))
120+
relative_coords_table = np.expand_dims(relative_coords_table, axis=0)
121121
if pretrained_window_size[0] > 0:
122122
relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
123123
relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
124124
else:
125125
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
126126
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
127127
relative_coords_table *= 8 # normalize to -8, 8
128-
sign = ops.Sign()
129128
relative_coords_table = (
130-
sign(relative_coords_table) * Tensor(np.log2(np.abs(relative_coords_table.asnumpy()) + 1)) / np.log2(8)
129+
np.sign(relative_coords_table) * np.log2(np.abs(relative_coords_table) + 1) / np.log2(8)
131130
)
132131

133132
self.relative_coords_table = Parameter(
134133
Tensor(relative_coords_table, mstype.float32), requires_grad=False
135134
)
136135

137136
# get pair-wise relative position index for each token inside the window
138-
coords_h = Tensor(np.arange(window_size[0]), mstype.int32)
139-
coords_w = Tensor(np.arange(window_size[1]), mstype.int32)
140-
coords = ops.stack(ops.meshgrid((coords_h, coords_w), indexing="ij"), axis=0) # 2, Wh, Ww
141-
coords_flatten = ops.flatten(coords) # 2, Wh*Ww
137+
coords_h = np.arange(window_size[0])
138+
coords_w = np.arange(window_size[1])
139+
coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij"), axis=0) # 2, Wh, Ww
140+
coords_flatten = coords.reshape(coords.shape[0], -1) # 2, Wh*Ww
142141
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
143-
relative_coords = relative_coords.transpose(1, 2, 0).asnumpy() # Wh*Ww, Wh*Ww, 2
142+
relative_coords = np.transpose(relative_coords, (1, 2, 0)) # Wh*Ww, Wh*Ww, 2
144143

145144
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
146145
relative_coords[:, :, 1] += window_size[1] - 1

0 commit comments

Comments
 (0)