4
4
# LICENSE file in the root directory of this source tree.
5
5
from __future__ import annotations
6
6
7
+ import functools
8
+
7
9
import importlib
8
10
import json
9
11
import logging
12
14
import shutil
13
15
import tempfile
14
16
from collections import defaultdict
15
- from concurrent .futures import ThreadPoolExecutor
16
17
from pathlib import Path
17
18
from typing import Callable , List
18
19
19
20
import numpy as np
20
21
21
22
import torch
22
23
from tensordict import PersistentTensorDict , TensorDict
24
+ from torch import multiprocessing as mp
23
25
24
26
from torchrl ._utils import KeyDependentDefaultDict
25
27
from torchrl .data .datasets .utils import _get_root_dir
@@ -96,6 +98,8 @@ class VD4RLExperienceReplay(TensorDictReplayBuffer):
96
98
transform that will be appended to the transform list. Supports
97
99
`int` types (square resizing) or a list/tuple of `int` (rectangular
98
100
resizing). Defaults to ``None`` (no resizing).
101
+ num_workers (int, optional): the number of workers to download the files.
102
+ Defaults to ``0`` (no multiprocessing).
99
103
100
104
Attributes:
101
105
available_datasets: a list of accepted entries to be downloaded. These
@@ -173,6 +177,7 @@ def __init__(
173
177
split_trajs : bool = False ,
174
178
totensor : bool = True ,
175
179
image_size : int | List [int ] | None = None ,
180
+ num_workers : int = 0 ,
176
181
** env_kwargs ,
177
182
):
178
183
if not _has_h5py or not _has_hf_hub :
@@ -191,6 +196,7 @@ def __init__(
191
196
self .root = root
192
197
self .split_trajs = split_trajs
193
198
self .download = download
199
+ self .num_workers = num_workers
194
200
if self .download == "force" or (self .download and not self ._is_downloaded ()):
195
201
if self .download == "force" :
196
202
try :
@@ -199,7 +205,9 @@ def __init__(
199
205
shutil .rmtree (self .data_path )
200
206
except FileNotFoundError :
201
207
pass
202
- storage = self ._download_and_preproc (dataset_id , data_path = self .data_path )
208
+ storage = self ._download_and_preproc (
209
+ dataset_id , data_path = self .data_path , num_workers = self .num_workers
210
+ )
203
211
elif self .split_trajs and not os .path .exists (self .data_path ):
204
212
storage = self ._make_split ()
205
213
else :
@@ -251,14 +259,23 @@ def _parse_datasets(cls):
251
259
return sibs
252
260
253
261
@classmethod
254
- def _download_and_preproc (cls , dataset_id , data_path ):
262
+ def _hf_hub_download (cls , subfolder , filename , * , tmpdir ):
255
263
from huggingface_hub import hf_hub_download
256
264
257
- files = []
265
+ return hf_hub_download (
266
+ "conglu/vd4rl" ,
267
+ subfolder = subfolder ,
268
+ filename = filename ,
269
+ repo_type = "dataset" ,
270
+ cache_dir = str (tmpdir ),
271
+ )
272
+
273
+ @classmethod
274
+ def _download_and_preproc (cls , dataset_id , data_path , num_workers ):
275
+
258
276
tds = []
259
277
with tempfile .TemporaryDirectory () as tmpdir :
260
278
sibs = cls ._parse_datasets ()
261
- # files = []
262
279
total_steps = 0
263
280
264
281
paths_to_proc = []
@@ -270,19 +287,19 @@ def _download_and_preproc(cls, dataset_id, data_path):
270
287
for file in sibs [path ]:
271
288
paths_to_proc .append (str (path ))
272
289
files_to_proc .append (str (file .parts [- 1 ]))
273
-
274
- with ThreadPoolExecutor ( 32 ) as executor :
275
- files = executor . map (
276
- lambda path_file : hf_hub_download (
277
- "conglu/vd4rl" ,
278
- subfolder = path_file [ 0 ] ,
279
- filename = path_file [ 1 ],
280
- repo_type = "dataset" ,
281
- cache_dir = str ( tmpdir ),
282
- ),
283
- zip ( paths_to_proc , files_to_proc ),
284
- )
285
- files = list ( files )
290
+ func = functools . partial ( cls . _hf_hub_download , tmpdir = tmpdir )
291
+ if num_workers > 0 :
292
+ with mp . Pool ( num_workers ) as pool :
293
+ files = pool . starmap (
294
+ func ,
295
+ zip ( paths_to_proc , files_to_proc ) ,
296
+ )
297
+ files = list ( files )
298
+ else :
299
+ files = [
300
+ func ( subfolder , filename )
301
+ for ( subfolder , filename ) in zip ( paths_to_proc , files_to_proc )
302
+ ]
286
303
logging .info ("Downloaded, processing files" )
287
304
if _has_tqdm :
288
305
import tqdm
0 commit comments