Skip to content

Commit d918200

Browse files
romainjlnvmoens
andauthored
[Feature] Auto-compute stats for ObservationNorm (#669)
* Add auto-compute stats feature for ObservationNorm * Fix issue in ObservNorm init function * Quick refactor of ObservationNorm init method * Minor refactoring and adding more tests for ObservationNorm * lint * docstring * docstring Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent ac4b987 commit d918200

File tree

2 files changed

+152
-10
lines changed

2 files changed

+152
-10
lines changed

test/test_transforms.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,73 @@ def test_observationnorm(
905905
assert (observation_spec[key].space.minimum == loc).all()
906906
assert (observation_spec[key].space.maximum == scale + loc).all()
907907

908+
@pytest.mark.parametrize(
909+
"keys", [["next_observation"], ["next_observation", "next_pixel"]]
910+
)
911+
@pytest.mark.parametrize("size", [1, 3])
912+
@pytest.mark.parametrize("device", get_available_devices())
913+
@pytest.mark.parametrize("standard_normal", [True, False])
914+
def test_observationnorm_init_stats(self, keys, size, device, standard_normal):
915+
base_env = ContinuousActionVecMockEnv(
916+
observation_spec=CompositeSpec(
917+
next_observation=NdBoundedTensorSpec(
918+
minimum=1, maximum=1, shape=torch.Size([size])
919+
),
920+
next_observation_orig=NdBoundedTensorSpec(
921+
minimum=1, maximum=1, shape=torch.Size([size])
922+
),
923+
),
924+
action_spec=NdBoundedTensorSpec(
925+
minimum=1, maximum=1, shape=torch.Size((size,))
926+
),
927+
seed=0,
928+
)
929+
base_env.out_key = "observation"
930+
t_env = TransformedEnv(
931+
base_env,
932+
transform=ObservationNorm(in_keys=keys, standard_normal=standard_normal),
933+
)
934+
if len(keys) > 1:
935+
t_env.transform.init_stats(num_iter=11, key="next_observation")
936+
else:
937+
t_env.transform.init_stats(num_iter=11)
938+
939+
if standard_normal:
940+
torch.testing.assert_close(t_env.transform.loc, torch.Tensor([1.06] * size))
941+
torch.testing.assert_close(
942+
t_env.transform.scale, torch.Tensor([0.03316621] * size)
943+
)
944+
else:
945+
torch.testing.assert_close(
946+
t_env.transform.loc, torch.Tensor([31.960236] * size)
947+
)
948+
torch.testing.assert_close(
949+
t_env.transform.scale, torch.Tensor([30.151169] * size)
950+
)
951+
952+
def test_observationnorm_stats_already_initialized_error(self):
953+
transform = ObservationNorm(in_keys="next_observation", loc=0, scale=1)
954+
955+
with pytest.raises(RuntimeError, match="Loc/Scale are already initialized"):
956+
transform.init_stats(num_iter=11)
957+
958+
def test_observationnorm_init_stats_multiple_keys_error(self):
959+
transform = ObservationNorm(in_keys=["next_observation", "next_pixels"])
960+
961+
err_msg = "Transform has multiple in_keys but no specific key was passed as an argument"
962+
with pytest.raises(RuntimeError, match=err_msg):
963+
transform.init_stats(num_iter=11)
964+
965+
def test_observationnorm_uninitialized_stats_error(self):
966+
transform = ObservationNorm(in_keys=["next_observation", "next_pixels"])
967+
968+
err_msg = (
969+
"Loc/Scale have not been initialized. Either pass in values in the constructor "
970+
"or call the init_stats method"
971+
)
972+
with pytest.raises(RuntimeError, match=err_msg):
973+
transform._apply_transform(torch.Tensor([1]))
974+
908975
def test_catframes_transform_observation_spec(self):
909976
N = 4
910977
key1 = "first key"

torchrl/envs/transforms/transforms.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import multiprocessing as mp
1010
from copy import deepcopy, copy
1111
from textwrap import indent
12-
from typing import Any, List, Optional, OrderedDict, Sequence, Union
12+
from typing import Any, List, Optional, OrderedDict, Sequence, Union, Tuple
1313
from warnings import warn
1414

1515
import torch
@@ -96,6 +96,9 @@ def __init__(
9696
out_keys_inv: Optional[Sequence[str]] = None,
9797
):
9898
super().__init__()
99+
if isinstance(in_keys, str):
100+
in_keys = [in_keys]
101+
99102
self.in_keys = in_keys
100103
if out_keys is None:
101104
out_keys = copy(self.in_keys)
@@ -1255,19 +1258,30 @@ class ObservationNorm(ObservationTransform):
12551258
>>> _ = transform(td)
12561259
>>> print(torch.isclose(td.get('next_obs').mean(0),
12571260
... torch.zeros(3)).all())
1258-
Tensor(True)
1261+
tensor(True)
12591262
>>> print(torch.isclose(td.get('next_obs').std(0),
12601263
... torch.ones(3)).all())
1261-
Tensor(True)
1264+
tensor(True)
1265+
1266+
The normalisation stats can be automatically computed:
1267+
Examples:
1268+
>>> from torchrl.envs.libs.gym import GymEnv
1269+
>>> torch.manual_seed(0)
1270+
>>> env = GymEnv("Pendulum-v1")
1271+
>>> env = TransformedEnv(env, ObservationNorm(in_keys=["observation"]))
1272+
>>> env.set_seed(0)
1273+
>>> env.transform.init_stats(100)
1274+
>>> print(env.transform.loc, env.transform.scale)
1275+
tensor([-1.3752e+01, -6.5087e-03, 2.9294e-03], dtype=torch.float32) tensor([14.9636, 2.5608, 0.6408], dtype=torch.float32)
12621276
12631277
"""
12641278

12651279
inplace = True
12661280

12671281
def __init__(
12681282
self,
1269-
loc: Union[float, torch.Tensor],
1270-
scale: Union[float, torch.Tensor],
1283+
loc: Optional[float, torch.Tensor] = None,
1284+
scale: Optional[float, torch.Tensor] = None,
12711285
in_keys: Optional[Sequence[str]] = None,
12721286
# observation_spec_key: =None,
12731287
standard_normal: bool = False,
@@ -1279,18 +1293,79 @@ def __init__(
12791293
"next_observation_state",
12801294
]
12811295
super().__init__(in_keys=in_keys)
1282-
if not isinstance(loc, torch.Tensor):
1296+
self.standard_normal = standard_normal
1297+
self.eps = 1e-6
1298+
1299+
if loc is not None and not isinstance(loc, torch.Tensor):
12831300
loc = torch.tensor(loc, dtype=torch.float)
1284-
if not isinstance(scale, torch.Tensor):
1301+
1302+
if scale is not None and not isinstance(scale, torch.Tensor):
12851303
scale = torch.tensor(scale, dtype=torch.float)
1304+
scale.clamp_min(self.eps)
12861305

12871306
# self.observation_spec_key = observation_spec_key
1288-
self.standard_normal = standard_normal
12891307
self.register_buffer("loc", loc)
1290-
eps = 1e-6
1291-
self.register_buffer("scale", scale.clamp_min(eps))
1308+
self.register_buffer("scale", scale)
1309+
1310+
def init_stats(
1311+
self,
1312+
num_iter: int,
1313+
reduce_dim: Union[int, Tuple[int]] = 0,
1314+
key: Optional[str] = None,
1315+
) -> None:
1316+
"""Initializes the loc and scale stats of the parent environment.
1317+
1318+
Normalization constant should ideally make the observation statistics approach
1319+
those of a standard Gaussian distribution. This method computes a location
1320+
and scale tensor that will empirically compute the mean and standard
1321+
deviation of a Gaussian distribution fitted on data generated randomly with
1322+
the parent environment for a given number of steps.
1323+
1324+
Args:
1325+
num_iter (int): number of random iterations to run in the environment.
1326+
reduce_dim (int, optional): dimension to compute the mean and std over.
1327+
Defaults to 0.
1328+
key (str, optional): if provided, the summary statistics will be
1329+
retrieved from that key in the resulting tensordicts.
1330+
Otherwise, the first key in :obj:`ObservationNorm.in_keys` will be used.
1331+
1332+
"""
1333+
if self.loc is not None or self.scale is not None:
1334+
raise RuntimeError(
1335+
f"Loc/Scale are already initialized: ({self.loc}, {self.scale})"
1336+
)
1337+
1338+
if len(self.in_keys) > 1 and key is None:
1339+
raise RuntimeError(
1340+
"Transform has multiple in_keys but no specific key was passed as an argument"
1341+
)
1342+
key = self.in_keys[0] if key is None else key
1343+
1344+
parent = self.parent
1345+
collected_frames = 0
1346+
data = []
1347+
while collected_frames < num_iter:
1348+
tensordict = parent.rollout(max_steps=num_iter)
1349+
collected_frames += tensordict.numel()
1350+
data.append(tensordict.get(key))
1351+
1352+
data = torch.cat(data, reduce_dim)
1353+
loc = data.mean(reduce_dim)
1354+
scale = data.std(reduce_dim)
1355+
1356+
if not self.standard_normal:
1357+
loc = loc / scale
1358+
scale = 1 / scale
1359+
1360+
self.register_buffer("loc", loc)
1361+
self.register_buffer("scale", scale.clamp_min(self.eps))
12921362

12931363
def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
1364+
if self.loc is None or self.scale is None:
1365+
raise RuntimeError(
1366+
"Loc/Scale have not been initialized. Either pass in values in the constructor "
1367+
"or call the init_stats method"
1368+
)
12941369
if self.standard_normal:
12951370
loc = self.loc
12961371
scale = self.scale

0 commit comments

Comments
 (0)