Skip to content

Commit 3018810

Browse files
MateuszGuzekvmoens
andauthored
[Feature] D4rl direct download (#1430)
Co-authored-by: Mateusz Guzek <matguzek@meta.com> Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 02cd86e commit 3018810

File tree

4 files changed

+370
-30
lines changed

4 files changed

+370
-30
lines changed

.github/unittest/linux_libs/scripts_d4rl/run_test.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,22 @@ python -c "import gym, d4rl"
4040
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestD4RL --error-for-skips
4141
coverage combine
4242
coverage xml -i
43+
44+
## check what happens if we update gym
45+
#pip install gym -U
46+
#python -c """
47+
#from torchrl.data.datasets import D4RLExperienceReplay
48+
#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=True)
49+
#for batch in data:
50+
# print(batch)
51+
# break
52+
#
53+
#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=False)
54+
#for batch in data:
55+
# print(batch)
56+
# break
57+
#
58+
#import d4rl
59+
#import gym
60+
#gym.make('halfcheetah-medium-v2')
61+
#"""

test/test_libs.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1775,7 +1775,7 @@ class TestD4RL:
17751775
def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs):
17761776

17771777
with pytest.warns(
1778-
UserWarning, match="Using terminate_on_end=True with from_env=False"
1778+
UserWarning, match="Using use_truncated_as_done=True"
17791779
) if use_truncated_as_done else nullcontext():
17801780
data_true = D4RLExperienceReplay(
17811781
task,
@@ -1823,6 +1823,37 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs):
18231823
]
18241824
assert "truncated" not in leaf_names
18251825

