@@ -119,9 +119,11 @@ def __init__(
119
119
dim ,
120
120
dim_head = 32 ,
121
121
dropout = 0. ,
122
- window_size = 7
122
+ window_size = 7 ,
123
+ num_registers = 1
123
124
):
124
125
super ().__init__ ()
126
+ assert num_registers > 0
125
127
assert (dim % dim_head ) == 0 , 'dimension should be divisible by dimension per head'
126
128
127
129
self .heads = dim // dim_head
@@ -142,7 +144,9 @@ def __init__(
142
144
143
145
# relative positional bias
144
146
145
- self .rel_pos_bias = nn .Embedding ((2 * window_size - 1 ) ** 2 , self .heads )
147
+ num_rel_pos_bias = (2 * window_size - 1 ) ** 2
148
+
149
+ self .rel_pos_bias = nn .Embedding (num_rel_pos_bias + 1 , self .heads )
146
150
147
151
pos = torch .arange (window_size )
148
152
grid = torch .stack (torch .meshgrid (pos , pos , indexing = 'ij' ))
@@ -151,10 +155,11 @@ def __init__(
151
155
rel_pos += window_size - 1
152
156
rel_pos_indices = (rel_pos * torch .tensor ([2 * window_size - 1 , 1 ])).sum (dim = - 1 )
153
157
158
+ rel_pos_indices = F .pad (rel_pos_indices , (num_registers , 0 , num_registers , 0 ), value = num_rel_pos_bias )
154
159
self .register_buffer ('rel_pos_indices' , rel_pos_indices , persistent = False )
155
160
156
161
def forward (self , x ):
157
- device , h = x .device , self .heads
162
+ device , h , bias_indices = x .device , self .heads , self . rel_pos_indices
158
163
159
164
x = self .norm (x )
160
165
@@ -176,13 +181,8 @@ def forward(self, x):
176
181
177
182
# add positional bias
178
183
179
- bias = self .rel_pos_bias (self .rel_pos_indices )
180
- bias = rearrange (bias , 'i j h -> h i j' )
181
-
182
- num_registers = sim .shape [- 1 ] - bias .shape [- 1 ]
183
- bias = F .pad (bias , (num_registers , 0 , num_registers , 0 ), value = 0. )
184
-
185
- sim = sim + bias
184
+ bias = self .rel_pos_bias (bias_indices )
185
+ sim = sim + rearrange (bias , 'i j h -> h i j' )
186
186
187
187
# attention
188
188
@@ -215,6 +215,7 @@ def __init__(
215
215
):
216
216
super ().__init__ ()
217
217
assert isinstance (depth , tuple ), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
218
+ assert num_register_tokens > 0
218
219
219
220
# convolutional stem
220
221
@@ -256,10 +257,10 @@ def __init__(
256
257
shrinkage_rate = mbconv_shrinkage_rate
257
258
)
258
259
259
- block_attn = Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = window_size )
260
+ block_attn = Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = window_size , num_registers = num_register_tokens )
260
261
block_ff = FeedForward (dim = layer_dim , dropout = dropout )
261
262
262
- grid_attn = Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = window_size )
263
+ grid_attn = Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = window_size , num_registers = num_register_tokens )
263
264
grid_ff = FeedForward (dim = layer_dim , dropout = dropout )
264
265
265
266
register_tokens = nn .Parameter (torch .randn (num_register_tokens , layer_dim ))
0 commit comments