Skip to content

Commit 09b28d3

Browse files
[Feature] add standard_normal for RewardScaling (#682)
* Add standard_normal * give attribute access * Update standard_normal * Update tests * Fix tests * Address in-place scaling of reward * Improvise tests
1 parent 0ab52dd commit 09b28d3

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

test/test_transforms.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,13 +1289,16 @@ def test_binarized_reward(self, device, batch):
12891289
@pytest.mark.parametrize("loc", [1, 5])
12901290
@pytest.mark.parametrize("keys", [None, ["reward_1"]])
12911291
@pytest.mark.parametrize("device", get_available_devices())
1292-
def test_reward_scaling(self, batch, scale, loc, keys, device):
1292+
@pytest.mark.parametrize("standard_normal", [True, False])
1293+
def test_reward_scaling(self, batch, scale, loc, keys, device, standard_normal):
12931294
torch.manual_seed(0)
12941295
if keys is None:
12951296
keys_total = set([])
12961297
else:
12971298
keys_total = set(keys)
1298-
reward_scaling = RewardScaling(in_keys=keys, scale=scale, loc=loc)
1299+
reward_scaling = RewardScaling(
1300+
in_keys=keys, scale=scale, loc=loc, standard_normal=standard_normal
1301+
)
12991302
td = TensorDict(
13001303
{
13011304
**{key: torch.randn(*batch, 1, device=device) for key in keys_total},
@@ -1308,13 +1311,17 @@ def test_reward_scaling(self, batch, scale, loc, keys, device):
13081311
td_copy = td.clone()
13091312
reward_scaling(td)
13101313
for key in keys_total:
1311-
assert (td.get(key) == td_copy.get(key).mul_(scale).add_(loc)).all()
1314+
if standard_normal:
1315+
original_key = td.get(key)
1316+
scaled_key = (td_copy.get(key) - loc) / scale
1317+
torch.testing.assert_close(original_key, scaled_key)
1318+
else:
1319+
original_key = td.get(key)
1320+
scaled_key = td_copy.get(key) * scale + loc
1321+
torch.testing.assert_close(original_key, scaled_key)
13121322
assert (td.get("dont touch") == td_copy.get("dont touch")).all()
1313-
if len(keys_total) == 0:
1314-
assert (
1315-
td.get("reward") == td_copy.get("reward").mul_(scale).add_(loc)
1316-
).all()
1317-
elif len(keys_total) == 1:
1323+
1324+
if len(keys_total) == 1:
13181325
reward_spec = UnboundedContinuousTensorSpec(device=device)
13191326
reward_spec = reward_scaling.transform_reward_spec(reward_spec)
13201327
assert reward_spec.shape == torch.Size([1])

torchrl/envs/transforms/transforms.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,12 @@ class RewardScaling(Transform):
15001500
Args:
15011501
loc (number or torch.Tensor): location of the affine transform
15021502
scale (number or torch.Tensor): scale of the affine transform
1503+
standard_normal (bool, optional): if True, the transform will be
1504+
1505+
.. math::
1506+
reward = (reward-loc)/scale
1507+
1508+
as it is done for standardization. Default is `False`.
15031509
"""
15041510

15051511
inplace = True
@@ -1509,10 +1515,13 @@ def __init__(
15091515
loc: Union[float, torch.Tensor],
15101516
scale: Union[float, torch.Tensor],
15111517
in_keys: Optional[Sequence[str]] = None,
1518+
standard_normal: bool = False,
15121519
):
15131520
if in_keys is None:
15141521
in_keys = ["reward"]
15151522
super().__init__(in_keys=in_keys)
1523+
self.standard_normal = standard_normal
1524+
15161525
if not isinstance(loc, torch.Tensor):
15171526
loc = torch.tensor(loc)
15181527
if not isinstance(scale, torch.Tensor):
@@ -1522,8 +1531,16 @@ def __init__(
15221531
self.register_buffer("scale", scale.clamp_min(1e-6))
15231532

15241533
def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor:
1525-
reward.mul_(self.scale).add_(self.loc)
1526-
return reward
1534+
if self.standard_normal:
1535+
loc = self.loc
1536+
scale = self.scale
1537+
reward = (reward - loc) / scale
1538+
return reward
1539+
else:
1540+
scale = self.scale
1541+
loc = self.loc
1542+
reward = reward * scale + loc
1543+
return reward
15271544

15281545
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
15291546
if isinstance(reward_spec, UnboundedContinuousTensorSpec):

0 commit comments

Comments
 (0)