Skip to content

Commit 79374d8

Browse files
author
Vincent Moens
authored
[BugFix] Fix VD4RL (#1834)
1 parent 156a668 commit 79374d8

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

torchrl/data/datasets/vd4rl.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import functools
8+
79
import importlib
810
import json
911
import logging
@@ -12,14 +14,14 @@
1214
import shutil
1315
import tempfile
1416
from collections import defaultdict
15-
from concurrent.futures import ThreadPoolExecutor
1617
from pathlib import Path
1718
from typing import Callable, List
1819

1920
import numpy as np
2021

2122
import torch
2223
from tensordict import PersistentTensorDict, TensorDict
24+
from torch import multiprocessing as mp
2325

2426
from torchrl._utils import KeyDependentDefaultDict
2527
from torchrl.data.datasets.utils import _get_root_dir
@@ -96,6 +98,8 @@ class VD4RLExperienceReplay(TensorDictReplayBuffer):
9698
transform that will be appended to the transform list. Supports
9799
`int` types (square resizing) or a list/tuple of `int` (rectangular
98100
resizing). Defaults to ``None`` (no resizing).
101+
num_workers (int, optional): the number of workers to download the files.
102+
Defaults to ``0`` (no multiprocessing).
99103
100104
Attributes:
101105
available_datasets: a list of accepted entries to be downloaded. These
@@ -173,6 +177,7 @@ def __init__(
173177
split_trajs: bool = False,
174178
totensor: bool = True,
175179
image_size: int | List[int] | None = None,
180+
num_workers: int = 0,
176181
**env_kwargs,
177182
):
178183
if not _has_h5py or not _has_hf_hub:
@@ -191,6 +196,7 @@ def __init__(
191196
self.root = root
192197
self.split_trajs = split_trajs
193198
self.download = download
199+
self.num_workers = num_workers
194200
if self.download == "force" or (self.download and not self._is_downloaded()):
195201
if self.download == "force":
196202
try:
@@ -199,7 +205,9 @@ def __init__(
199205
shutil.rmtree(self.data_path)
200206
except FileNotFoundError:
201207
pass
202-
storage = self._download_and_preproc(dataset_id, data_path=self.data_path)
208+
storage = self._download_and_preproc(
209+
dataset_id, data_path=self.data_path, num_workers=self.num_workers
210+
)
203211
elif self.split_trajs and not os.path.exists(self.data_path):
204212
storage = self._make_split()
205213
else:
@@ -251,14 +259,23 @@ def _parse_datasets(cls):
251259
return sibs
252260

253261
@classmethod
254-
def _download_and_preproc(cls, dataset_id, data_path):
262+
def _hf_hub_download(cls, subfolder, filename, *, tmpdir):
255263
from huggingface_hub import hf_hub_download
256264

257-
files = []
265+
return hf_hub_download(
266+
"conglu/vd4rl",
267+
subfolder=subfolder,
268+
filename=filename,
269+
repo_type="dataset",
270+
cache_dir=str(tmpdir),
271+
)
272+
273+
@classmethod
274+
def _download_and_preproc(cls, dataset_id, data_path, num_workers):
275+
258276
tds = []
259277
with tempfile.TemporaryDirectory() as tmpdir:
260278
sibs = cls._parse_datasets()
261-
# files = []
262279
total_steps = 0
263280

264281
paths_to_proc = []
@@ -270,19 +287,19 @@ def _download_and_preproc(cls, dataset_id, data_path):
270287
for file in sibs[path]:
271288
paths_to_proc.append(str(path))
272289
files_to_proc.append(str(file.parts[-1]))
273-
274-
with ThreadPoolExecutor(32) as executor:
275-
files = executor.map(
276-
lambda path_file: hf_hub_download(
277-
"conglu/vd4rl",
278-
subfolder=path_file[0],
279-
filename=path_file[1],
280-
repo_type="dataset",
281-
cache_dir=str(tmpdir),
282-
),
283-
zip(paths_to_proc, files_to_proc),
284-
)
285-
files = list(files)
290+
func = functools.partial(cls._hf_hub_download, tmpdir=tmpdir)
291+
if num_workers > 0:
292+
with mp.Pool(num_workers) as pool:
293+
files = pool.starmap(
294+
func,
295+
zip(paths_to_proc, files_to_proc),
296+
)
297+
files = list(files)
298+
else:
299+
files = [
300+
func(subfolder, filename)
301+
for (subfolder, filename) in zip(paths_to_proc, files_to_proc)
302+
]
286303
logging.info("Downloaded, processing files")
287304
if _has_tqdm:
288305
import tqdm

0 commit comments

Comments
 (0)