9
9
import multiprocessing as mp
10
10
from copy import deepcopy , copy
11
11
from textwrap import indent
12
- from typing import Any , List , Optional , OrderedDict , Sequence , Union
12
+ from typing import Any , List , Optional , OrderedDict , Sequence , Union , Tuple
13
13
from warnings import warn
14
14
15
15
import torch
@@ -96,6 +96,9 @@ def __init__(
96
96
out_keys_inv : Optional [Sequence [str ]] = None ,
97
97
):
98
98
super ().__init__ ()
99
+ if isinstance (in_keys , str ):
100
+ in_keys = [in_keys ]
101
+
99
102
self .in_keys = in_keys
100
103
if out_keys is None :
101
104
out_keys = copy (self .in_keys )
@@ -1255,19 +1258,30 @@ class ObservationNorm(ObservationTransform):
1255
1258
>>> _ = transform(td)
1256
1259
>>> print(torch.isclose(td.get('next_obs').mean(0),
1257
1260
... torch.zeros(3)).all())
1258
- Tensor (True)
1261
+ tensor (True)
1259
1262
>>> print(torch.isclose(td.get('next_obs').std(0),
1260
1263
... torch.ones(3)).all())
1261
- Tensor(True)
1264
+ tensor(True)
1265
+
1266
+ The normalisation stats can be automatically computed:
1267
+ Examples:
1268
+ >>> from torchrl.envs.libs.gym import GymEnv
1269
+ >>> torch.manual_seed(0)
1270
+ >>> env = GymEnv("Pendulum-v1")
1271
+ >>> env = TransformedEnv(env, ObservationNorm(in_keys=["observation"]))
1272
+ >>> env.set_seed(0)
1273
+ >>> env.transform.init_stats(100)
1274
+ >>> print(env.transform.loc, env.transform.scale)
1275
+ tensor([-1.3752e+01, -6.5087e-03, 2.9294e-03], dtype=torch.float32) tensor([14.9636, 2.5608, 0.6408], dtype=torch.float32)
1262
1276
1263
1277
"""
1264
1278
1265
1279
inplace = True
1266
1280
1267
1281
def __init__ (
1268
1282
self ,
1269
- loc : Union [float , torch .Tensor ],
1270
- scale : Union [float , torch .Tensor ],
1283
+ loc : Optional [float , torch .Tensor ] = None ,
1284
+ scale : Optional [float , torch .Tensor ] = None ,
1271
1285
in_keys : Optional [Sequence [str ]] = None ,
1272
1286
# observation_spec_key: =None,
1273
1287
standard_normal : bool = False ,
@@ -1279,18 +1293,79 @@ def __init__(
1279
1293
"next_observation_state" ,
1280
1294
]
1281
1295
super ().__init__ (in_keys = in_keys )
1282
- if not isinstance (loc , torch .Tensor ):
1296
+ self .standard_normal = standard_normal
1297
+ self .eps = 1e-6
1298
+
1299
+ if loc is not None and not isinstance (loc , torch .Tensor ):
1283
1300
loc = torch .tensor (loc , dtype = torch .float )
1284
- if not isinstance (scale , torch .Tensor ):
1301
+
1302
+ if scale is not None and not isinstance (scale , torch .Tensor ):
1285
1303
scale = torch .tensor (scale , dtype = torch .float )
1304
+ scale .clamp_min (self .eps )
1286
1305
1287
1306
# self.observation_spec_key = observation_spec_key
1288
- self .standard_normal = standard_normal
1289
1307
self .register_buffer ("loc" , loc )
1290
- eps = 1e-6
1291
- self .register_buffer ("scale" , scale .clamp_min (eps ))
1308
+ self .register_buffer ("scale" , scale )
1309
+
1310
+ def init_stats (
1311
+ self ,
1312
+ num_iter : int ,
1313
+ reduce_dim : Union [int , Tuple [int ]] = 0 ,
1314
+ key : Optional [str ] = None ,
1315
+ ) -> None :
1316
+ """Initializes the loc and scale stats of the parent environment.
1317
+
1318
+ Normalization constant should ideally make the observation statistics approach
1319
+ those of a standard Gaussian distribution. This method computes a location
1320
+ and scale tensor that will empirically compute the mean and standard
1321
+ deviation of a Gaussian distribution fitted on data generated randomly with
1322
+ the parent environment for a given number of steps.
1323
+
1324
+ Args:
1325
+ num_iter (int): number of random iterations to run in the environment.
1326
+ reduce_dim (int, optional): dimension to compute the mean and std over.
1327
+ Defaults to 0.
1328
+ key (str, optional): if provided, the summary statistics will be
1329
+ retrieved from that key in the resulting tensordicts.
1330
+ Otherwise, the first key in :obj:`ObservationNorm.in_keys` will be used.
1331
+
1332
+ """
1333
+ if self .loc is not None or self .scale is not None :
1334
+ raise RuntimeError (
1335
+ f"Loc/Scale are already initialized: ({ self .loc } , { self .scale } )"
1336
+ )
1337
+
1338
+ if len (self .in_keys ) > 1 and key is None :
1339
+ raise RuntimeError (
1340
+ "Transform has multiple in_keys but no specific key was passed as an argument"
1341
+ )
1342
+ key = self .in_keys [0 ] if key is None else key
1343
+
1344
+ parent = self .parent
1345
+ collected_frames = 0
1346
+ data = []
1347
+ while collected_frames < num_iter :
1348
+ tensordict = parent .rollout (max_steps = num_iter )
1349
+ collected_frames += tensordict .numel ()
1350
+ data .append (tensordict .get (key ))
1351
+
1352
+ data = torch .cat (data , reduce_dim )
1353
+ loc = data .mean (reduce_dim )
1354
+ scale = data .std (reduce_dim )
1355
+
1356
+ if not self .standard_normal :
1357
+ loc = loc / scale
1358
+ scale = 1 / scale
1359
+
1360
+ self .register_buffer ("loc" , loc )
1361
+ self .register_buffer ("scale" , scale .clamp_min (self .eps ))
1292
1362
1293
1363
def _apply_transform (self , obs : torch .Tensor ) -> torch .Tensor :
1364
+ if self .loc is None or self .scale is None :
1365
+ raise RuntimeError (
1366
+ "Loc/Scale have not been initialized. Either pass in values in the constructor "
1367
+ "or call the init_stats method"
1368
+ )
1294
1369
if self .standard_normal :
1295
1370
loc = self .loc
1296
1371
scale = self .scale
0 commit comments