Skip to content

Commit 705f70f

Browse files
authored
[BugFix] Dreamer helpers are broken with batched envs (#903)
1 parent 653b2a1 commit 705f70f

File tree

17 files changed

+343
-143
lines changed

17 files changed

+343
-143
lines changed

.circleci/unittest/linux_examples/scripts/run_test.sh

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,93 @@ export MKL_THREADING_LAYER=GNU
2626

2727
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 20
2828
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 20
29+
30+
# With batched environments
31+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
32+
total_frames=48 \
33+
init_random_frames=10 \
34+
batch_size=10 \
35+
frames_per_batch=16 \
36+
num_workers=4 \
37+
env_per_collector=2 \
38+
collector_devices=cuda:0 \
39+
optim_steps_per_batch=1 \
40+
record_video=True \
41+
record_frames=4 \
42+
buffer_size=120
43+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \
44+
total_frames=48 \
45+
batch_size=10 \
46+
frames_per_batch=16 \
47+
num_workers=4 \
48+
env_per_collector=2 \
49+
collector_devices=cuda:0 \
50+
optim_steps_per_batch=1 \
51+
record_video=True \
52+
record_frames=4 \
53+
logger=csv
54+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
55+
total_frames=48 \
56+
init_random_frames=10 \
57+
batch_size=10 \
58+
frames_per_batch=16 \
59+
num_workers=4 \
60+
env_per_collector=2 \
61+
collector_devices=cuda:0 \
62+
optim_steps_per_batch=1 \
63+
record_video=True \
64+
record_frames=4 \
65+
buffer_size=120
66+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
67+
total_frames=48 \
68+
init_random_frames=10 \
69+
batch_size=10 \
70+
frames_per_batch=16 \
71+
num_workers=4 \
72+
env_per_collector=2 \
73+
collector_devices=cuda:0 \
74+
optim_steps_per_batch=1 \
75+
record_video=True \
76+
record_frames=4 \
77+
buffer_size=120
78+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
79+
total_frames=48 \
80+
init_random_frames=10 \
81+
batch_size=10 \
82+
frames_per_batch=16 \
83+
num_workers=4 \
84+
env_per_collector=2 \
85+
collector_devices=cuda:0 \
86+
optim_steps_per_batch=1 \
87+
record_video=True \
88+
record_frames=4 \
89+
buffer_size=120
90+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
91+
total_frames=48 \
92+
batch_size=10 \
93+
frames_per_batch=16 \
94+
num_workers=4 \
95+
env_per_collector=2 \
96+
collector_devices=cuda:0 \
97+
optim_steps_per_batch=1 \
98+
record_video=True \
99+
record_frames=4 \
100+
lr_scheduler=
101+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
102+
total_frames=48 \
103+
init_random_frames=10 \
104+
batch_size=10 \
105+
frames_per_batch=200 \
106+
num_workers=4 \
107+
env_per_collector=2 \
108+
collector_devices=cuda:0 \
109+
optim_steps_per_batch=1 \
110+
record_video=True \
111+
record_frames=4 \
112+
buffer_size=120 \
113+
rssm_hidden_dim=17
114+
115+
# With single envs
29116
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
30117
total_frames=48 \
31118
init_random_frames=10 \
@@ -109,5 +196,6 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/drea
109196
record_frames=4 \
110197
buffer_size=120 \
111198
rssm_hidden_dim=17
199+
112200
coverage combine
113201
coverage xml -i

docs/source/conf.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@
7171
"myst_parser",
7272
]
7373

74+
intersphinx_mapping = {
75+
"torch": ("https://pytorch.org/docs/stable/", None),
76+
"tensordict": ("https://pytorch-labs.github.io/tensordict/", None),
77+
"torchrl": ("https://pytorch.org/rl/", None),
78+
"torchaudio": ("https://pytorch.org/audio/stable/", None),
79+
"torchtext": ("https://pytorch.org/text/stable/", None),
80+
"torchvision": ("https://pytorch.org/vision/stable/", None),
81+
}
82+
7483
sphinx_gallery_conf = {
7584
"examples_dirs": "reference/generated/tutorials/", # path to your example scripts
7685
"gallery_dirs": "tutorials", # path to where to save gallery generated output
@@ -162,14 +171,6 @@
162171
]
163172

164173

165-
# Example configuration for intersphinx: refer to the Python standard library.
166-
intersphinx_mapping = {
167-
"python": ("https://docs.python.org/3/", None),
168-
"torch": ("https://pytorch.org/docs/stable/", None),
169-
"numpy": ("https://numpy.org/doc/stable/", None),
170-
}
171-
172-
173174
aafig_default_options = {"scale": 1.5, "aspect": 1.0, "proportional": True}
174175

