Skip to content

Commit 24d411f

Browse files
authored
add rotary emb option to train dalle (lucidrains#348)
1 parent 208d12e commit 24d411f

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

train_dalle.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@
132132

133133
model_group.add_argument('--shift_tokens', default = False, type = bool, help = 'Use the shift tokens feature')
134134

135+
model_group.add_argument('--rotary_emb', default = False, type = bool, help = 'Use rotary embeddings')
136+
135137
args = parser.parse_args()
136138

137139
# helpers
@@ -186,6 +188,7 @@ def cp_path_to_dir(cp_path, tag):
186188
ATTN_DROPOUT = args.attn_dropout
187189
STABLE = args.stable_softmax
188190
SHIFT_TOKENS = args.shift_tokens
191+
ROTARY_EMB = args.rotary_emb
189192

190193
ATTN_TYPES = tuple(args.attn_types.split(','))
191194

@@ -299,6 +302,7 @@ def cp_path_to_dir(cp_path, tag):
299302
attn_dropout=ATTN_DROPOUT,
300303
stable=STABLE,
301304
shift_tokens=SHIFT_TOKENS,
305+
rotary_emb=ROTARY_EMB,
302306
)
303307
resume_epoch = 0
304308

0 commit comments

Comments
 (0)