Skip to content

Commit 153e6fe

Browse files
authored
fix rotary and shift options in train_dalle (lucidrains#349)
1 parent f676ac7 commit 153e6fe

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

train_dalle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@
130130

131131
model_group.add_argument('--attn_types', default = 'full', type = str, help = 'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.')
132132

133-
model_group.add_argument('--shift_tokens', default = False, type = bool, help = 'Use the shift tokens feature')
133+
model_group.add_argument('--shift_tokens', help = 'Use the shift tokens feature', action = 'store_true')
134134

135-
model_group.add_argument('--rotary_emb', default = False, type = bool, help = 'Use rotary embeddings')
135+
model_group.add_argument('--rotary_emb', help = 'Use rotary embeddings', action = 'store_true')
136136

137137
args = parser.parse_args()
138138

0 commit comments

Comments
 (0)