Skip to content

Commit 69bdeaf

Browse files
authored
[Doc] More doc about environments (#683)
* amend * amend * amend * amend * amend * amend
1 parent 0e3f066 commit 69bdeaf

File tree

1 file changed

+193
-33
lines changed

1 file changed

+193
-33
lines changed

docs/source/reference/envs.rst

Lines changed: 193 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,131 @@
33
torchrl.envs package
44
====================
55

6+
TorchRL offers an API to handle environments of different backends, such as gym,
7+
dm-control, dm-lab, model-based environments as well as custom environments.
8+
The goal is to be able to swap environments in an experiment with little or no effort,
9+
even if these environments are simulated using different libraries.
10+
TorchRL offers some out-of-the-box environment wrappers under :obj:`torchrl.envs.libs`,
11+
which we hope can be easily imitated for other libraries.
12+
The parent class :obj:`EnvBase` is a :obj:`torch.nn.Module` subclass that implements
13+
some typical environment methods using :obj:`TensorDict` as a data organiser. This allows this
14+
class to be generic and to handle an arbitrary number of input and outputs, as well as
15+
nested or batched data structures.
16+
17+
Each env will have the following attributes:
18+
19+
- :obj:`env.batch_size`: a :obj:`torch.Size` representing the number of envs batched together.
20+
- :obj:`env.device`: the device where the input and output tensordict are expected to live.
21+
The environment device does not mean that the actual step operations will be computed on device
22+
(this is the responsibility of the backend, with which TorchRL can do little). The device of
23+
an environment just represents the device where the data is to be expected when input to the
24+
environment or retrieved from it. TorchRL takes care of mapping the data to the desired device.
25+
This is especially useful for transforms (see below). For parametric environments (e.g.
26+
model-based environments), the device does represent the hardware that will be used to
27+
compute the operations.
28+
- :obj:`env.observation_spec`: a :obj:`CompositeSpec` object containing all the observation key-spec pairs.
29+
- :obj:`env.input_spec`: a :obj:`CompositeSpec` object containing all the input keys (:obj:`"action"` and others).
30+
- :obj:`env.action_spec`: a :obj:`TensorSpec` object representing the action spec.
31+
- :obj:`env.reward_spec`: a :obj:`TensorSpec` object representing the reward spec.
32+
33+
Importantly, the environment spec shapes should *not* contain the batch size, e.g.
34+
an environment with :obj:`env.batch_size == torch.Size([4])` should not have
35+
an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])` but simply
36+
:obj:`torch.Size([action_size])`.
37+
38+
With these, the following methods are implemented:
39+
40+
- :obj:`env.reset(tensordict)`: a reset method that may (but not necessarily requires to) take
41+
a :obj:`TensorDict` input. It return the first tensordict of a rollout, usually
42+
containing a :obj:`"done"` state and a set of observations.
43+
- :obj:`env.step(tensordict)`: a step method that takes a :obj:`TensorDict` input
44+
containing an input action as well as other inputs (for model-based or stateless
45+
environments, for instance).
46+
- :obj:`env.set_seed(integer)`: a seeding method that will return the next seed
47+
to be used in a multi-env setting. This next seed is deterministically computed
48+
from the preceding one, such that one can seed multiple environments with a different
49+
seed without risking to overlap seeds in consecutive experiments, while still
50+
having reproducible results.
51+
- :obj:`env.rollout(max_steps, policy)`: executes a rollout in the environment for
52+
a maximum number of steps :obj:`max_steps` and using a policy :obj:`policy`.
53+
The policy should be coded using a :obj:`TensorDictModule` (or any other
54+
:obj:`TensorDict`-compatible module).
55+
56+
657
.. autosummary::
758
:toctree: generated/
859
:template: rl_template.rst
960

1061
EnvBase
1162
GymLikeEnv
12-
SerialEnv
13-
ParallelEnv
14-
15-
Helpers
16-
-------
17-
.. currentmodule:: torchrl.envs.utils
18-
19-
.. autosummary::
20-
:toctree: generated/
21-
:template: rl_template_fun.rst
2263

23-
step_mdp
24-
get_available_libraries
25-
set_exploration_mode
26-
exploration_mode
27-
28-
Domain-specific
64+
Vectorized envs
2965
---------------
30-
.. currentmodule:: torchrl.envs
3166

32-
.. autosummary::
33-
:toctree: generated/
34-
:template: rl_template_fun.rst
35-
36-
ModelBasedEnvBase
37-
model_based.dreamer.DreamerEnv
38-
39-
40-
Libraries
41-
---------
42-
.. currentmodule:: torchrl.envs.libs
67+
Vectorized (or better: parallel) environments is a common feature in Reinforcement Learning
68+
where executing the environment step can be cpu-intensive.
69+
Some libraries such as `gym3 <https://github.com/openai/gym3>`_ or `EnvPool <https://github.com/sail-sg/envpool>`_
70+
offer interfaces to execute batches of environments simultaneously.
71+
While they often offer a very competitive computational advantage, they do not
72+
necessarily scale to the wide variety of environment libraries supported by TorchRL.
73+
Therefore, TorchRL offers its own, generic :obj:`ParallelEnv` class to run multiple
74+
environments in parallel.
75+
As this class inherits from :obj:`EnvBase`, it enjoys the exact same API as other environment.
76+
Of course, a :obj:`ParallelEnv` will have a batch size that corresponds to its environment count:
77+
78+
.. code-block::
79+
:caption: Parallel environment
80+
81+
>>> def make_env():
82+
... return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0")
83+
>>> env = ParallelEnv(4, make_env)
84+
>>> print(env.batch_size)
85+
torch.Size([4])
86+
87+
:obj:`ParallelEnv` allows to retrieve the attributes from its contained environments:
88+
one can simply call:
89+
90+
.. code-block::
91+
:caption: Parallel environment attributes
92+
93+
>>> a, b, c, d = env.g # gets the g-force of the various envs, which we set to 9.81 before
94+
>>> print(a)
95+
9.81
96+
97+
It is also possible to reset some but not all of the environments:
98+
99+
.. code-block::
100+
:caption: Parallel environment reset
101+
102+
>>> tensordict = TensorDict({"reset_workers": [True, False, True, True]}, [4])
103+
>>> env.reset(tensordict)
104+
TensorDict(
105+
fields={
106+
done: Tensor(torch.Size([4, 1]), dtype=torch.bool),
107+
pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8),
108+
reset_workers: Tensor(torch.Size([4, 1]), dtype=torch.bool)},
109+
batch_size=torch.Size([4]),
110+
device=None,
111+
is_shared=True)
112+
113+
114+
A note on performance: launching a :obj:`ParallelEnv` can take quite some time
115+
as it requires to launch as many python instances as there are processes. Due to
116+
the time that it takes to run :obj:`import torch` (and other imports), starting the
117+
parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow.
118+
Once the environment is launched, a great speedup should be observed.
119+
120+
We also offer the :obj:`SerialEnv` class that enjoys the exact same API but is executed
121+
serially. This is mostly useful for testing purposes, when one wants to assess the
122+
behaviour of a :obj:`ParallelEnv` without launching the subprocesses.
43123

