Skip to content

Commit adfce34

Browse files
committed
Remove excess changes
1 parent 94fda36 commit adfce34

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

dalle_pytorch/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
7575
if exists(cache):
7676
cache[cache_key] = k, v
7777

78-
dots = q @ k.swapaxes(-1, -2)
78+
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
7979
mask_value = max_neg_value(dots)
8080

8181
if exists(mask):
@@ -93,7 +93,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
9393

9494
attn = softmax(dots, dim=-1)
9595

96-
out = attn @ v
96+
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
9797
out = rearrange(out, 'b h n d -> b n (h d)')
9898
out = self.to_out(out)
9999
return out
@@ -248,7 +248,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
248248
nn.Dropout(dropout)
249249
)
250250

251-
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
251+
def forward(self, x, mask = None, rotary_pos_emb = None):
252252
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
253253
softmax = torch.softmax if not self.stable else stable_softmax
254254

0 commit comments

Comments
 (0)