Skip to content

Commit dd59290

Browse files
author
Vincent Moens
committed
[Doc] Better doc for Transform class
ghstack-source-id: 16e563b Pull Request resolved: #2797
1 parent 2046bc5 commit dd59290

File tree

10 files changed

+371
-178
lines changed

10 files changed

+371
-178
lines changed

docs/source/reference/envs.rst

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,8 @@ The inverse process is executed with the output tensordict, where the `in_keys`
865865

866866
Rename transform logic
867867

868+
.. note:: During a call to `inv`, the transforms are executed in reversed order (compared to the forward / step mode).
869+
868870
Transforming Tensors and Specs
869871
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
870872

@@ -900,6 +902,74 @@ tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand
900902
environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the
901903
transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`.
902904

905+
Designing your own Transform
906+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
907+
908+
To create a basic, custom transform, you need to subclass the `Transform` class and implement the
909+
:meth:`~torchrl.envs._apply_transform` method. Here's an example of a simple transform that adds 1 to the observation
910+
tensor:
911+
912+
>>> class AddOneToObs(Transform):
913+
... """A transform that adds 1 to the observation tensor."""
914+
...
915+
... def __init__(self):
916+
... super().__init__(in_keys=["observation"], out_keys=["observation"])
917+
...
918+
... def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
919+
... return obs + 1
920+
921+
922+
Tips for subclassing `Transform`
923+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
924+
925+
There are various ways of subclassing a transform. The things to take into considerations are:
926+
927+
- Is the transform identical for each tensor / item being transformed? Use
928+
:meth:`~torchrl.envs.Transform._apply_transform` and :meth:`~torchrl.envs.Transform._inv_apply_transform`.
929+
- The transform needs access to the input data to env.step as well as output? Rewrite
930+
:meth:`~torchrl.envs.Transform._step`.
931+
Otherwise, rewrite :meth:`~torchrl.envs.Transform._call` (or :meth:`~torchrl.envs.Transform._inv_call`).
932+
- Is the transform to be used within a replay buffer? Overwrite :meth:`~torchrl.envs.Transform.forward`,
933+
:meth:`~torchrl.envs.Transform.inv`, :meth:`~torchrl.envs.Transform._apply_transform` or
934+
:meth:`~torchrl.envs.Transform._inv_apply_transform`.
935+
- Within a transform, you can access (and make calls to) the parent environment using
936+
:attr:`~torchrl.envs.Transform.parent` (the base env + all transforms till this one) or
937+
:meth:`~torchrl.envs.Transform.container` (The object that encapsulates the transform).
938+
- Don't forget to edits the specs if needed: top level: :meth:`~torchrl.envs.Transform.transform_output_spec`,
939+
:meth:`~torchrl.envs.Transform.transform_input_spec`.
940+
Leaf level: :meth:`~torchrl.envs.Transform.transform_observation_spec`,
941+
:meth:`~torchrl.envs.Transform.transform_action_spec`, :meth:`~torchrl.envs.Transform.transform_state_spec`,
942+
:meth:`~torchrl.envs.Transform.transform_reward_spec` and
943+
:meth:`~torchrl.envs.Transform.transform_reward_spec`.
944+
945+
For practical examples, see the methods listed above.
946+
947+
You can use a transform in an environment by passing it to the TransformedEnv constructor:
948+
949+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), AddOneToObs())
950+
951+
You can compose multiple transforms together using the Compose class:
952+
953+
>>> transform = Compose(AddOneToObs(), RewardSum())
954+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), transform)
955+
956+
Inverse Transforms
957+
^^^^^^^^^^^^^^^^^^
958+
959+
Some transforms have an inverse transform that can be used to undo the transformation. For example, the AddOneToAction
960+
transform has an inverse transform that subtracts 1 from the action tensor:
961+
962+
>>> class AddOneToAction(Transform):
963+
... """A transform that adds 1 to the action tensor."""
964+
... def __init__(self):
965+
... super().__init__(in_keys=[], out_keys=[], in_keys_inv=["action"], out_keys_inv=["action"])
966+
... def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor:
967+
... return action + 1
968+
969+
Using a Transform with a Replay Buffer
970+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
971+
972+
You can use a transform with a replay buffer by passing it to the ReplayBuffer constructor:
903973

904974
Cloning transforms
905975
~~~~~~~~~~~~~~~~~~

test/test_collector.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3213,9 +3213,11 @@ def test_cudagraph_policy(self, collector_cls, cudagraph_policy):
32133213
@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
32143214
class TestCollectorsNonTensor:
32153215
class AddNontTensorData(Transform):
3216-
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
3217-
tensordict["nt"] = f"a string! - {tensordict.get('step_count').item()}"
3218-
return tensordict
3216+
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
3217+
next_tensordict[
3218+
"nt"
3219+
] = f"a string! - {next_tensordict.get('step_count').item()}"
3220+
return next_tensordict
32193221

32203222
def _reset(
32213223
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase

torchrl/envs/model_based/dreamer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ class DreamerDecoder(Transform):
8080
>>> model_based_env_eval = model_based_env.append_transform(DreamerDecoder())
8181
"""
8282

83-
def _call(self, tensordict):
84-
return self.parent.base_env.obs_decoder(tensordict)
83+
def _call(self, next_tensordict):
84+
return self.parent.base_env.obs_decoder(next_tensordict)
8585

8686
def _reset(self, tensordict, tensordict_reset):
8787
return self._call(tensordict_reset)

torchrl/envs/transforms/gym_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def _get_lives(self):
138138
lives = torch.as_tensor([_lives() for _lives in lives])
139139
return lives
140140

141-
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
142-
return tensordict
141+
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
142+
return next_tensordict
143143

144144
def _step(self, tensordict, next_tensordict):
145145
parent = self.parent

torchrl/envs/transforms/r3m.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True):
7070
self.del_keys = del_keys
7171

7272
@set_lazy_legacy(False)
73-
def _call(self, tensordict):
74-
with tensordict.view(-1) as tensordict_view:
73+
def _call(self, next_tensordict):
74+
with next_tensordict.view(-1) as tensordict_view:
7575
super()._call(tensordict_view)
7676
if self.del_keys:
77-
tensordict.exclude(*self.in_keys, inplace=True)
78-
return tensordict
77+
next_tensordict.exclude(*self.in_keys, inplace=True)
78+
return next_tensordict
7979

8080
forward = _call
8181

torchrl/envs/transforms/rlhf.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,25 +158,25 @@ def _reset(
158158
tensordict_reset = self._call(tensordict_reset)
159159
return tensordict_reset
160160

161-
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
161+
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
162162
# run the actor on the tensordict
163-
action = tensordict.get("action", None)
163+
action = next_tensordict.get("action", None)
164164
if action is None:
165165
# being called after reset or without action, skipping
166166
if self.out_keys[0] != ("reward",) and self.parent is not None:
167-
tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
168-
return tensordict
167+
next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
168+
return next_tensordict
169169
with self.frozen_params.to_module(self.functional_actor):
170-
dist = self.functional_actor.get_dist(tensordict.clone(False))
170+
dist = self.functional_actor.get_dist(next_tensordict.clone(False))
171171
# get the log_prob given the original model
172172
log_prob = dist.log_prob(action)
173173
reward_key = self.in_keys[0]
174-
reward = tensordict.get("next").get(reward_key)
175-
curr_log_prob = tensordict.get(self.sample_log_prob_key)
174+
reward = next_tensordict.get("next").get(reward_key)
175+
curr_log_prob = next_tensordict.get(self.sample_log_prob_key)
176176
# we use the unbiased consistent estimator of the KL: log_p(x) - log_q(x) when x ~ p(x)
177177
kl = (curr_log_prob - log_prob).view_as(reward)
178-
tensordict.set(("next", *self.out_keys[0]), reward + self.coef * kl)
179-
return tensordict
178+
next_tensordict.set(("next", *self.out_keys[0]), reward + self.coef * kl)
179+
return next_tensordict
180180

181181
def _step(
182182
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase

0 commit comments

Comments
 (0)