|
46 | 46 | help='path to your folder of images and text for learning the DALL-E')
|
47 | 47 |
|
48 | 48 | parser.add_argument(
|
49 |
| - '--wds', |
50 |
| - type = str, |
51 |
| - default='', |
| 49 | + '--wds', |
| 50 | + type = str, |
| 51 | + default='', |
52 | 52 | help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
|
53 | 53 | )
|
54 | 54 |
|
|
134 | 134 |
|
135 | 135 | model_group.add_argument('--rotary_emb', help = 'Use rotary embeddings', action = 'store_true')
|
136 | 136 |
|
| 137 | +model_group.add_argument('--shared_attn_ids', default = None, type = str, help = 'Comma separated list of shared attention layer ids. Default: sharing is disabled') |
| 138 | + |
| 139 | +model_group.add_argument('--shared_ff_ids', default = None, type = str, help = 'Comma separated list of shared feed forward layer ids. Default: sharing is disabled') |
| 140 | + |
137 | 141 | args = parser.parse_args()
|
138 | 142 |
|
139 | 143 | # helpers
|
@@ -191,6 +195,8 @@ def cp_path_to_dir(cp_path, tag):
|
191 | 195 | ROTARY_EMB = args.rotary_emb
|
192 | 196 |
|
193 | 197 | ATTN_TYPES = tuple(args.attn_types.split(','))
|
| 198 | +SHARED_ATTN_IDS = tuple(args.shared_attn_ids.split(',')) if exists(args.shared_attn_ids) else None |
| 199 | +SHARED_FF_IDS = tuple(args.shared_ff_ids.split(',')) if exists(args.shared_ff_ids) else None |
194 | 200 |
|
195 | 201 | DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'
|
196 | 202 |
|
@@ -303,6 +309,8 @@ def cp_path_to_dir(cp_path, tag):
|
303 | 309 | stable=STABLE,
|
304 | 310 | shift_tokens=SHIFT_TOKENS,
|
305 | 311 | rotary_emb=ROTARY_EMB,
|
| 312 | + shared_attn_ids=SHARED_ATTN_IDS, |
| 313 | + shared_ff_ids=SHARED_FF_IDS, |
306 | 314 | )
|
307 | 315 | resume_epoch = 0
|
308 | 316 |
|
@@ -368,7 +376,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
|
368 | 376 | if myimg not in item:
|
369 | 377 | return False
|
370 | 378 | return True
|
371 |
| - |
| 379 | + |
372 | 380 | w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue)
|
373 | 381 | filtered_dataset = w_dataset.select(filter_dataset)
|
374 | 382 | ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True)
|
@@ -600,7 +608,7 @@ def save_model(path, epoch=0):
|
600 | 608 |
|
601 | 609 | if i % SAVE_EVERY_N_STEPS == 0:
|
602 | 610 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
|
603 |
| - |
| 611 | + |
604 | 612 | if i % 100 == 0:
|
605 | 613 | if distr_backend.is_root_worker():
|
606 | 614 | sample_text = text[:1]
|
@@ -633,7 +641,7 @@ def save_model(path, epoch=0):
|
633 | 641 | distr_scheduler.step(avg_loss)
|
634 | 642 |
|
635 | 643 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
|
636 |
| - |
| 644 | + |
637 | 645 | if distr_backend.is_root_worker():
|
638 | 646 | # save trained model to wandb as an artifact every epoch's end
|
639 | 647 |
|
|
0 commit comments