Skip to content

Commit 7c0a28b

Browse files
authored
Do not try to wrap WebDatasets with DeepSpeed (lucidrains#367)
Wrapping causes errors due to PyTorch's `torch.data.utils.DistributedSampler` not being applicable to `torch.data.utils.IterableDataset`s (which WebDatasets are implementing). Fix lucidrains#359.
1 parent 9a2a1b6 commit 7c0a28b

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

train_dalle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,11 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
492492
model=dalle,
493493
optimizer=opt,
494494
model_parameters=get_trainable_params(dalle),
495-
training_data=ds if using_deepspeed else dl,
495+
training_data=(
496+
(None if ENABLE_WEBDATASET else ds)
497+
if using_deepspeed
498+
else dl
499+
),
496500
# Do not pass the LR scheduler to DeepSpeed so we can manually
497501
# advance it.
498502
lr_scheduler=scheduler if LR_DECAY and not using_deepspeed else None,

0 commit comments

Comments
 (0)