Skip to content

Commit 8375ebd

Browse files
committed
Fix rotary embeddings
1 parent bc1bb6c commit 8375ebd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

dalle_pytorch/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239
img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0)
240240

241241
pos_emb = torch.cat((text_freqs, img_freqs), dim = -1)
242-
pos_emb = rearrange(pos_emb[:-1], 'n d -> () () n d')
242+
pos_emb = rearrange(pos_emb, 'n d -> () n d')
243243

244244
self.register_buffer('pos_emb', pos_emb)
245245

0 commit comments

Comments
 (0)