2
2
#
3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import urllib
5
9
import warnings
6
- from typing import Callable , Optional
10
+ from typing import Callable
7
11
8
12
import numpy as np
9
13
10
14
import torch
15
+
16
+ from tensordict import PersistentTensorDict
11
17
from tensordict .tensordict import make_tensordict
12
18
13
19
from torchrl .collectors .utils import split_trajectories
20
+ from torchrl .data .datasets .d4rl_infos import D4RL_DATASETS
14
21
from torchrl .data .replay_buffers import TensorDictReplayBuffer
15
22
from torchrl .data .replay_buffers .samplers import Sampler
16
23
from torchrl .data .replay_buffers .storages import LazyMemmapStorage
@@ -75,18 +82,25 @@ class D4RLExperienceReplay(TensorDictReplayBuffer):
75
82
differ. In particular, the ``"truncated"`` key (used to determine the
76
83
end of an episode) may be absent when ``from_env=False`` but present
77
84
otherwise, leading to a different slicing when ``traj_splits`` is enabled.
78
-
85
+ direct_download (bool): if ``True``, the data will be downloaded without
86
+ requiring D4RL. If ``None``, if ``d4rl`` is present in the env it will
87
+ be used to download the dataset, otherwise the download will fall back
88
+ on ``direct_download=True``.
89
+ This is not compatible with ``from_env=True``.
90
+ Defaults to ``None``.
79
91
use_truncated_as_done (bool, optional): if ``True``, ``done = terminated | truncated``.
80
92
Otherwise, only the ``terminated`` key is used. Defaults to ``True``.
93
+ terminate_on_end (bool, optional): Set ``done=True`` on the last timestep
94
+ in a trajectory. Default is ``False``, and will discard the
95
+ last timestep in each trajectory.
81
96
**env_kwargs (key-value pairs): additional kwargs for
82
- :func:`d4rl.qlearning_dataset`. Supports ``terminate_on_end``
83
- (``False`` by default) or other kwargs if defined by D4RL library.
97
+ :func:`d4rl.qlearning_dataset`.
84
98
85
99
86
100
Examples:
87
101
>>> from torchrl.data.datasets.d4rl import D4RLExperienceReplay
88
102
>>> from torchrl.envs import ObservationNorm
89
- >>> data = D4RLExperienceReplay("maze2d-umaze-v1")
103
+ >>> data = D4RLExperienceReplay("maze2d-umaze-v1", 128 )
90
104
>>> # we can append transforms to the dataset
91
105
>>> data.append_transform(ObservationNorm(loc=-1, scale=1.0))
92
106
>>> data.sample(128)
@@ -109,34 +123,63 @@ def __init__(
109
123
self ,
110
124
name ,
111
125
batch_size : int ,
112
- sampler : Optional [ Sampler ] = None ,
113
- writer : Optional [ Writer ] = None ,
114
- collate_fn : Optional [ Callable ] = None ,
126
+ sampler : Sampler | None = None ,
127
+ writer : Writer | None = None ,
128
+ collate_fn : Callable | None = None ,
115
129
pin_memory : bool = False ,
116
- prefetch : Optional [ int ] = None ,
117
- transform : Optional [ " Transform"] = None , # noqa-F821
130
+ prefetch : int | None = None ,
131
+ transform : "torchrl.envs. Transform" | None = None , # noqa-F821
118
132
split_trajs : bool = False ,
119
- from_env : bool = True ,
133
+ from_env : bool = None ,
120
134
use_truncated_as_done : bool = True ,
135
+ direct_download : bool = None ,
136
+ terminate_on_end : bool = None ,
121
137
** env_kwargs ,
122
138
):
123
-
124
- type (self )._import_d4rl ()
125
-
126
- if not self ._has_d4rl :
127
- raise ImportError ("Could not import d4rl" ) from self .D4RL_ERR
139
+ if from_env is None :
140
+ warnings .warn (
141
+ "from_env will soon default to ``False``, ie the data will be "
142
+ "downloaded without relying on d4rl by default. "
143
+ "For now, ``True`` will still be the default. "
144
+ "To disable this warning, explicitly pass the ``from_env`` argument "
145
+ "during construction of the dataset." ,
146
+ category = DeprecationWarning ,
147
+ )
148
+ from_env = True
128
149
self .from_env = from_env
129
150
self .use_truncated_as_done = use_truncated_as_done
130
- if from_env :
131
- dataset = self ._get_dataset_from_env (name , env_kwargs )
151
+
152
+ if not from_env and direct_download is None :
153
+ self ._import_d4rl ()
154
+ direct_download = not self ._has_d4rl
155
+
156
+ if not direct_download :
157
+ if terminate_on_end is None :
158
+ # we use the default of d4rl
159
+ terminate_on_end = False
160
+ self ._import_d4rl ()
161
+
162
+ if not self ._has_d4rl :
163
+ raise ImportError ("Could not import d4rl" ) from self .D4RL_ERR
164
+
165
+ if from_env :
166
+ dataset = self ._get_dataset_from_env (name , env_kwargs )
167
+ else :
168
+ if self .use_truncated_as_done :
169
+ warnings .warn (
170
+ "Using use_truncated_as_done=True + terminate_on_end=True "
171
+ "with from_env=False may not have the intended effect "
172
+ "as the timeouts (truncation) "
173
+ "can be absent from the static dataset."
174
+ )
175
+ env_kwargs .update ({"terminate_on_end" : terminate_on_end })
176
+ dataset = self ._get_dataset_direct (name , env_kwargs )
132
177
else :
133
- if self .use_truncated_as_done :
134
- warnings .warn (
135
- "Using terminate_on_end=True with from_env=False "
136
- "may not have the intended effect as the timeouts (truncation) "
137
- "can be absent from the static dataset."
178
+ if terminate_on_end is False :
179
+ raise ValueError (
180
+ "Using terminate_on_end=False is not compatible with direct_download=True."
138
181
)
139
- dataset = self ._get_dataset_direct (name , env_kwargs )
182
+ dataset = self ._get_dataset_direct_download (name , env_kwargs )
140
183
# Fill unknown next states with 0
141
184
dataset ["next" , "observation" ][dataset ["next" , "done" ].squeeze ()] = 0
142
185
@@ -157,6 +200,23 @@ def __init__(
157
200
)
158
201
self .extend (dataset )
159
202
203
+ def _get_dataset_direct_download (self , name , env_kwargs ):
204
+ """Directly download and use a D4RL dataset."""
205
+ if env_kwargs :
206
+ raise RuntimeError (
207
+ f"Cannot pass env_kwargs when `direct_download=True`. Got env_kwargs keys: { env_kwargs .keys ()} "
208
+ )
209
+ url = D4RL_DATASETS .get (name , None )
210
+ if url is None :
211
+ raise KeyError (f"Env { name } not found." )
212
+ h5path = _download_dataset_from_url (url )
213
+ # h5path_parent = Path(h5path).parent
214
+ dataset = PersistentTensorDict .from_h5 (h5path )
215
+ dataset = dataset .to_tensordict ()
216
+ with dataset .unlock_ ():
217
+ dataset = self ._process_data_from_env (dataset )
218
+ return dataset
219
+
160
220
def _get_dataset_direct (self , name , env_kwargs ):
161
221
from torchrl .envs .libs .gym import GymWrapper
162
222
@@ -247,6 +307,10 @@ def _get_dataset_from_env(self, name, env_kwargs):
247
307
}
248
308
)
249
309
dataset = dataset .unflatten_keys ("/" )
310
+ dataset = self ._process_data_from_env (dataset , env )
311
+ return dataset
312
+
313
+ def _process_data_from_env (self , dataset , env = None ):
250
314
if "metadata" in dataset .keys ():
251
315
metadata = dataset .get ("metadata" )
252
316
dataset = dataset .exclude ("metadata" )
@@ -277,10 +341,11 @@ def _get_dataset_from_env(self, name, env_kwargs):
277
341
pass
278
342
279
343
# let's make sure that the dtypes match what's expected
280
- for key , spec in env .observation_spec .items (True , True ):
281
- dataset [key ] = dataset [key ].to (spec .dtype )
282
- dataset ["action" ] = dataset ["action" ].to (env .action_spec .dtype )
283
- dataset ["reward" ] = dataset ["reward" ].to (env .reward_spec .dtype )
344
+ if env is not None :
345
+ for key , spec in env .observation_spec .items (True , True ):
346
+ dataset [key ] = dataset [key ].to (spec .dtype )
347
+ dataset ["action" ] = dataset ["action" ].to (env .action_spec .dtype )
348
+ dataset ["reward" ] = dataset ["reward" ].to (env .reward_spec .dtype )
284
349
285
350
# format done
286
351
dataset ["done" ] = dataset ["done" ].bool ().unsqueeze (- 1 )
@@ -300,7 +365,10 @@ def _get_dataset_from_env(self, name, env_kwargs):
300
365
dataset .clone ()
301
366
) # make sure that all tensors have a different data_ptr
302
367
self ._shift_reward_done (dataset )
303
- self .specs = env .specs .clone ()
368
+ if env is not None :
369
+ self .specs = env .specs .clone ()
370
+ else :
371
+ self .specs = None
304
372
return dataset
305
373
306
374
def _shift_reward_done (self , dataset ):
@@ -313,3 +381,39 @@ def _shift_reward_done(self, dataset):
313
381
dataset [key ] = dataset [key ].clone ()
314
382
dataset [key ][1 :] = dataset [key ][:- 1 ].clone ()
315
383
dataset [key ][0 ] = 0
384
+
385
+
386
+ def _download_dataset_from_url (dataset_url ):
387
+ dataset_filepath = _filepath_from_url (dataset_url )
388
+ if not os .path .exists (dataset_filepath ):
389
+ print ("Downloading dataset:" , dataset_url , "to" , dataset_filepath )
390
+ urllib .request .urlretrieve (dataset_url , dataset_filepath )
391
+ if not os .path .exists (dataset_filepath ):
392
+ raise IOError ("Failed to download dataset from %s" % dataset_url )
393
+ return dataset_filepath
394
+
395
+
396
+ def _filepath_from_url (dataset_url ):
397
+ _ , dataset_name = os .path .split (dataset_url )
398
+ dataset_filepath = os .path .join (DATASET_PATH , dataset_name )
399
+ return dataset_filepath
400
+
401
+
402
+ def _set_dataset_path (path ):
403
+ global DATASET_PATH
404
+ DATASET_PATH = path
405
+ os .makedirs (path , exist_ok = True )
406
+
407
+
408
+ _set_dataset_path (
409
+ os .environ .get (
410
+ "D4RL_DATASET_DIR" , os .path .expanduser ("~/.cache/torchrl/data/d4rl/datasets" )
411
+ )
412
+ )
413
+
414
+ if __name__ == "__main__" :
415
+ data = D4RLExperienceReplay ("kitchen-partial-v0" , batch_size = 128 )
416
+ print (data )
417
+ for sample in data :
418
+ print (sample )
419
+ break
0 commit comments