Skip to content

Commit e31bbfc

Browse files
authored
Fix rotary embeddings (lucidrains#383)
1 parent 04b6feb commit e31bbfc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

dalle_pytorch/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __init__(
223223
img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0)
224224

225225
pos_emb = torch.cat((text_freqs, img_freqs), dim = -1)
226-
pos_emb = rearrange(pos_emb[:-1], 'n d -> () () n d')
226+
pos_emb = rearrange(pos_emb, 'n d -> () n d')
227227

228228
self.register_buffer('pos_emb', pos_emb)
229229

0 commit comments

Comments
 (0)