1826+
@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
1827+
def test_direct_download(self, task):
1828+
data_direct = D4RLExperienceReplay(
1829+
task,
1830+
split_trajs=False,
1831+
from_env=False,
1832+
batch_size=2,
1833+
use_truncated_as_done=True,
1834+
direct_download=True,
1835+
)
1836+
data_d4rl = D4RLExperienceReplay(
1837+
task,
1838+
split_trajs=False,
1839+
from_env=True,
1840+
batch_size=2,
1841+
use_truncated_as_done=True,
1842+
direct_download=False,
1843+
terminate_on_end=True, # keep the last time step
1844+
)
1845+
keys = set(data_direct._storage._storage.keys(True, True))
1846+
keys = keys.intersection(data_d4rl._storage._storage.keys(True, True))
1847+
assert len(keys)
1848+
assert_allclose_td(
1849+
data_direct._storage._storage.select(*keys).apply(
1850+
lambda t: t.as_tensor().float()
1851+
),
1852+
data_d4rl._storage._storage.select(*keys).apply(
1853+
lambda t: t.as_tensor().float()
1854+
),
1855+
)
1856+
18261857
@pytest.mark.parametrize(
18271858
"task",
18281859
[

torchrl/data/datasets/d4rl.py

Lines changed: 133 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,22 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import os
8+
import urllib
59
import warnings
6-
from typing import Callable, Optional
10+
from typing import Callable
711

812
import numpy as np
913

1014
import torch
15+
16+
from tensordict import PersistentTensorDict
1117
from tensordict.tensordict import make_tensordict
1218

1319
from torchrl.collectors.utils import split_trajectories
20+
from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS
1421
from torchrl.data.replay_buffers import TensorDictReplayBuffer
1522
from torchrl.data.replay_buffers.samplers import Sampler
1623
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
@@ -75,18 +82,25 @@ class D4RLExperienceReplay(TensorDictReplayBuffer):
7582
differ. In particular, the ``"truncated"`` key (used to determine the
7683
end of an episode) may be absent when ``from_env=False`` but present
7784
otherwise, leading to a different slicing when ``traj_splits`` is enabled.
78-
85+
direct_download (bool): if ``True``, the data will be downloaded without
86+
requiring D4RL. If ``None``, if ``d4rl`` is present in the env it will
87+
be used to download the dataset, otherwise the download will fall back
88+
on ``direct_download=True``.
89+
This is not compatible with ``from_env=True``.
90+
Defaults to ``None``.
7991
use_truncated_as_done (bool, optional): if ``True``, ``done = terminated | truncated``.
8092
Otherwise, only the ``terminated`` key is used. Defaults to ``True``.
93+
terminate_on_end (bool, optional): Set ``done=True`` on the last timestep
94+
in a trajectory. Default is ``False``, and will discard the
95+
last timestep in each trajectory.
8196
**env_kwargs (key-value pairs): additional kwargs for
82-
:func:`d4rl.qlearning_dataset`. Supports ``terminate_on_end``
83-
(``False`` by default) or other kwargs if defined by D4RL library.
97+
:func:`d4rl.qlearning_dataset`.
8498
8599
86100
Examples:
87101
>>> from torchrl.data.datasets.d4rl import D4RLExperienceReplay
88102
>>> from torchrl.envs import ObservationNorm
89-
>>> data = D4RLExperienceReplay("maze2d-umaze-v1")
103+
>>> data = D4RLExperienceReplay("maze2d-umaze-v1", 128)
90104
>>> # we can append transforms to the dataset
91105
>>> data.append_transform(ObservationNorm(loc=-1, scale=1.0))
92106
>>> data.sample(128)
@@ -109,34 +123,63 @@ def __init__(
109123
self,
110124
name,
111125
batch_size: int,
112-
sampler: Optional[Sampler] = None,
113-
writer: Optional[Writer] = None,
114-
collate_fn: Optional[Callable] = None,
126+
sampler: Sampler | None = None,
127+
writer: Writer | None = None,
128+
collate_fn: Callable | None = None,
115129
pin_memory: bool = False,
116-
prefetch: Optional[int] = None,
117-
transform: Optional["Transform"] = None, # noqa-F821
130+
prefetch: int | None = None,
131+
transform: "torchrl.envs.Transform" | None = None, # noqa-F821
118132
split_trajs: bool = False,
119-
from_env: bool = True,
133+
from_env: bool = None,
120134
use_truncated_as_done: bool = True,
135+
direct_download: bool = None,
136+
terminate_on_end: bool = None,
121137
**env_kwargs,
122138
):
123-
124-
type(self)._import_d4rl()
125-
126-
if not self._has_d4rl:
127-
raise ImportError("Could not import d4rl") from self.D4RL_ERR
139+
if from_env is None:
140+
warnings.warn(
141+
"from_env will soon default to ``False``, ie the data will be "
142+
"downloaded without relying on d4rl by default. "
143+
"For now, ``True`` will still be the default. "
144+
"To disable this warning, explicitly pass the ``from_env`` argument "
145+
"during construction of the dataset.",
146+
category=DeprecationWarning,
147+
)
148+
from_env = True
128149
self.from_env = from_env
129150
self.use_truncated_as_done = use_truncated_as_done
130-
if from_env:
131-
dataset = self._get_dataset_from_env(name, env_kwargs)
151+
152+
if not from_env and direct_download is None:
153+
self._import_d4rl()
154+
direct_download = not self._has_d4rl
155+
156+
if not direct_download:
157+
if terminate_on_end is None:
158+
# we use the default of d4rl
159+
terminate_on_end = False
160+
self._import_d4rl()
161+
162+
if not self._has_d4rl:
163+
raise ImportError("Could not import d4rl") from self.D4RL_ERR
164+
165+
if from_env:
166+
dataset = self._get_dataset_from_env(name, env_kwargs)
167+
else:
168+
if self.use_truncated_as_done:
169+
warnings.warn(
170+
"Using use_truncated_as_done=True + terminate_on_end=True "
171+
"with from_env=False may not have the intended effect "
172+
"as the timeouts (truncation) "
173+
"can be absent from the static dataset."
174+
)
175+
env_kwargs.update({"terminate_on_end": terminate_on_end})
176+
dataset = self._get_dataset_direct(name, env_kwargs)
132177
else:
133-
if self.use_truncated_as_done:
134-
warnings.warn(
135-
"Using terminate_on_end=True with from_env=False "
136-
"may not have the intended effect as the timeouts (truncation) "
137-
"can be absent from the static dataset."
178+
if terminate_on_end is False:
179+
raise ValueError(
180+
"Using terminate_on_end=False is not compatible with direct_download=True."
138181
)
139-
dataset = self._get_dataset_direct(name, env_kwargs)
182+
dataset = self._get_dataset_direct_download(name, env_kwargs)
140183
# Fill unknown next states with 0
141184
dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0
142185

@@ -157,6 +200,23 @@ def __init__(
157200
)
158201
self.extend(dataset)
159202

203+
def _get_dataset_direct_download(self, name, env_kwargs):
204+
"""Directly download and use a D4RL dataset."""
205+
if env_kwargs:
206+
raise RuntimeError(
207+
f"Cannot pass env_kwargs when `direct_download=True`. Got env_kwargs keys: {env_kwargs.keys()}"
208+
)
209+
url = D4RL_DATASETS.get(name, None)
210+
if url is None:
211+
raise KeyError(f"Env {name} not found.")
212+
h5path = _download_dataset_from_url(url)
213+
# h5path_parent = Path(h5path).parent
214+
dataset = PersistentTensorDict.from_h5(h5path)
215+
dataset = dataset.to_tensordict()
216+
with dataset.unlock_():
217+
dataset = self._process_data_from_env(dataset)
218+
return dataset
219+
160220
def _get_dataset_direct(self, name, env_kwargs):
161221
from torchrl.envs.libs.gym import GymWrapper
162222

@@ -247,6 +307,10 @@ def _get_dataset_from_env(self, name, env_kwargs):
247307
}
248308
)
249309
dataset = dataset.unflatten_keys("/")
310+
dataset = self._process_data_from_env(dataset, env)
311+
return dataset
312+
313+
def _process_data_from_env(self, dataset, env=None):
250314
if "metadata" in dataset.keys():
251315
metadata = dataset.get("metadata")
252316
dataset = dataset.exclude("metadata")
@@ -277,10 +341,11 @@ def _get_dataset_from_env(self, name, env_kwargs):
277341
pass
278342

279343
# let's make sure that the dtypes match what's expected
280-
for key, spec in env.observation_spec.items(True, True):
281-
dataset[key] = dataset[key].to(spec.dtype)
282-
dataset["action"] = dataset["action"].to(env.action_spec.dtype)
283-
dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype)
344+
if env is not None:
345+
for key, spec in env.observation_spec.items(True, True):
346+
dataset[key] = dataset[key].to(spec.dtype)
347+
dataset["action"] = dataset["action"].to(env.action_spec.dtype)
348+
dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype)
284349

