Skip to content

Commit 1697102

Browse files
[BugFix] RewardSum transform for multiple reward keys (#1544)
Signed-off-by: Matteo Bettini <matbet@meta.com> Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 106368f commit 1697102

File tree

2 files changed

+173
-98
lines changed

2 files changed

+173
-98
lines changed

test/test_transforms.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@
3232
IncrementingEnv,
3333
MockBatchedLockedEnv,
3434
MockBatchedUnLockedEnv,
35+
MultiKeyCountingEnv,
36+
MultiKeyCountingEnvPolicy,
3537
NestedCountingEnv,
3638
)
3739
from tensordict import unravel_key
3840
from tensordict.nn import TensorDictSequential
3941
from tensordict.tensordict import TensorDict, TensorDictBase
42+
from tensordict.utils import _unravel_key_to_tuple
4043
from torch import multiprocessing as mp, nn, Tensor
4144
from torchrl._utils import prod
4245
from torchrl.data import (
@@ -104,7 +107,7 @@
104107
from torchrl.envs.transforms.transforms import _has_tv
105108
from torchrl.envs.transforms.vc1 import _has_vc
106109
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
107-
from torchrl.envs.utils import check_env_specs, step_mdp
110+
from torchrl.envs.utils import _replace_last, check_env_specs, step_mdp
108111
from torchrl.modules import LSTMModule, MLP, ProbabilisticActor, TanhNormal
109112

110113
TIMEOUT = 100.0
@@ -4527,6 +4530,36 @@ def test_trans_parallel_env_check(self):
45274530
r = env.rollout(4)
45284531
assert r["next", "episode_reward"].unique().numel() > 1
45294532

4533+
@pytest.mark.parametrize("has_in_keys,", [True, False])
4534+
def test_trans_multi_key(
4535+
self, has_in_keys, n_workers=2, batch_size=(3, 2), max_steps=5
4536+
):
4537+
torch.manual_seed(0)
4538+
env_fun = lambda: MultiKeyCountingEnv(batch_size=batch_size)
4539+
base_env = SerialEnv(n_workers, env_fun)
4540+
if has_in_keys:
4541+
t = RewardSum(in_keys=base_env.reward_keys, reset_keys=base_env.reset_keys)
4542+
else:
4543+
t = RewardSum()
4544+
env = TransformedEnv(
4545+
base_env,
4546+
Compose(t),
4547+
)
4548+
policy = MultiKeyCountingEnvPolicy(
4549+
full_action_spec=env.action_spec, deterministic=True
4550+
)
4551+
4552+
check_env_specs(env)
4553+
td = env.rollout(max_steps, policy=policy)
4554+
for reward_key in env.reward_keys:
4555+
reward_key = _unravel_key_to_tuple(reward_key)
4556+
assert (
4557+
td.get(
4558+
("next", _replace_last(reward_key, f"episode_{reward_key[-1]}"))
4559+
)[(0,) * (len(batch_size) + 1)][-1]
4560+
== max_steps
4561+
).all()
4562+
45304563
@pytest.mark.parametrize("in_key", ["reward", ("some", "nested")])
45314564
def test_transform_no_env(self, in_key):
45324565
t = RewardSum(in_keys=[in_key], out_keys=[("some", "nested_sum")])
@@ -4550,7 +4583,8 @@ def test_transform_no_env(self, in_key):
45504583
def test_transform_compose(
45514584
self,
45524585
):
4553-
t = Compose(RewardSum())
4586+
# reset keys should not be needed for offline run
4587+
t = Compose(RewardSum(in_keys=["reward"], out_keys=["episode_reward"]))
45544588
reward = torch.randn(10)
45554589
td = TensorDict({("next", "reward"): reward}, [])
45564590
with pytest.raises(
@@ -4649,6 +4683,9 @@ def test_sum_reward(self, keys, device):
46494683

46504684
# reset environments
46514685
td.set("_reset", torch.ones(batch, dtype=torch.bool, device=device))
4686+
with pytest.raises(TypeError, match="reset_keys not provided but parent"):
4687+
rs.reset(td)
4688+
rs._reset_keys = ["_reset"]
46524689
rs.reset(td)
46534690

46544691
# apply a third time, episode_reward should be equal to reward again

torchrl/envs/transforms/transforms.py

Lines changed: 134 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from torchrl.envs.common import _EnvPostInit, EnvBase, make_tensordict
4141
from torchrl.envs.transforms import functional as F
4242
from torchrl.envs.transforms.utils import check_finite
43-
from torchrl.envs.utils import _sort_keys, step_mdp
43+
from torchrl.envs.utils import _replace_last, _sort_keys, step_mdp
4444
from torchrl.objectives.value.functional import reward2go
4545

4646
try:
@@ -242,7 +242,7 @@ def _apply_transform(self, obs: torch.Tensor) -> None:
242242
243243
"""
244244
raise NotImplementedError(
245-
f"{self.__class__.__name__}_apply_transform is not coded. If the transform is coded in "
245+
f"{self.__class__.__name__}._apply_transform is not coded. If the transform is coded in "
246246
"transform._call, make sure that this method is called instead of"
247247
"transform.forward, which is reserved for usage inside nn.Modules"
248248
"or appended to a replay buffer."
@@ -4342,74 +4342,140 @@ class RewardSum(Transform):
43424342
"""Tracks episode cumulative rewards.
43434343
43444344
This transform accepts a list of tensordict reward keys (i.e. ´in_keys´) and tracks their cumulative
4345-
value along each episode. When called, the transform creates a new tensordict key for each in_key named
4346-
´episode_{in_key}´ where the cumulative values are written. All ´in_keys´ should be part of the env
4347-
reward and be present in the env reward_spec.
4345+
value along the time dimension for each episode.
43484346
4349-
If no in_keys are specified, this transform assumes ´reward´ to be the input key. However, multiple rewards
4350-
(e.g. reward1 and reward2) can also be specified. If ´in_keys´ are not present in the provided tensordict,
4351-
this transform hos no effect.
4347+
When called, the transform writes a new tensordict entry for each ``in_key`` named
4348+
``episode_{in_key}`` where the cumulative values are written.
43524349
4353-
.. note:: :class:`~RewardSum` currently only supports ``"done"`` signal at the root.
4354-
Nested ``"done"``, such as those found in MARL settings, are currently not supported.
4355-
If this feature is needed, please raise an issue on TorchRL repo.
4350+
Args:
4351+
in_keys (list of NestedKeys, optional): Input reward keys.
4352+
All ´in_keys´ should be part of the environment reward_spec.
4353+
If no ``in_keys`` are specified, this transform assumes ``"reward"`` to be the input key.
4354+
However, multiple rewards (e.g. ``"reward1"`` and ``"reward2""``) can also be specified.
4355+
out_keys (list of NestedKeys, optional): The output sum keys, should be one per each input key.
4356+
reset_keys (list of NestedKeys, optional): the list of reset_keys to be
4357+
used, if the parent environment cannot be found. If provided, this
4358+
value will prevail over the environment ``reset_keys``.
43564359
4360+
Examples:
4361+
>>> from torchrl.envs.transforms import RewardSum, TransformedEnv
4362+
>>> from torchrl.envs.libs.gym import GymEnv
4363+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), RewardSum())
4364+
>>> td = env.reset()
4365+
>>> print(td["episode_reward"])
4366+
tensor([0.])
4367+
>>> td = env.rollout(3)
4368+
>>> print(td["next", "episode_reward"])
4369+
tensor([[-0.5926],
4370+
[-1.4578],
4371+
[-2.7885]])
43574372
"""
43584373

43594374
def __init__(
43604375
self,
43614376
in_keys: Optional[Sequence[NestedKey]] = None,
43624377
out_keys: Optional[Sequence[NestedKey]] = None,
4378+
reset_keys: Optional[Sequence[NestedKey]] = None,
43634379
):
43644380
"""Initialises the transform. Filters out non-reward input keys and defines output keys."""
4365-
if in_keys is None:
4366-
in_keys = ["reward"]
4367-
if out_keys is None and in_keys == ["reward"]:
4368-
out_keys = ["episode_reward"]
4369-
elif out_keys is None:
4370-
raise RuntimeError(
4371-
"the out_keys must be specified for non-conventional in-keys in RewardSum."
4381+
super().__init__(in_keys=in_keys, out_keys=out_keys)
4382+
self._reset_keys = reset_keys
4383+
4384+
@property
4385+
def in_keys(self):
4386+
in_keys = self.__dict__.get("_in_keys", None)
4387+
if in_keys in (None, []):
4388+
# retrieve rewards from parent env
4389+
parent = self.parent
4390+
if parent is None:
4391+
in_keys = ["reward"]
4392+
else:
4393+
in_keys = copy(parent.reward_keys)
4394+
self._in_keys = in_keys
4395+
return in_keys
4396+
4397+
@in_keys.setter
4398+
def in_keys(self, value):
4399+
if value is not None:
4400+
if isinstance(value, (str, tuple)):
4401+
value = [value]
4402+
value = [unravel_key(val) for val in value]
4403+
self._in_keys = value
4404+
4405+
@property
4406+
def out_keys(self):
4407+
out_keys = self.__dict__.get("_out_keys", None)
4408+
if out_keys in (None, []):
4409+
out_keys = [
4410+
_replace_last(in_key, f"episode_{_unravel_key_to_tuple(in_key)[-1]}")
4411+
for in_key in self.in_keys
4412+
]
4413+
self._out_keys = out_keys
4414+
return out_keys
4415+
4416+
@out_keys.setter
4417+
def out_keys(self, value):
4418+
# we must access the private attribute because this check occurs before
4419+
# the parent env is defined
4420+
if value is not None and len(self._in_keys) != len(value):
4421+
raise ValueError(
4422+
"RewardSum expects the same number of input and output keys"
43724423
)
4424+
if value is not None:
4425+
if isinstance(value, (str, tuple)):
4426+
value = [value]
4427+
value = [unravel_key(val) for val in value]
4428+
self._out_keys = value
43734429

4374-
super().__init__(in_keys=in_keys, out_keys=out_keys)
4430+
@property
4431+
def reset_keys(self):
4432+
reset_keys = self.__dict__.get("_reset_keys", None)
4433+
if reset_keys is None:
4434+
parent = self.parent
4435+
if parent is None:
4436+
raise TypeError(
4437+
"reset_keys not provided but parent env not found. "
4438+
"Make sure that the reset_keys are provided during "
4439+
"construction if the transform does not have a container env."
4440+
)
4441+
reset_keys = copy(parent.reset_keys)
4442+
self._reset_keys = reset_keys
4443+
return reset_keys
4444+
4445+
@reset_keys.setter
4446+
def reset_keys(self, value):
4447+
if value is not None:
4448+
if isinstance(value, (str, tuple)):
4449+
value = [value]
4450+
value = [unravel_key(val) for val in value]
4451+
self._reset_keys = value
43754452

43764453
def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
43774454
"""Resets episode rewards."""
4378-
# Non-batched environments
4379-
_reset = tensordict.get("_reset", None)
4380-
if _reset is None:
4381-
_reset = torch.ones(
4382-
self.parent.done_spec.shape if self.parent else tensordict.batch_size,
4383-
dtype=torch.bool,
4384-
device=tensordict.device,
4385-
)
4455+
for in_key, reset_key, out_key in zip(
4456+
self.in_keys, self.reset_keys, self.out_keys
4457+
):
4458+
_reset = tensordict.get(reset_key, None)
43864459

4387-
if _reset.any():
4388-
_reset = _reset.sum(
4389-
tuple(range(tensordict.batch_dims, _reset.ndim)), dtype=torch.bool
4390-
)
4391-
reward_key = self.parent.reward_key if self.parent else "reward"
4392-
for in_key, out_key in zip(self.in_keys, self.out_keys):
4393-
if out_key in tensordict.keys(True, True):
4394-
value = tensordict[out_key]
4395-
tensordict[out_key] = value.masked_fill(
4396-
expand_as_right(_reset, value), 0.0
4397-
)
4398-
elif unravel_key(in_key) == unravel_key(reward_key):
4460+
if _reset is None or _reset.any():
4461+
value = tensordict.get(out_key, default=None)
4462+
if value is not None:
4463+
if _reset is None:
4464+
tensordict.set(out_key, torch.zeros_like(value))
4465+
else:
4466+
tensordict.set(
4467+
out_key,
4468+
value.masked_fill(
4469+
expand_as_right(_reset.squeeze(-1), value), 0.0
4470+
),
4471+
)
4472+
else:
43994473
# Since the episode reward is not in the tensordict, we need to allocate it
44004474
# with zeros entirely (regardless of the _reset mask)
4401-
tensordict[out_key] = self.parent.reward_spec.zero()
4402-
else:
4403-
try:
4404-
tensordict[out_key] = self.parent.observation_spec[
4405-
in_key
4406-
].zero()
4407-
except KeyError as err:
4408-
raise KeyError(
4409-
f"The key {in_key} was not found in the parent "
4410-
f"observation_spec with keys "
4411-
f"{list(self.parent.observation_spec.keys(True))}. "
4412-
) from err
4475+
tensordict.set(
4476+
out_key,
4477+
self.parent.full_reward_spec[in_key].zero(),
4478+
)
44134479
return tensordict
44144480

44154481
def _step(
@@ -4430,76 +4496,48 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
44304496
state_spec = input_spec["full_state_spec"]
44314497
if state_spec is None:
44324498
state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device)
4433-
reward_spec = self.parent.output_spec["full_reward_spec"]
4434-
reward_spec_keys = list(reward_spec.keys(True, True))
4499+
state_spec.update(self._generate_episode_reward_spec())
4500+
input_spec["full_state_spec"] = state_spec
4501+
return input_spec
4502+
4503+
def _generate_episode_reward_spec(self) -> CompositeSpec:
4504+
episode_reward_spec = CompositeSpec()
4505+
reward_spec = self.parent.full_reward_spec
4506+
reward_spec_keys = self.parent.reward_keys
44354507
# Define episode specs for all out_keys
44364508
for in_key, out_key in zip(self.in_keys, self.out_keys):
44374509
if (
44384510
in_key in reward_spec_keys
44394511
): # if this out_key has a corresponding key in reward_spec
44404512
out_key = _unravel_key_to_tuple(out_key)
4441-
temp_state_spec = state_spec
4513+
temp_episode_reward_spec = episode_reward_spec
44424514
temp_rew_spec = reward_spec
44434515
for sub_key in out_key[:-1]:
44444516
if (
44454517
not isinstance(temp_rew_spec, CompositeSpec)
44464518
or sub_key not in temp_rew_spec.keys()
44474519
):
44484520
break
4449-
if sub_key not in temp_state_spec.keys():
4450-
temp_state_spec[sub_key] = temp_rew_spec[sub_key].empty()
4521+
if sub_key not in temp_episode_reward_spec.keys():
4522+
temp_episode_reward_spec[sub_key] = temp_rew_spec[
4523+
sub_key
4524+
].empty()
44514525
temp_rew_spec = temp_rew_spec[sub_key]
4452-
temp_state_spec = temp_state_spec[sub_key]
4453-
state_spec[out_key] = reward_spec[in_key].clone()
4526+
temp_episode_reward_spec = temp_episode_reward_spec[sub_key]
4527+
episode_reward_spec[out_key] = reward_spec[in_key].clone()
44544528
else:
44554529
raise ValueError(
44564530
f"The in_key: {in_key} is not present in the reward spec {reward_spec}."
44574531
)
4458-
input_spec["full_state_spec"] = state_spec
4459-
return input_spec
4532+
return episode_reward_spec
44604533

44614534
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
44624535
"""Transforms the observation spec, adding the new keys generated by RewardSum."""
4463-
# Retrieve parent reward spec
4464-
reward_spec = self.parent.reward_spec
4465-
reward_key = self.parent.reward_key if self.parent else "reward"
4466-
4467-
episode_specs = {}
4468-
if isinstance(reward_spec, CompositeSpec):
4469-
# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
4470-
if not all(k in reward_spec.keys(True, True) for k in self.in_keys):
4471-
raise KeyError("Not all in_keys are present in ´reward_spec´")
4472-
4473-
# Define episode specs for all out_keys
4474-
for out_key in self.out_keys:
4475-
episode_spec = UnboundedContinuousTensorSpec(
4476-
shape=reward_spec.shape,
4477-
device=reward_spec.device,
4478-
dtype=reward_spec.dtype,
4479-
)
4480-
episode_specs.update({out_key: episode_spec})
4481-
4482-
else:
4483-
# If reward_spec is not a CompositeSpec, the only in_key should be ´reward´
4484-
if set(unravel_key_list(self.in_keys)) != {unravel_key(reward_key)}:
4485-
raise KeyError(
4486-
"reward_spec is not a CompositeSpec class, in_keys should only include ´reward´"
4487-
)
4488-
4489-
# Define episode spec
4490-
episode_spec = UnboundedContinuousTensorSpec(
4491-
device=reward_spec.device,
4492-
dtype=reward_spec.dtype,
4493-
shape=reward_spec.shape,
4494-
)
4495-
episode_specs.update({self.out_keys[0]: episode_spec})
4496-
4497-
# Update observation_spec with episode_specs
44984536
if not isinstance(observation_spec, CompositeSpec):
44994537
observation_spec = CompositeSpec(
45004538
observation=observation_spec, shape=self.parent.batch_size
45014539
)
4502-
observation_spec.update(episode_specs)
4540+
observation_spec.update(self._generate_episode_reward_spec())
45034541
return observation_spec
45044542

45054543
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

0 commit comments

Comments
 (0)