@@ -1375,6 +1375,7 @@ def init_stats(
1375
1375
reduce_dim : Union [int , Tuple [int ]] = 0 ,
1376
1376
cat_dim : Optional [int ] = None ,
1377
1377
key : Optional [str ] = None ,
1378
+ keep_dims : Optional [Tuple [int ]] = None ,
1378
1379
) -> None :
1379
1380
"""Initializes the loc and scale stats of the parent environment.
1380
1381
@@ -1394,6 +1395,10 @@ def init_stats(
1394
1395
key (str, optional): if provided, the summary statistics will be
1395
1396
retrieved from that key in the resulting tensordicts.
1396
1397
Otherwise, the first key in :obj:`ObservationNorm.in_keys` will be used.
1398
+ keep_dims (tuple of int, optional): the dimensions to keep in the loc and scale.
1399
+ For instance, one may want the location and scale to have shape [C, 1, 1]
1400
+ when normalizing a 3D tensor over the last two dimensions, but not the
1401
+ third. Defaults to None.
1397
1402
1398
1403
"""
1399
1404
if cat_dim is None :
@@ -1440,12 +1445,23 @@ def raise_initialization_exception(module):
1440
1445
data .append (tensordict .get (key ))
1441
1446
1442
1447
data = torch .cat (data , cat_dim )
1443
- loc = data .mean (reduce_dim )
1444
- scale = data .std (reduce_dim )
1448
+ if isinstance (reduce_dim , int ):
1449
+ reduce_dim = [reduce_dim ]
1450
+ if keep_dims is not None :
1451
+ if not all (k in reduce_dim for k in keep_dims ):
1452
+ raise ValueError ("keep_dim elements must be part of reduce_dim list." )
1453
+ else :
1454
+ keep_dims = []
1455
+ loc = data .mean (reduce_dim , keepdim = True )
1456
+ scale = data .std (reduce_dim , keepdim = True )
1457
+ for r in sorted (reduce_dim , reverse = True ):
1458
+ if r not in keep_dims :
1459
+ loc = loc .squeeze (r )
1460
+ scale = scale .squeeze (r )
1445
1461
1446
1462
if not self .standard_normal :
1447
- loc = loc / scale
1448
- scale = 1 / scale
1463
+ scale = 1 / scale . clamp_min ( self . eps )
1464
+ loc = - loc * scale
1449
1465
1450
1466
if not torch .isfinite (loc ).all ():
1451
1467
raise RuntimeError ("Non-finite values found in loc" )
@@ -2516,9 +2532,22 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
2516
2532
"""Resets episode rewards."""
2517
2533
# Non-batched environments
2518
2534
if len (tensordict .batch_size ) < 1 or tensordict .batch_size [0 ] == 1 :
2519
- for out_key in self .out_keys :
2535
+ for in_key , out_key in zip ( self .in_keys , self . out_keys ) :
2520
2536
if out_key in tensordict .keys ():
2521
- tensordict [out_key ] = 0.0
2537
+ tensordict [out_key ] = torch .zeros_like (tensordict [out_key ])
2538
+ elif in_key == "reward" :
2539
+ tensordict [out_key ] = self .parent .reward_spec .zero ()
2540
+ else :
2541
+ try :
2542
+ tensordict [out_key ] = self .parent .observation_spec [
2543
+ in_key
2544
+ ].zero ()
2545
+ except KeyError as err :
2546
+ raise KeyError (
2547
+ f"The key { in_key } was not found in the parent "
2548
+ f"observation_spec with keys "
2549
+ f"{ list (self .parent .observation_spec .keys ())} . "
2550
+ ) from err
2522
2551
2523
2552
# Batched environments
2524
2553
else :
@@ -2530,9 +2559,27 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
2530
2559
device = tensordict .device ,
2531
2560
),
2532
2561
)
2533
- for out_key in self .out_keys :
2562
+ for in_key , out_key in zip ( self .in_keys , self . out_keys ) :
2534
2563
if out_key in tensordict .keys ():
2535
- tensordict [out_key ][_reset ] = 0.0
2564
+ z = torch .zeros_like (tensordict [out_key ])
2565
+ _reset = _reset .view_as (z )
2566
+ tensordict [out_key ][_reset ] = z [_reset ]
2567
+ elif in_key == "reward" :
2568
+ # Since the episode reward is not in the tensordict, we need to allocate it
2569
+ # with zeros entirely (regardless of the _reset mask)
2570
+ z = self .parent .reward_spec .zero (self .parent .batch_size )
2571
+ tensordict [out_key ] = z
2572
+ else :
2573
+ try :
2574
+ tensordict [out_key ] = self .parent .observation_spec [in_key ].zero (
2575
+ self .parent .batch_size
2576
+ )
2577
+ except KeyError as err :
2578
+ raise KeyError (
2579
+ f"The key { in_key } was not found in the parent "
2580
+ f"observation_spec with keys "
2581
+ f"{ list (self .parent .observation_spec .keys ())} . "
2582
+ ) from err
2536
2583
2537
2584
return tensordict
2538
2585
@@ -2554,8 +2601,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
2554
2601
* tensordict .shape , 1 , dtype = reward .dtype , device = reward .device
2555
2602
),
2556
2603
)
2557
- tensordict [out_key ] += reward
2558
-
2604
+ tensordict [out_key ] = tensordict [out_key ] + reward
2559
2605
return tensordict
2560
2606
2561
2607
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
0 commit comments