Skip to content

Commit f676ac7

Browse files
committed
place the image tokens far away relative to the text tokens
1 parent 24d411f commit f676ac7

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

dalle_pytorch/transformer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,17 @@ def __init__(
205205

206206
text_pos_emb = RotaryEmbedding(dim = rot_dim)
207207
img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel')
208+
208209
text_freqs = text_pos_emb(torch.arange(text_len))
210+
img_to_text_freqs = text_pos_emb(torch.full((img_seq_len,), 8192)) # image is given a position far away from text
211+
text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim = 0)
209212

210213
img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size))
211214
img_freqs = broadcat((rearrange(img_freqs_axial, 'i d -> i () d'), rearrange(img_freqs_axial, 'j d -> () j d')), dim = -1)
212215
img_freqs = rearrange(img_freqs, 'h w d -> (h w) d')
213216

214217
text_axial_freqs = img_axial_pos_emb(torch.full((text_len,), -10.)) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1]
215218
text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim = -1)
216-
217-
text_freqs = F.pad(text_freqs, (0, 0, 0, img_seq_len))
218219
img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0)
219220

220221
pos_emb = torch.cat((text_freqs, img_freqs), dim = -1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.0.2',
7+
version = '1.0.3',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)