175176
# -- Generate knowledge base references -----------------------------------

docs/source/reference/envs.rst

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ The goal is to be able to swap environments in an experiment with little or no e
99
even if these environments are simulated using different libraries.
1010
TorchRL offers some out-of-the-box environment wrappers under :obj:`torchrl.envs.libs`,
1111
which we hope can be easily imitated for other libraries.
12-
The parent class :obj:`EnvBase` is a :obj:`torch.nn.Module` subclass that implements
13-
some typical environment methods using :obj:`TensorDict` as a data organiser. This allows this
12+
The parent class :class:`torchrl.envs.EnvBase` is a :class:`torch.nn.Module` subclass that implements
13+
some typical environment methods using :class:`tensordict.TensorDict` as a data organiser. This allows this
1414
class to be generic and to handle an arbitrary number of input and outputs, as well as
1515
nested or batched data structures.
1616

@@ -25,10 +25,10 @@ Each env will have the following attributes:
2525
This is especially useful for transforms (see below). For parametric environments (e.g.
2626
model-based environments), the device does represent the hardware that will be used to
2727
compute the operations.
28-
- :obj:`env.observation_spec`: a :obj:`CompositeSpec` object containing all the observation key-spec pairs.
29-
- :obj:`env.input_spec`: a :obj:`CompositeSpec` object containing all the input keys (:obj:`"action"` and others).
30-
- :obj:`env.action_spec`: a :obj:`TensorSpec` object representing the action spec.
31-
- :obj:`env.reward_spec`: a :obj:`TensorSpec` object representing the reward spec.
28+
- :obj:`env.observation_spec`: a :class:`torchrl.data.CompositeSpec` object containing all the observation key-spec pairs.
29+
- :obj:`env.input_spec`: a :class:`torchrl.data.CompositeSpec` object containing all the input keys (:obj:`"action"` and others).
30+
- :obj:`env.action_spec`: a :class:`torchrl.data.TensorSpec` object representing the action spec.
31+
- :obj:`env.reward_spec`: a :class:`torchrl.data.TensorSpec` object representing the reward spec.
3232