44124
.. autosummary::
45125
:toctree: generated/
46-
:template: rl_template_fun.rst
126+
:template: rl_template.rst
127+
128+
SerialEnv
129+
ParallelEnv
47130

48-
gym.GymEnv
49-
gym.GymWrapper
50-
dm_control.DMControlEnv
51-
dm_control.DMControlWrapper
52131

53132
Transforms
54133
----------
@@ -58,6 +137,49 @@ In most cases, the raw output of an environment must be treated before being pas
58137
policy or a value operator). To do this, TorchRL provides a set of transforms that aim at reproducing the transform
59138
logic of `torch.distributions.Transform` and `torchvision.transforms`.
60139

140+
Transformed environments are build using the :doc:`TransformedEnv` primitive.
141+
Composed transforms are built using the :doc:`Compose` class:
142+
143+
.. code-block::
144+
:caption: Transformed environment
145+
146+
>>> base_env = GymEnv("Pendulum-v1", from_pixels=True, device="cuda:0")
147+
>>> transform = Compose(ToTensorImage(in_keys=["pixels"]), Resize(64, 64, in_keys=["pixels"]))
148+
>>> env = TransformedEnv(base_env, transform)
149+
150+
151+
By default, the transformed environment will inherit the device of the
152+
:obj:`base_env` that is passed to it. The transforms will then be executed on that device.
153+
It is now apparent that this can bring a significant speedup depending on the kind of
154+
operations that is to be computed.
155+
156+
A great advantage of environment wrappers is that one can consult the environment up to that wrapper.
157+
The same can be achieved with TorchRL transformed environments: the :doc:`parent` attribute will
158+
return a new :obj:`TransformedEnv` with all the transforms up to the transform of interest.
159+
Re-using the example above:
160+
161+
.. code-block::
162+
:caption: Transform parent
163+
164+
>>> resize_parent = env.transform[-1].parent # returns the same as TransformedEnv(base_env, transform[:-1])
165+
166+
167+
Transformed environment can be used with vectorized environments.
168+
Since each transform uses a :doc:`"in_keys"`/:doc:`"out_keys"` set of keyword argument, it is
169+
also easy to root the transform graph to each component of the observation data (e.g.
170+
pixels or states etc).
171+
172+
Transforms also have an :doc:`inv` method that is called before
173+
the action is applied in reverse order over the composed transform chain:
174+
this allows to apply transforms to data in the environment before the action is taken
175+
in the environment. The keys to be included in this inverse transform are passed through the
176+
:doc:`"in_keys_inv"` keyword argument:
177+
178+
.. code-block::
179+
:caption: Inverse transform
180+
181+
>>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step
182+
61183
62184
.. autosummary::
63185
:toctree: generated/
@@ -88,3 +210,41 @@ logic of `torch.distributions.Transform` and `torchvision.transforms`.
88210
TensorDictPrimer
89211
R3MTransform
90212
VIPTransform
213+
214+
Helpers
215+
-------
216+
.. currentmodule:: torchrl.envs.utils
217+
218+
.. autosummary::
219+
:toctree: generated/
220+
:template: rl_template_fun.rst
221+
222+
step_mdp
223+
get_available_libraries
224+
set_exploration_mode
225+
exploration_mode
226+
227+
Domain-specific
228+
---------------
229+
.. currentmodule:: torchrl.envs
230+
231+
.. autosummary::
232+
:toctree: generated/
233+
:template: rl_template_fun.rst
234+
235+
ModelBasedEnvBase
236+
model_based.dreamer.DreamerEnv
237+
238+
239+
Libraries
240+
---------
241+
.. currentmodule:: torchrl.envs.libs
242+
243+
.. autosummary::
244+
:toctree: generated/
245+
:template: rl_template_fun.rst
246+
247+
gym.GymEnv
248+
gym.GymWrapper
249+
dm_control.DMControlEnv
250+
dm_control.DMControlWrapper

0 commit comments

Comments
 (0)