@@ -113,34 +113,33 @@ def __init__(
113
113
self .cpb_act1 = nn .ReLU ()
114
114
self .cpb_mlp2 = nn .Dense (512 , num_heads , has_bias = False )
115
115
116
- relative_coords_h = Tensor ( np .arange (- (self .window_size [0 ] - 1 ), self .window_size [0 ]), mstype . float32 )
117
- relative_coords_w = Tensor ( np .arange (- (self .window_size [1 ] - 1 ), self .window_size [1 ]), mstype . float32 )
118
- relative_coords_table = ops .stack (ops .meshgrid (( relative_coords_h , relative_coords_w ) , indexing = "ij" ), axis = 0 )
119
- relative_coords_table = relative_coords_table .transpose (1 , 2 , 0 )
120
- relative_coords_table = ops .expand_dims (relative_coords_table , axis = 0 )
116
+ relative_coords_h = np .arange (- (self .window_size [0 ] - 1 ), self .window_size [0 ], dtype = float )
117
+ relative_coords_w = np .arange (- (self .window_size [1 ] - 1 ), self .window_size [1 ], dtype = float )
118
+ relative_coords_table = np .stack (np .meshgrid (relative_coords_h , relative_coords_w , indexing = "ij" ), axis = 0 )
119
+ relative_coords_table = np .transpose (relative_coords_table , ( 1 , 2 , 0 ) )
120
+ relative_coords_table = np .expand_dims (relative_coords_table , axis = 0 )
121
121
if pretrained_window_size [0 ] > 0 :
122
122
relative_coords_table [:, :, :, 0 ] /= pretrained_window_size [0 ] - 1
123
123
relative_coords_table [:, :, :, 1 ] /= pretrained_window_size [1 ] - 1
124
124
else :
125
125
relative_coords_table [:, :, :, 0 ] /= self .window_size [0 ] - 1
126
126
relative_coords_table [:, :, :, 1 ] /= self .window_size [1 ] - 1
127
127
relative_coords_table *= 8 # normalize to -8, 8
128
- sign = ops .Sign ()
129
128
relative_coords_table = (
130
- sign (relative_coords_table ) * Tensor ( np .log2 (np .abs (relative_coords_table . asnumpy ()) + 1 ) ) / np .log2 (8 )
129
+ np . sign (relative_coords_table ) * np .log2 (np .abs (relative_coords_table ) + 1 ) / np .log2 (8 )
131
130
)
132
131
133
132
self .relative_coords_table = Parameter (
134
133
Tensor (relative_coords_table , mstype .float32 ), requires_grad = False
135
134
)
136
135
137
136
# get pair-wise relative position index for each token inside the window
138
- coords_h = Tensor ( np .arange (window_size [0 ]), mstype . int32 )
139
- coords_w = Tensor ( np .arange (window_size [1 ]), mstype . int32 )
140
- coords = ops .stack (ops .meshgrid (( coords_h , coords_w ) , indexing = "ij" ), axis = 0 ) # 2, Wh, Ww
141
- coords_flatten = ops . flatten (coords ) # 2, Wh*Ww
137
+ coords_h = np .arange (window_size [0 ])
138
+ coords_w = np .arange (window_size [1 ])
139
+ coords = np .stack (np .meshgrid (coords_h , coords_w , indexing = "ij" ), axis = 0 ) # 2, Wh, Ww
140
+ coords_flatten = coords . reshape (coords . shape [ 0 ], - 1 ) # 2, Wh*Ww
142
141
relative_coords = coords_flatten [:, :, None ] - coords_flatten [:, None , :] # 2, Wh*Ww, Wh*Ww
143
- relative_coords = relative_coords .transpose (1 , 2 , 0 ). asnumpy ( ) # Wh*Ww, Wh*Ww, 2
142
+ relative_coords = np .transpose (relative_coords , ( 1 , 2 , 0 )) # Wh*Ww, Wh*Ww, 2
144
143
145
144
relative_coords [:, :, 0 ] += window_size [0 ] - 1 # shift to start from 0
146
145
relative_coords [:, :, 1 ] += window_size [1 ] - 1
0 commit comments