Skip to content

Commit e10096e

Browse files
committed
Pass sharing args from CLI
1 parent 44775fc commit e10096e

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ def __init__(
326326
stable = False,
327327
sandwich_norm = False,
328328
shift_tokens = True,
329-
rotary_emb = True
329+
rotary_emb = True,
330+
shared_attn_ids = None,
331+
shared_ff_ids = None,
330332
):
331333
super().__init__()
332334
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
@@ -374,7 +376,9 @@ def __init__(
374376
stable = stable,
375377
sandwich_norm = sandwich_norm,
376378
shift_tokens = shift_tokens,
377-
rotary_emb = rotary_emb
379+
rotary_emb = rotary_emb,
380+
shared_attn_ids = shared_attn_ids,
381+
shared_ff_ids = shared_ff_ids,
378382
)
379383

380384
self.stable = stable
@@ -417,7 +421,7 @@ def generate_texts(
417421
text_tokens = torch.tensor([[0]]).cuda()
418422
else:
419423
text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)
420-
424+
421425
for _ in range(text_tokens.shape[1], text_seq_len):
422426
device = text_tokens.device
423427

@@ -443,9 +447,9 @@ def generate_texts(
443447
filtered_logits = top_k(logits, thres = filter_thres)
444448
probs = F.softmax(filtered_logits / temperature, dim = -1)
445449
sample = torch.multinomial(probs, 1)
446-
450+
447451
text_tokens = torch.cat((text_tokens, sample), dim=-1)
448-
452+
449453
padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
450454
texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
451455
return text_tokens, texts

train_dalle.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@
4646
help='path to your folder of images and text for learning the DALL-E')
4747

4848
parser.add_argument(
49-
'--wds',
50-
type = str,
51-
default='',
49+
'--wds',
50+
type = str,
51+
default='',
5252
help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
5353
)
5454

@@ -134,6 +134,10 @@
134134

135135
model_group.add_argument('--rotary_emb', help = 'Use rotary embeddings', action = 'store_true')
136136

137+
model_group.add_argument('--shared_attn_ids', default = None, type = str, help = 'Comma separated list of shared attention layer ids. Default: sharing is disabled')
138+
139+
model_group.add_argument('--shared_ff_ids', default = None, type = str, help = 'Comma separated list of shared feed forward layer ids. Default: sharing is disabled')
140+
137141
args = parser.parse_args()
138142

139143
# helpers
@@ -191,6 +195,8 @@ def cp_path_to_dir(cp_path, tag):
191195
ROTARY_EMB = args.rotary_emb
192196

193197
ATTN_TYPES = tuple(args.attn_types.split(','))
198+
SHARED_ATTN_IDS = tuple(args.shared_attn_ids.split(',')) if exists(args.shared_attn_ids) else None
199+
SHARED_FF_IDS = tuple(args.shared_ff_ids.split(',')) if exists(args.shared_ff_ids) else None
194200

195201
DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'
196202

@@ -303,6 +309,8 @@ def cp_path_to_dir(cp_path, tag):
303309
stable=STABLE,
304310
shift_tokens=SHIFT_TOKENS,
305311
rotary_emb=ROTARY_EMB,
312+
shared_attn_ids=SHARED_ATTN_IDS,
313+
shared_ff_ids=SHARED_FF_IDS,
306314
)
307315
resume_epoch = 0
308316

@@ -368,7 +376,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
368376
if myimg not in item:
369377
return False
370378
return True
371-
379+
372380
w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue)
373381
filtered_dataset = w_dataset.select(filter_dataset)
374382
ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True)
@@ -600,7 +608,7 @@ def save_model(path, epoch=0):
600608

601609
if i % SAVE_EVERY_N_STEPS == 0:
602610
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
603-
611+
604612
if i % 100 == 0:
605613
if distr_backend.is_root_worker():
606614
sample_text = text[:1]
@@ -633,7 +641,7 @@ def save_model(path, epoch=0):
633641
distr_scheduler.step(avg_loss)
634642

635643
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
636-
644+
637645
if distr_backend.is_root_worker():
638646
# save trained model to wandb as an artifact every epoch's end
639647

0 commit comments

Comments
 (0)