40
40
from torchrl .envs .common import _EnvPostInit , EnvBase , make_tensordict
41
41
from torchrl .envs .transforms import functional as F
42
42
from torchrl .envs .transforms .utils import check_finite
43
- from torchrl .envs .utils import _sort_keys , step_mdp
43
+ from torchrl .envs .utils import _replace_last , _sort_keys , step_mdp
44
44
from torchrl .objectives .value .functional import reward2go
45
45
46
46
try :
@@ -242,7 +242,7 @@ def _apply_transform(self, obs: torch.Tensor) -> None:
242
242
243
243
"""
244
244
raise NotImplementedError (
245
- f"{ self .__class__ .__name__ } _apply_transform is not coded. If the transform is coded in "
245
+ f"{ self .__class__ .__name__ } . _apply_transform is not coded. If the transform is coded in "
246
246
"transform._call, make sure that this method is called instead of"
247
247
"transform.forward, which is reserved for usage inside nn.Modules"
248
248
"or appended to a replay buffer."
@@ -4342,74 +4342,140 @@ class RewardSum(Transform):
4342
4342
"""Tracks episode cumulative rewards.
4343
4343
4344
4344
This transform accepts a list of tensordict reward keys (i.e. ´in_keys´) and tracks their cumulative
4345
- value along each episode. When called, the transform creates a new tensordict key for each in_key named
4346
- ´episode_{in_key}´ where the cumulative values are written. All ´in_keys´ should be part of the env
4347
- reward and be present in the env reward_spec.
4345
+ value along the time dimension for each episode.
4348
4346
4349
- If no in_keys are specified, this transform assumes ´reward´ to be the input key. However, multiple rewards
4350
- (e.g. reward1 and reward2) can also be specified. If ´in_keys´ are not present in the provided tensordict,
4351
- this transform hos no effect.
4347
+ When called, the transform writes a new tensordict entry for each ``in_key`` named
4348
+ ``episode_{in_key}`` where the cumulative values are written.
4352
4349
4353
- .. note:: :class:`~RewardSum` currently only supports ``"done"`` signal at the root.
4354
- Nested ``"done"``, such as those found in MARL settings, are currently not supported.
4355
- If this feature is needed, please raise an issue on TorchRL repo.
4350
+ Args:
4351
+ in_keys (list of NestedKeys, optional): Input reward keys.
4352
+ All ´in_keys´ should be part of the environment reward_spec.
4353
+ If no ``in_keys`` are specified, this transform assumes ``"reward"`` to be the input key.
4354
+ However, multiple rewards (e.g. ``"reward1"`` and ``"reward2""``) can also be specified.
4355
+ out_keys (list of NestedKeys, optional): The output sum keys, should be one per each input key.
4356
+ reset_keys (list of NestedKeys, optional): the list of reset_keys to be
4357
+ used, if the parent environment cannot be found. If provided, this
4358
+ value will prevail over the environment ``reset_keys``.
4356
4359
4360
+ Examples:
4361
+ >>> from torchrl.envs.transforms import RewardSum, TransformedEnv
4362
+ >>> from torchrl.envs.libs.gym import GymEnv
4363
+ >>> env = TransformedEnv(GymEnv("Pendulum-v1"), RewardSum())
4364
+ >>> td = env.reset()
4365
+ >>> print(td["episode_reward"])
4366
+ tensor([0.])
4367
+ >>> td = env.rollout(3)
4368
+ >>> print(td["next", "episode_reward"])
4369
+ tensor([[-0.5926],
4370
+ [-1.4578],
4371
+ [-2.7885]])
4357
4372
"""
4358
4373
4359
4374
def __init__ (
4360
4375
self ,
4361
4376
in_keys : Optional [Sequence [NestedKey ]] = None ,
4362
4377
out_keys : Optional [Sequence [NestedKey ]] = None ,
4378
+ reset_keys : Optional [Sequence [NestedKey ]] = None ,
4363
4379
):
4364
4380
"""Initialises the transform. Filters out non-reward input keys and defines output keys."""
4365
- if in_keys is None :
4366
- in_keys = ["reward" ]
4367
- if out_keys is None and in_keys == ["reward" ]:
4368
- out_keys = ["episode_reward" ]
4369
- elif out_keys is None :
4370
- raise RuntimeError (
4371
- "the out_keys must be specified for non-conventional in-keys in RewardSum."
4381
+ super ().__init__ (in_keys = in_keys , out_keys = out_keys )
4382
+ self ._reset_keys = reset_keys
4383
+
4384
+ @property
4385
+ def in_keys (self ):
4386
+ in_keys = self .__dict__ .get ("_in_keys" , None )
4387
+ if in_keys in (None , []):
4388
+ # retrieve rewards from parent env
4389
+ parent = self .parent
4390
+ if parent is None :
4391
+ in_keys = ["reward" ]
4392
+ else :
4393
+ in_keys = copy (parent .reward_keys )
4394
+ self ._in_keys = in_keys
4395
+ return in_keys
4396
+
4397
+ @in_keys .setter
4398
+ def in_keys (self , value ):
4399
+ if value is not None :
4400
+ if isinstance (value , (str , tuple )):
4401
+ value = [value ]
4402
+ value = [unravel_key (val ) for val in value ]
4403
+ self ._in_keys = value
4404
+
4405
+ @property
4406
+ def out_keys (self ):
4407
+ out_keys = self .__dict__ .get ("_out_keys" , None )
4408
+ if out_keys in (None , []):
4409
+ out_keys = [
4410
+ _replace_last (in_key , f"episode_{ _unravel_key_to_tuple (in_key )[- 1 ]} " )
4411
+ for in_key in self .in_keys
4412
+ ]
4413
+ self ._out_keys = out_keys
4414
+ return out_keys
4415
+
4416
+ @out_keys .setter
4417
+ def out_keys (self , value ):
4418
+ # we must access the private attribute because this check occurs before
4419
+ # the parent env is defined
4420
+ if value is not None and len (self ._in_keys ) != len (value ):
4421
+ raise ValueError (
4422
+ "RewardSum expects the same number of input and output keys"
4372
4423
)
4424
+ if value is not None :
4425
+ if isinstance (value , (str , tuple )):
4426
+ value = [value ]
4427
+ value = [unravel_key (val ) for val in value ]
4428
+ self ._out_keys = value
4373
4429
4374
- super ().__init__ (in_keys = in_keys , out_keys = out_keys )
4430
+ @property
4431
+ def reset_keys (self ):
4432
+ reset_keys = self .__dict__ .get ("_reset_keys" , None )
4433
+ if reset_keys is None :
4434
+ parent = self .parent
4435
+ if parent is None :
4436
+ raise TypeError (
4437
+ "reset_keys not provided but parent env not found. "
4438
+ "Make sure that the reset_keys are provided during "
4439
+ "construction if the transform does not have a container env."
4440
+ )
4441
+ reset_keys = copy (parent .reset_keys )
4442
+ self ._reset_keys = reset_keys
4443
+ return reset_keys
4444
+
4445
+ @reset_keys .setter
4446
+ def reset_keys (self , value ):
4447
+ if value is not None :
4448
+ if isinstance (value , (str , tuple )):
4449
+ value = [value ]
4450
+ value = [unravel_key (val ) for val in value ]
4451
+ self ._reset_keys = value
4375
4452
4376
4453
def reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
4377
4454
"""Resets episode rewards."""
4378
- # Non-batched environments
4379
- _reset = tensordict .get ("_reset" , None )
4380
- if _reset is None :
4381
- _reset = torch .ones (
4382
- self .parent .done_spec .shape if self .parent else tensordict .batch_size ,
4383
- dtype = torch .bool ,
4384
- device = tensordict .device ,
4385
- )
4455
+ for in_key , reset_key , out_key in zip (
4456
+ self .in_keys , self .reset_keys , self .out_keys
4457
+ ):
4458
+ _reset = tensordict .get (reset_key , None )
4386
4459
4387
- if _reset .any ():
4388
- _reset = _reset .sum (
4389
- tuple (range (tensordict .batch_dims , _reset .ndim )), dtype = torch .bool
4390
- )
4391
- reward_key = self .parent .reward_key if self .parent else "reward"
4392
- for in_key , out_key in zip (self .in_keys , self .out_keys ):
4393
- if out_key in tensordict .keys (True , True ):
4394
- value = tensordict [out_key ]
4395
- tensordict [out_key ] = value .masked_fill (
4396
- expand_as_right (_reset , value ), 0.0
4397
- )
4398
- elif unravel_key (in_key ) == unravel_key (reward_key ):
4460
+ if _reset is None or _reset .any ():
4461
+ value = tensordict .get (out_key , default = None )
4462
+ if value is not None :
4463
+ if _reset is None :
4464
+ tensordict .set (out_key , torch .zeros_like (value ))
4465
+ else :
4466
+ tensordict .set (
4467
+ out_key ,
4468
+ value .masked_fill (
4469
+ expand_as_right (_reset .squeeze (- 1 ), value ), 0.0
4470
+ ),
4471
+ )
4472
+ else :
4399
4473
# Since the episode reward is not in the tensordict, we need to allocate it
4400
4474
# with zeros entirely (regardless of the _reset mask)
4401
- tensordict [out_key ] = self .parent .reward_spec .zero ()
4402
- else :
4403
- try :
4404
- tensordict [out_key ] = self .parent .observation_spec [
4405
- in_key
4406
- ].zero ()
4407
- except KeyError as err :
4408
- raise KeyError (
4409
- f"The key { in_key } was not found in the parent "
4410
- f"observation_spec with keys "
4411
- f"{ list (self .parent .observation_spec .keys (True ))} . "
4412
- ) from err
4475
+ tensordict .set (
4476
+ out_key ,
4477
+ self .parent .full_reward_spec [in_key ].zero (),
4478
+ )
4413
4479
return tensordict
4414
4480
4415
4481
def _step (
@@ -4430,76 +4496,48 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
4430
4496
state_spec = input_spec ["full_state_spec" ]
4431
4497
if state_spec is None :
4432
4498
state_spec = CompositeSpec (shape = input_spec .shape , device = input_spec .device )
4433
- reward_spec = self .parent .output_spec ["full_reward_spec" ]
4434
- reward_spec_keys = list (reward_spec .keys (True , True ))
4499
+ state_spec .update (self ._generate_episode_reward_spec ())
4500
+ input_spec ["full_state_spec" ] = state_spec
4501
+ return input_spec
4502
+
4503
+ def _generate_episode_reward_spec (self ) -> CompositeSpec :
4504
+ episode_reward_spec = CompositeSpec ()
4505
+ reward_spec = self .parent .full_reward_spec
4506
+ reward_spec_keys = self .parent .reward_keys
4435
4507
# Define episode specs for all out_keys
4436
4508
for in_key , out_key in zip (self .in_keys , self .out_keys ):
4437
4509
if (
4438
4510
in_key in reward_spec_keys
4439
4511
): # if this out_key has a corresponding key in reward_spec
4440
4512
out_key = _unravel_key_to_tuple (out_key )
4441
- temp_state_spec = state_spec
4513
+ temp_episode_reward_spec = episode_reward_spec
4442
4514
temp_rew_spec = reward_spec
4443
4515
for sub_key in out_key [:- 1 ]:
4444
4516
if (
4445
4517
not isinstance (temp_rew_spec , CompositeSpec )
4446
4518
or sub_key not in temp_rew_spec .keys ()
4447
4519
):
4448
4520
break
4449
- if sub_key not in temp_state_spec .keys ():
4450
- temp_state_spec [sub_key ] = temp_rew_spec [sub_key ].empty ()
4521
+ if sub_key not in temp_episode_reward_spec .keys ():
4522
+ temp_episode_reward_spec [sub_key ] = temp_rew_spec [
4523
+ sub_key
4524
+ ].empty ()
4451
4525
temp_rew_spec = temp_rew_spec [sub_key ]
4452
- temp_state_spec = temp_state_spec [sub_key ]
4453
- state_spec [out_key ] = reward_spec [in_key ].clone ()
4526
+ temp_episode_reward_spec = temp_episode_reward_spec [sub_key ]
4527
+ episode_reward_spec [out_key ] = reward_spec [in_key ].clone ()
4454
4528
else :
4455
4529
raise ValueError (
4456
4530
f"The in_key: { in_key } is not present in the reward spec { reward_spec } ."
4457
4531
)
4458
- input_spec ["full_state_spec" ] = state_spec
4459
- return input_spec
4532
+ return episode_reward_spec
4460
4533
4461
4534
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
4462
4535
"""Transforms the observation spec, adding the new keys generated by RewardSum."""
4463
- # Retrieve parent reward spec
4464
- reward_spec = self .parent .reward_spec
4465
- reward_key = self .parent .reward_key if self .parent else "reward"
4466
-
4467
- episode_specs = {}
4468
- if isinstance (reward_spec , CompositeSpec ):
4469
- # If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
4470
- if not all (k in reward_spec .keys (True , True ) for k in self .in_keys ):
4471
- raise KeyError ("Not all in_keys are present in ´reward_spec´" )
4472
-
4473
- # Define episode specs for all out_keys
4474
- for out_key in self .out_keys :
4475
- episode_spec = UnboundedContinuousTensorSpec (
4476
- shape = reward_spec .shape ,
4477
- device = reward_spec .device ,
4478
- dtype = reward_spec .dtype ,
4479
- )
4480
- episode_specs .update ({out_key : episode_spec })
4481
-
4482
- else :
4483
- # If reward_spec is not a CompositeSpec, the only in_key should be ´reward´
4484
- if set (unravel_key_list (self .in_keys )) != {unravel_key (reward_key )}:
4485
- raise KeyError (
4486
- "reward_spec is not a CompositeSpec class, in_keys should only include ´reward´"
4487
- )
4488
-
4489
- # Define episode spec
4490
- episode_spec = UnboundedContinuousTensorSpec (
4491
- device = reward_spec .device ,
4492
- dtype = reward_spec .dtype ,
4493
- shape = reward_spec .shape ,
4494
- )
4495
- episode_specs .update ({self .out_keys [0 ]: episode_spec })
4496
-
4497
- # Update observation_spec with episode_specs
4498
4536
if not isinstance (observation_spec , CompositeSpec ):
4499
4537
observation_spec = CompositeSpec (
4500
4538
observation = observation_spec , shape = self .parent .batch_size
4501
4539
)
4502
- observation_spec .update (episode_specs )
4540
+ observation_spec .update (self . _generate_episode_reward_spec () )
4503
4541
return observation_spec
4504
4542
4505
4543
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
0 commit comments