3333
Importantly, the environment spec shapes should *not* contain the batch size, e.g.
3434
an environment with :obj:`env.batch_size == torch.Size([4])` should not have
@@ -38,9 +38,9 @@ an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])` but sim
3838
With these, the following methods are implemented:
3939

4040
- :obj:`env.reset(tensordict)`: a reset method that may (but not necessarily requires to) take
41-
a :obj:`TensorDict` input. It return the first tensordict of a rollout, usually
41+
a :class:`tensordict.TensorDict` input. It return the first tensordict of a rollout, usually
4242
containing a :obj:`"done"` state and a set of observations.
43-
- :obj:`env.step(tensordict)`: a step method that takes a :obj:`TensorDict` input
43+
- :obj:`env.step(tensordict)`: a step method that takes a :class:`tensordict.TensorDict` input
4444
containing an input action as well as other inputs (for model-based or stateless
4545
environments, for instance).
4646
- :obj:`env.set_seed(integer)`: a seeding method that will return the next seed
@@ -51,7 +51,7 @@ With these, the following methods are implemented:
5151
- :obj:`env.rollout(max_steps, policy)`: executes a rollout in the environment for
5252
a maximum number of steps :obj:`max_steps` and using a policy :obj:`policy`.
5353
The policy should be coded using a :obj:`SafeModule` (or any other
54-
:obj:`TensorDict`-compatible module).
54+
:class:`tensordict.TensorDict`-compatible module).
5555

5656

5757
.. autosummary::
@@ -204,6 +204,47 @@ in the environment. The keys to be included in this inverse transform are passed
204204
205205
>>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step
206206
207+
Cloning transforms
208+
~~~~~~~~~~~~~~~~~~
209+
210+
Because transforms appended to an environment are "registered" to this environment
211+
through the ``transform.parent`` property, when manipulating transforms we should keep
212+
in mind that the parent may come and go following what is being done with the transform.
213+
Here are some examples: if we get a single transform from a :class:`Compose` object,
214+
this transform will keep its parent:
215+
216+
>>> third_transform = env.transform[2]
217+
>>> assert third_transform.parent is not None
218+
219+
This means that using this transform for another environment is prohibited, as
220+
the other environment would replace the parent and this may lead to unexpected
221+
behviours. Fortunately, the :class:`Transform` class comes with a :func:`clone`
222+
method that will erase the parent while keeping the identity of all the
223+
registered buffers:
224+
225+
>>> TransformedEnv(base_env, third_transform) # raises an Exception as third_transform already has a parent
226+
>>> TransformedEnv(base_env, third_transform.clone()) # works
227+
228+
On a single process or if the buffers are placed in shared memory, this will
229+
result in all the clone transforms to keep the same behaviour even if the
230+
buffers are changed in place (which is what will happen with the :class:`CatFrames`
231+
transform, for instance). In distributed settings, this may not hold and one
232+
should be careful about the expected behaviour of the cloned transforms in this
233+
context.
234+
Finally, notice that indexing multiple transforms from a :class:`Compose` transform
235+
may also result in loss of parenthood for these transforms: the reason is that
236+
indexing a :class:`Compose` transform results in another :class:`Compose` transform
237+
that does not have a parent environment. Hence, we have to clone the sub-transforms
238+
to be able to create this other composition:
239+
240+
>>> env = TransformedEnv(base_env, Compose(transform1, transform2, transform3))
241+
>>> last_two = env.transform[-2:]
242+
>>> assert isinstance(last_two, Compose)
243+
>>> assert last_two.parent is None
244+
>>> assert last_two[0] is not transform2
245+
>>> assert isinstance(last_two[0], transform2) # and the buffers will match
246+
>>> assert last_two[1] is not transform3
247+
>>> assert isinstance(last_two[1], transform3) # and the buffers will match
207248

208249
.. autosummary::
209250
:toctree: generated/

examples/a2c/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
env_library: gym # env_library used for the simulated environment.
33
env_name: HalfCheetah-v4 # name of the environment to be created. Default=Humanoid-v2
44
frame_skip: 2 # frame_skip for the environment.
5+
batch_transform: True
56

67
# Logger
78
logger: wandb # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv'

examples/ddpg/ddpg.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,23 +159,16 @@ def main(cfg: "DictConfig"): # noqa: F821
159159
logger=logger,
160160
use_env_creator=False,
161161
)()
162-
163-
# remove video recorder from recorder to have matching state_dict keys
164-
if cfg.record_video:
165-
recorder_rm = TransformedEnv(recorder.base_env)
166-
for transform in recorder.transform:
167-
if not isinstance(transform, VideoRecorder):
168-
recorder_rm.append_transform(transform.clone())
169-
else:
170-
recorder_rm = recorder
171-
172162
if isinstance(create_env_fn, ParallelEnv):
173-
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
174-
create_env_fn.close()
163+
raise NotImplementedError("This behaviour is deprecated")
175164
elif isinstance(create_env_fn, EnvCreator):
176-
recorder_rm.load_state_dict(create_env_fn().state_dict())
165+
recorder.transform[1:].load_state_dict(create_env_fn().transform.state_dict())
166+
elif isinstance(create_env_fn, TransformedEnv):
167+
recorder.transform = create_env_fn.transform.clone()
177168
else:
178-
recorder_rm.load_state_dict(create_env_fn.state_dict())
169+
raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}")
170+
if logger is not None and video_tag:
171+
recorder.insert_transform(0, VideoRecorder(logger=logger, tag=video_tag))
179172

180173
# reset reward scaling
181174
for t in recorder.transform:

examples/dqn/dqn.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,6 @@ def main(cfg: "DictConfig"): # noqa: F821
112112
make_env=create_env_fn,
113113
actor_model_explore=model_explore,
114114
cfg=cfg,
115-
# make_env_kwargs=[
116-
# {"device": device} if device >= 0 else {}
117-
# for device in args.env_rendering_devices
118-
# ],
119115
)
120116

121117
replay_buffer = make_replay_buffer(device, cfg)
@@ -126,24 +122,19 @@ def main(cfg: "DictConfig"): # noqa: F821
126122
norm_obs_only=True,
127123
obs_norm_state_dict=obs_norm_state_dict,
128124
logger=logger,
125+
use_env_creator=False,
129126
)()
130-
131-
# remove video recorder from recorder to have matching state_dict keys
132-
if cfg.record_video:
133-
recorder_rm = TransformedEnv(recorder.base_env)
134-
for transform in recorder.transform:
135-
if not isinstance(transform, VideoRecorder):
136-
recorder_rm.append_transform(transform.clone())
137-
else:
138-
recorder_rm = recorder
139-
140127
if isinstance(create_env_fn, ParallelEnv):
141-
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
142-
create_env_fn.close()
128+
raise NotImplementedError("This behaviour is deprecated")
143129
elif isinstance(create_env_fn, EnvCreator):
144-
recorder_rm.load_state_dict(create_env_fn().state_dict())
130+
recorder.transform[1:].load_state_dict(create_env_fn().transform.state_dict())
131+
elif isinstance(create_env_fn, TransformedEnv):
132+
recorder.transform = create_env_fn.transform.clone()
145133
else:
146-
recorder_rm.load_state_dict(create_env_fn.state_dict())
134+
raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}")
135+
if logger is not None and video_tag:
136+
recorder.insert_transform(0, VideoRecorder(logger=logger, tag=video_tag))
137+
147138
# reset reward scaling
148139
for t in recorder.transform:
149140
if isinstance(t, RewardScaling):

examples/dreamer/dreamer_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,19 @@ def make_env_transforms(
9191
env.append_transform(GrayScale())
9292
env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True))
9393
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"], dim=-3))
94-
if stats is None:
94+
if stats is None and obs_norm_state_dict is None:
9595
obs_stats = {
96-
"loc": torch.zeros(env.observation_spec["pixels"].shape),
97-
"scale": torch.ones(env.observation_spec["pixels"].shape),
96+
"loc": torch.zeros(()),
97+
"scale": torch.ones(()),
9898
}
99+
elif stats is None and obs_norm_state_dict is not None:
100+
obs_stats = obs_norm_state_dict
99101
else:
100102
obs_stats = stats
101103
obs_stats["standard_normal"] = True
102104
obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"])
103-
if obs_norm_state_dict:
104-
obs_norm.load_state_dict(obs_norm_state_dict)
105+
# if obs_norm_state_dict:
106+
# obs_norm.load_state_dict(obs_norm_state_dict)
105107
env.append_transform(obs_norm)
106108
if norm_rewards:
107109
reward_scaling = 1.0
@@ -125,8 +127,10 @@ def make_env_transforms(
125127
)
126128

127129
default_dict = {
128-
"state": UnboundedContinuousTensorSpec(cfg.state_dim),
129-
"belief": UnboundedContinuousTensorSpec(cfg.rssm_hidden_dim),
130+
"state": UnboundedContinuousTensorSpec(shape=(*env.batch_size, cfg.state_dim)),
131+
"belief": UnboundedContinuousTensorSpec(
132+
shape=(*env.batch_size, cfg.rssm_hidden_dim)
133+
),
130134
}
131135
env.append_transform(
132136
TensorDictPrimer(random=False, default_value=0, **default_dict)
@@ -417,6 +421,6 @@ class EnvConfig:
417421
# Disables grayscale transform.
418422
max_frames_per_traj: int = 1000
419423
# Number of steps before a reset of the environment is called (if it has not been flagged as done before).
420-
batch_transform: bool = False
424+
batch_transform: bool = True
421425
# if True, the transforms will be applied to the parallel env, and not to each individual env.\
422426
image_size: int = 84

examples/ppo/ppo.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,23 +132,16 @@ def main(cfg: "DictConfig"): # noqa: F821
132132
logger=logger,
133133
use_env_creator=False,
134134
)()
135-
136-
# remove video recorder from recorder to have matching state_dict keys
137-
if cfg.record_video:
138-
recorder_rm = TransformedEnv(recorder.base_env)
139-
for transform in recorder.transform:
140-
if not isinstance(transform, VideoRecorder):
141-
recorder_rm.append_transform(transform.clone())
142-
else:
143-
recorder_rm = recorder
144-
145135
if isinstance(create_env_fn, ParallelEnv):
146-
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
147-
create_env_fn.close()
136+
raise NotImplementedError("This behaviour is deprecated")
148137
elif isinstance(create_env_fn, EnvCreator):
149-
recorder_rm.load_state_dict(create_env_fn().state_dict())
138+
recorder.transform[1:].load_state_dict(create_env_fn().transform.state_dict())
139+
elif isinstance(create_env_fn, TransformedEnv):
140+
recorder.transform = create_env_fn.transform.clone()
150141
else:
151-
recorder_rm.load_state_dict(create_env_fn.state_dict())
142+
raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}")
143+
if logger is not None and video_tag:
144+
recorder.insert_transform(0, VideoRecorder(logger=logger, tag=video_tag))
152145

153146
# reset reward scaling
154147
for t in recorder.transform:

0 commit comments

Comments
 (0)