Skip to content

Commit 9a2a1b6

Browse files
authored
(webdataset) fix KeyError for C@H (lucidrains#363)
1 parent 499b4c9 commit 9a2a1b6

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

train_dalle.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,15 +361,17 @@ def tokenize(s):
361361
image_mapping = {
362362
myimg: imagepreproc
363363
}
364+
365+
def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available.
366+
if mycap not in item:
367+
return False
368+
if myimg not in item:
369+
return False
370+
return True
364371

365-
ds = (
366-
wds.WebDataset(DATASET)
367-
# .shuffle(is_shuffle) # Commented out for WebDataset as the behaviour cannot be predicted yet
368-
.map_dict(**image_text_mapping)
369-
.map_dict(**image_mapping)
370-
.to_tuple(mycap, myimg)
371-
.batched(BATCH_SIZE, partial=False) # It is good to avoid partial batches when using Distributed training
372-
)
372+
w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue)
373+
filtered_dataset = w_dataset.select(filter_dataset)
374+
ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True)
373375
else:
374376
ds = TextImageDataset(
375377
args.image_text_folder,

0 commit comments

Comments
 (0)