285350
# format done
286351
dataset["done"] = dataset["done"].bool().unsqueeze(-1)
@@ -300,7 +365,10 @@ def _get_dataset_from_env(self, name, env_kwargs):
300365
dataset.clone()
301366
) # make sure that all tensors have a different data_ptr
302367
self._shift_reward_done(dataset)
303-
self.specs = env.specs.clone()
368+
if env is not None:
369+
self.specs = env.specs.clone()
370+
else:
371+
self.specs = None
304372
return dataset
305373

306374
def _shift_reward_done(self, dataset):
@@ -313,3 +381,39 @@ def _shift_reward_done(self, dataset):
313381
dataset[key] = dataset[key].clone()
314382
dataset[key][1:] = dataset[key][:-1].clone()
315383
dataset[key][0] = 0
384+
385+
386+
def _download_dataset_from_url(dataset_url):
387+
dataset_filepath = _filepath_from_url(dataset_url)
388+
if not os.path.exists(dataset_filepath):
389+
print("Downloading dataset:", dataset_url, "to", dataset_filepath)
390+
urllib.request.urlretrieve(dataset_url, dataset_filepath)
391+
if not os.path.exists(dataset_filepath):
392+
raise IOError("Failed to download dataset from %s" % dataset_url)
393+
return dataset_filepath
394+
395+
396+
def _filepath_from_url(dataset_url):
397+
_, dataset_name = os.path.split(dataset_url)
398+
dataset_filepath = os.path.join(DATASET_PATH, dataset_name)
399+
return dataset_filepath
400+
401+
402+
def _set_dataset_path(path):
403+
global DATASET_PATH
404+
DATASET_PATH = path
405+
os.makedirs(path, exist_ok=True)
406+
407+
408+
_set_dataset_path(
409+
os.environ.get(
410+
"D4RL_DATASET_DIR", os.path.expanduser("~/.cache/torchrl/data/d4rl/datasets")
411+
)
412+
)
413+
414+
if __name__ == "__main__":
415+
data = D4RLExperienceReplay("kitchen-partial-v0", batch_size=128)
416+
print(data)
417+
for sample in data:
418+
print(sample)
419+
break

0 commit comments

Comments
 (0)