Skip to content

Commit 3b355dd

Browse files
author
Vincent Moens
authored
[Feature] step_and_maybe_reset in env (#1611)
1 parent 8d2bc8b commit 3b355dd

File tree

41 files changed

+2595
-1110
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2595
-1110
lines changed

.github/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies:
1616
- pytest-mock
1717
- pytest-instafail
1818
- pytest-rerunfailures
19+
- pytest-timeout
1920
- expecttest
2021
- pyyaml
2122
- scipy

.github/unittest/linux/scripts/run_all.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,12 @@ pytest test/smoke_test.py -v --durations 200
184184
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
185185
if [ "${CU_VERSION:-}" != cpu ] ; then
186186
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
187-
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py
187+
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
188+
--timeout=120
188189
else
189190
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
190-
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py
191+
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py \
192+
--timeout=120
191193
fi
192194

193195
coverage combine

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
3737
optim.updates_per_episode=3 \
3838
optim.warmup_steps=10 \
3939
optim.device=cuda:0 \
40-
logger.backend=
40+
logger.backend= \
41+
env.backend=gymnasium \
42+
env.name=HalfCheetah-v4
4143
python .github/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_dt.py \
4244
optim.pretrain_gradient_steps=55 \
4345
optim.updates_per_episode=3 \
4446
optim.warmup_steps=10 \
4547
optim.device=cuda:0 \
46-
logger.backend=
48+
logger.backend= \
49+
env.backend=gymnasium \
50+
env.name=HalfCheetah-v4
4751

4852
# ==================================================================================== #
4953
# ================================ Gymnasium ========================================= #

benchmarks/ecosystem/gym_env_throughput.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
"""
1717
import time
1818

19-
import myosuite # noqa: F401
19+
# import myosuite # noqa: F401
20+
import torch
2021
import tqdm
2122
from torchrl._utils import timeit
2223
from torchrl.collectors import (
@@ -29,6 +30,10 @@
2930
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend
3031

3132
if __name__ == "__main__":
33+
avail_devices = ("cpu",)
34+
if torch.cuda.device_count():
35+
avail_devices = avail_devices + ("cuda:0",)
36+
3237
for envname in [
3338
"CartPole-v1",
3439
"HalfCheetah-v4",
@@ -69,24 +74,25 @@ def make(envname=envname, gym_backend=gym_backend):
6974
log.flush()
7075

7176
# regular parallel env
72-
for device in (
73-
"cuda:0",
74-
"cpu",
75-
):
77+
for device in avail_devices:
7678

7779
def make(envname=envname, gym_backend=gym_backend, device=device):
7880
with set_gym_backend(gym_backend):
7981
return GymEnv(envname, device=device)
8082

81-
env_make = EnvCreator(make)
82-
penv = ParallelEnv(num_workers, env_make)
83-
# warmup
84-
penv.rollout(2)
85-
pbar = tqdm.tqdm(total=num_workers * 10_000)
86-
t0 = time.time()
87-
for _ in range(100):
88-
data = penv.rollout(100, break_when_any_done=False)
89-
pbar.update(100 * num_workers)
83+
# env_make = EnvCreator(make)
84+
penv = ParallelEnv(num_workers, EnvCreator(make))
85+
with torch.inference_mode():
86+
# warmup
87+
penv.rollout(2)
88+
pbar = tqdm.tqdm(total=num_workers * 10_000)
89+
t0 = time.time()
90+
data = None
91+
for _ in range(100):
92+
data = penv.rollout(
93+
100, break_when_any_done=False, out=data
94+
)
95+
pbar.update(100 * num_workers)
9096
log.write(
9197
f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
9298
)
@@ -95,7 +101,7 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
95101
timeit.print()
96102
del penv
97103

98-
for device in ("cuda:0", "cpu"):
104+
for device in avail_devices:
99105

100106
def make(envname=envname, gym_backend=gym_backend, device=device):
101107
with set_gym_backend(gym_backend):
@@ -109,29 +115,26 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
109115
RandomPolicy(penv.action_spec),
110116
frames_per_batch=1024,
111117
total_frames=num_workers * 10_000,
118+
device=device,
119+
storing_device=device,
112120
)
113121
pbar = tqdm.tqdm(total=num_workers * 10_000)
114122
total_frames = 0
115-
for i, data in enumerate(collector):
116-
if i == num_collectors:
117-
t0 = time.time()
118-
if i >= num_collectors:
119-
total_frames += data.numel()
120-
pbar.update(data.numel())
121-
pbar.set_description(
122-
f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
123-
)
123+
t0 = time.time()
124+
for data in collector:
125+
total_frames += data.numel()
126+
pbar.update(data.numel())
127+
pbar.set_description(
128+
f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
129+
)
124130
log.write(
125131
f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
126132
)
127133
log.flush()
128134
collector.shutdown()
129135
del collector
130136

131-
for device in (
132-
"cuda:0",
133-
"cpu",
134-
):
137+
for device in avail_devices:
135138
# gym parallel env
136139
def make_env(
137140
envname=envname,
@@ -158,10 +161,7 @@ def make_env(
158161
penv.close()
159162
del penv
160163

161-
for device in (
162-
"cuda:0",
163-
"cpu",
164-
):
164+
for device in avail_devices:
165165
# async collector
166166
# + torchrl parallel env
167167
def make_env(
@@ -179,6 +179,7 @@ def make_env(
179179
frames_per_batch=1024,
180180
total_frames=num_workers * 10_000,
181181
device=device,
182+
storing_device=device,
182183
)
183184
pbar = tqdm.tqdm(total=num_workers * 10_000)
184185
total_frames = 0
@@ -198,10 +199,7 @@ def make_env(
198199
collector.shutdown()
199200
del collector
200201

201-
for device in (
202-
"cuda:0",
203-
"cpu",
204-
):
202+
for device in avail_devices:
205203
# async collector
206204
# + gym async env
207205
def make_env(
@@ -226,6 +224,7 @@ def make_env(
226224
total_frames=num_workers * 10_000,
227225
num_sub_threads=num_workers // num_collectors,
228226
device=device,
227+
storing_device=device,
229228
)
230229
pbar = tqdm.tqdm(total=num_workers * 10_000)
231230
total_frames = 0
@@ -245,10 +244,7 @@ def make_env(
245244
collector.shutdown()
246245
del collector
247246

248-
for device in (
249-
"cuda:0",
250-
"cpu",
251-
):
247+
for device in avail_devices:
252248
# sync collector
253249
# + torchrl parallel env
254250
def make_env(
@@ -266,6 +262,7 @@ def make_env(
266262
frames_per_batch=1024,
267263
total_frames=num_workers * 10_000,
268264
device=device,
265+
storing_device=device,
269266
)
270267
pbar = tqdm.tqdm(total=num_workers * 10_000)
271268
total_frames = 0
@@ -285,10 +282,7 @@ def make_env(
285282
collector.shutdown()
286283
del collector
287284

288-
for device in (
289-
"cuda:0",
290-
"cpu",
291-
):
285+
for device in avail_devices:
292286
# sync collector
293287
# + gym async env
294288
def make_env(

docs/source/reference/envs.rst

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,23 @@ With these, the following methods are implemented:
5858
- :meth:`env.step`: a step method that takes a :class:`tensordict.TensorDict` input
5959
containing an input action as well as other inputs (for model-based or stateless
6060
environments, for instance).
61+
- :meth:`env.step_and_maybe_reset`: executes a step, and (partially) resets the
62+
environments if it needs to. It returns the updated input with a ``"next"``
63+
key containing the data of the next step, as well as a tensordict containing
64+
the input data for the next step (ie, reset or result or
65+
:func:`~torchrl.envs.utils.step_mdp`)
66+
This is done by reading the ``done_keys`` and
67+
assigning a ``"_reset"`` signal to each done state. This method allows
68+
to code non-stopping rollout functions with little effort:
69+
70+
>>> data_ = env.reset()
71+
>>> result = []
72+
>>> for i in range(N):
73+
... data, data_ = env.step_and_maybe_reset(data_)
74+
... result.append(data)
75+
...
76+
>>> result = torch.stack(result)
77+
6178
- :meth:`env.set_seed`: a seeding method that will return the next seed
6279
to be used in a multi-env setting. This next seed is deterministically computed
6380
from the preceding one, such that one can seed multiple environments with a different
@@ -169,7 +186,95 @@ one can simply call:
169186
>>> print(a)
170187
9.81
171188
172-
It is also possible to reset some but not all of the environments:
189+
TorchRL uses a private ``"_reset"`` key to indicate to the environment which
190+
component (sub-environments or agents) should be reset.
191+
This allows to reset some but not all of the components.
192+
193+
The ``"_reset"`` key has two distinct functionalities:
194+
1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may
195+
not be present in the input tensordict. TorchRL's convention is that the
196+
absence of the ``"_reset"`` key at a given ``"done"`` level indicates
197+
a total reset of that level (unless a ``"_reset"`` key was found at a level
198+
above, see details below).
199+
If it is present, it is expected that those entries and only those components
200+
where the ``"_reset"`` entry is ``True`` (along key and shape dimension) will be reset.
201+
202+
The way an environment deals with the ``"_reset"`` keys in its :meth:`~.EnvBase._reset`
203+
method is proper to its class.
204+
Designing an environment that behaves according to ``"_reset"`` inputs is the
205+
developer's responsibility, as TorchRL has no control over the inner logic
206+
of :meth:`~.EnvBase._reset`. Nevertheless, the following point should be
207+
kept in mind when desiging that method.
208+
209+
2. After a call to :meth:`~.EnvBase._reset`, the output will be masked with the
210+
``"_reset"`` entries and the output of the previous :meth:`~.EnvBase.step`
211+
will be written wherever the ``"_reset"`` was ``False``. In practice, this
212+
means that if a ``"_reset"`` modifies data that isn't exposed by it, this
213+
modification will be lost. After this masking operation, the ``"_reset"``
214+
entries will be erased from the :meth:`~.EnvBase.reset` outputs.
215+
216+
It must be pointed that ``"_reset"`` is a private key, and it should only be
217+
used when coding specific environment features that are internal facing.
218+
In other words, this should NOT be used outside of the library, and developers
219+
will keep the right to modify the logic of partial resets through ``"_reset"``
220+
setting without preliminary warranty, as long as they don't affect TorchRL
221+
internal tests.
222+
223+
Finally, the following assumptions are made and should be kept in mind when
224+
designing reset functionalities:
225+
226+
- Each ``"_reset"`` is paired with a ``"done"`` entry (+ ``"terminated"`` and,
227+
possibly, ``"truncated"``). This means that the following structure is not
228+
allowed: ``TensorDict({"done": done, "nested": {"_reset": reset}}, [])``, as
229+
the ``"_reset"`` lives at a different nesting level than the ``"done"``.
230+
- A reset at one level does not preclude the presence of a ``"_reset"`` at lower
231+
levels, but it annihilates its effects. The reason is simply that
232+
whether the ``"_reset"`` at the root level corresponds to an ``all()``, ``any()``
233+
or custom call to the nested ``"done"`` entries cannot be known in advance,
234+
and it is explicitly assumed that the ``"_reset"`` at the root was placed
235+
there to superseed the nested values (for an example, have a look at
236+
:class:`~.PettingZooWrapper` implementation where each group has one or more
237+
``"done"`` entries associated which is aggregated at the root level with a
238+
``any`` or ``all`` logic depending on the task).
239+
- When calling :meth:`env.reset(tensordict)` with a partial ``"_reset"`` entry
240+
that will reset some but not all the done sub-environments, the input data
241+
should contain the data of the sub-environemtns that are __not__ being reset.
242+
The reason for this constrain lies in the fact that the output of the
243+
``env._reset(data)`` can only be predicted for the entries that are reset.
244+
For the others, TorchRL cannot know in advance if they will be meaningful or
245+
not. For instance, one could perfectly just pad the values of the non-reset
246+
components, in which case the non-reset data will be meaningless and should
247+
be discarded.
248+
249+
Below, we give some examples of the expected effect that ``"_reset"`` keys will
250+
have on an environment returning zeros after reset:
251+
252+
>>> # single reset at the root
253+
>>> data = TensorDict({"val": [1, 1], "_reset": [False, True]}, [])
254+
>>> env.reset(data)
255+
>>> print(data.get("val")) # only the second value is 0
256+
tensor([1, 0])
257+
>>> # nested resets
258+
>>> data = TensorDict({
259+
... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True],
260+
... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False],
261+
... }, [])
262+
>>> env.reset(data)
263+
>>> print(data.get(("agent0", "val"))) # only the second value is 0
264+
tensor([1, 0])
265+
>>> print(data.get(("agent1", "val"))) # only the second value is 0
266+
tensor([0, 2])
267+
>>> # nested resets are overridden by a "_reset" at the root
268+
>>> data = TensorDict({
269+
... "_reset": [True, True],
270+
... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True],
271+
... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False],
272+
... }, [])
273+
>>> env.reset(data)
274+
>>> print(data.get(("agent0", "val"))) # reset at the root overrides nested
275+
tensor([0, 0])
276+
>>> print(data.get(("agent1", "val"))) # reset at the root overrides nested
277+
tensor([0, 0])
173278

174279
.. code-block::
175280
:caption: Parallel environment reset

examples/decision_transformer/dt.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
)
2929

3030

31-
@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden
32-
@hydra.main(config_path=".", config_name="dt_config")
31+
@hydra.main(config_path=".", config_name="dt_config", version_base="1.1")
3332
def main(cfg: "DictConfig"): # noqa: F821
33+
set_gym_backend(cfg.env.backend).set()
34+
3435
model_device = cfg.optim.device
3536

3637
# Set seeds
@@ -63,6 +64,11 @@ def main(cfg: "DictConfig"): # noqa: F821
6364
policy=policy,
6465
inference_context=cfg.env.inference_context,
6566
).to(model_device)
67+
inference_policy.set_tensor_keys(
68+
observation="observation_cat",
69+
action="action_cat",
70+
return_to_go="return_to_go_cat",
71+
)
6672

6773
pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)
6874

@@ -76,7 +82,7 @@ def main(cfg: "DictConfig"): # noqa: F821
7682
# Pretraining
7783
start_time = time.time()
7884
for i in range(pretrain_gradient_steps):
79-
pbar.update(i)
85+
pbar.update(1)
8086

8187
# Sample data
8288
data = offline_buffer.sample()

examples/decision_transformer/dt_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ env:
1515
target_return_mode: reduce
1616
eval_target_return: 6000
1717
collect_target_return: 12000
18+
backend: gym # D4RL uses gym so we make sure gymnasium is hidden
1819

1920
# logger
2021
logger:

examples/decision_transformer/odt_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ env:
1414
target_return_mode: reduce
1515
eval_target_return: 6000
1616
collect_target_return: 12000
17+
backend: gym # D4RL uses gym so we make sure gymnasium is hidden
1718

1819

1920
# logger

0 commit comments

Comments
 (0)