Skip to content

Commit c9f462a

Browse files
committed
Revert excess changes
1 parent 14eb932 commit c9f462a

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

dalle_pytorch/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
132132
if exists(rotary_pos_emb):
133133
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
134134

135-
q = q * self.scale
135+
q *= self.scale
136136

137137
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
138138

@@ -252,7 +252,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
252252
if exists(rotary_pos_emb):
253253
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
254254

255-
q = q * self.scale
255+
q *= self.scale
256256

257257
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
258258

dalle_pytorch/dalle_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def forward(
319319
return loss
320320

321321
# main DALL-E class
322+
322323
class DALLE(nn.Module):
323324
def __init__(
324325
self,

0 commit comments

Comments
 (0)