@@ -75,7 +75,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
75
75
if exists (cache ):
76
76
cache [cache_key ] = k , v
77
77
78
- dots = q @ k . swapaxes ( - 1 , - 2 )
78
+ dots = torch . einsum ( 'b h i d, b h j d -> b h i j' , q , k )
79
79
mask_value = max_neg_value (dots )
80
80
81
81
if exists (mask ):
@@ -93,7 +93,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
93
93
94
94
attn = softmax (dots , dim = - 1 )
95
95
96
- out = attn @ v
96
+ out = torch . einsum ( 'b h i j, b h j d -> b h i d' , attn , v )
97
97
out = rearrange (out , 'b h n d -> b n (h d)' )
98
98
out = self .to_out (out )
99
99
return out
@@ -248,7 +248,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
248
248
nn .Dropout (dropout )
249
249
)
250
250
251
- def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
251
+ def forward (self , x , mask = None , rotary_pos_emb = None ):
252
252
b , n , _ , h , img_size , axis , seq_len , device = * x .shape , self .heads , self .image_size , self .axis , self .seq_len , x .device
253
253
softmax = torch .softmax if not self .stable else stable_softmax
254
254
0 commit comments