Skip to content

Commit 2d1723c

Browse files
authored
[Feature, BugFix] ObservationNorm keep_dims and RewardSum init (#839)
1 parent 9b11d25 commit 2d1723c

File tree

2 files changed

+109
-10
lines changed

2 files changed

+109
-10
lines changed

test/test_transforms.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,59 @@ def make_env():
11141114
t_env.transform.loc.device == t_env.observation_spec["observation"].device
11151115
)
11161116

1117+
@pytest.mark.parametrize("keys", [["pixels"], ["pixels", "stuff"]])
1118+
@pytest.mark.parametrize("size", [1, 3])
1119+
@pytest.mark.parametrize("device", get_available_devices())
1120+
@pytest.mark.parametrize("standard_normal", [True, False])
1121+
@pytest.mark.parametrize("parallel", [True, False])
1122+
def test_observationnorm_init_stats_pixels(
1123+
self, keys, size, device, standard_normal, parallel
1124+
):
1125+
def make_env():
1126+
base_env = DiscreteActionConvMockEnvNumpy(
1127+
seed=0,
1128+
)
1129+
base_env.out_key = "pixels"
1130+
return base_env
1131+
1132+
if parallel:
1133+
base_env = SerialEnv(3, make_env)
1134+
reduce_dim = (0, 1, 3, 4)
1135+
keep_dim = (3, 4)
1136+
cat_dim = 1
1137+
else:
1138+
base_env = make_env()
1139+
reduce_dim = (0, 2, 3)
1140+
keep_dim = (2, 3)
1141+
cat_dim = 0
1142+
1143+
t_env = TransformedEnv(
1144+
base_env,
1145+
transform=ObservationNorm(in_keys=keys, standard_normal=standard_normal),
1146+
)
1147+
if len(keys) > 1:
1148+
t_env.transform.init_stats(
1149+
num_iter=11,
1150+
key="pixels",
1151+
cat_dim=cat_dim,
1152+
reduce_dim=reduce_dim,
1153+
keep_dims=keep_dim,
1154+
)
1155+
else:
1156+
t_env.transform.init_stats(
1157+
num_iter=11,
1158+
reduce_dim=reduce_dim,
1159+
cat_dim=cat_dim,
1160+
keep_dims=keep_dim,
1161+
)
1162+
1163+
assert t_env.transform.loc.shape == torch.Size(
1164+
[t_env.observation_spec["pixels"].shape[0], 1, 1]
1165+
)
1166+
assert t_env.transform.scale.shape == torch.Size(
1167+
[t_env.observation_spec["pixels"].shape[0], 1, 1]
1168+
)
1169+
11171170
def test_observationnorm_stats_already_initialized_error(self):
11181171
transform = ObservationNorm(in_keys="next_observation", loc=0, scale=1)
11191172

torchrl/envs/transforms/transforms.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,7 @@ def init_stats(
13751375
reduce_dim: Union[int, Tuple[int]] = 0,
13761376
cat_dim: Optional[int] = None,
13771377
key: Optional[str] = None,
1378+
keep_dims: Optional[Tuple[int]] = None,
13781379
) -> None:
13791380
"""Initializes the loc and scale stats of the parent environment.
13801381
@@ -1394,6 +1395,10 @@ def init_stats(
13941395
key (str, optional): if provided, the summary statistics will be
13951396
retrieved from that key in the resulting tensordicts.
13961397
Otherwise, the first key in :obj:`ObservationNorm.in_keys` will be used.
1398+
keep_dims (tuple of int, optional): the dimensions to keep in the loc and scale.
1399+
For instance, one may want the location and scale to have shape [C, 1, 1]
1400+
when normalizing a 3D tensor over the last two dimensions, but not the
1401+
third. Defaults to None.
13971402
13981403
"""
13991404
if cat_dim is None:
@@ -1440,12 +1445,23 @@ def raise_initialization_exception(module):
14401445
data.append(tensordict.get(key))
14411446

14421447
data = torch.cat(data, cat_dim)
1443-
loc = data.mean(reduce_dim)
1444-
scale = data.std(reduce_dim)
1448+
if isinstance(reduce_dim, int):
1449+
reduce_dim = [reduce_dim]
1450+
if keep_dims is not None:
1451+
if not all(k in reduce_dim for k in keep_dims):
1452+
raise ValueError("keep_dim elements must be part of reduce_dim list.")
1453+
else:
1454+
keep_dims = []
1455+
loc = data.mean(reduce_dim, keepdim=True)
1456+
scale = data.std(reduce_dim, keepdim=True)
1457+
for r in sorted(reduce_dim, reverse=True):
1458+
if r not in keep_dims:
1459+
loc = loc.squeeze(r)
1460+
scale = scale.squeeze(r)
14451461

14461462
if not self.standard_normal:
1447-
loc = loc / scale
1448-
scale = 1 / scale
1463+
scale = 1 / scale.clamp_min(self.eps)
1464+
loc = -loc * scale
14491465

14501466
if not torch.isfinite(loc).all():
14511467
raise RuntimeError("Non-finite values found in loc")
@@ -2516,9 +2532,22 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
25162532
"""Resets episode rewards."""
25172533
# Non-batched environments
25182534
if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1:
2519-
for out_key in self.out_keys:
2535+
for in_key, out_key in zip(self.in_keys, self.out_keys):
25202536
if out_key in tensordict.keys():
2521-
tensordict[out_key] = 0.0
2537+
tensordict[out_key] = torch.zeros_like(tensordict[out_key])
2538+
elif in_key == "reward":
2539+
tensordict[out_key] = self.parent.reward_spec.zero()
2540+
else:
2541+
try:
2542+
tensordict[out_key] = self.parent.observation_spec[
2543+
in_key
2544+
].zero()
2545+
except KeyError as err:
2546+
raise KeyError(
2547+
f"The key {in_key} was not found in the parent "
2548+
f"observation_spec with keys "
2549+
f"{list(self.parent.observation_spec.keys())}. "
2550+
) from err
25222551

25232552
# Batched environments
25242553
else:
@@ -2530,9 +2559,27 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
25302559
device=tensordict.device,
25312560
),
25322561
)
2533-
for out_key in self.out_keys:
2562+
for in_key, out_key in zip(self.in_keys, self.out_keys):
25342563
if out_key in tensordict.keys():
2535-
tensordict[out_key][_reset] = 0.0
2564+
z = torch.zeros_like(tensordict[out_key])
2565+
_reset = _reset.view_as(z)
2566+
tensordict[out_key][_reset] = z[_reset]
2567+
elif in_key == "reward":
2568+
# Since the episode reward is not in the tensordict, we need to allocate it
2569+
# with zeros entirely (regardless of the _reset mask)
2570+
z = self.parent.reward_spec.zero(self.parent.batch_size)
2571+
tensordict[out_key] = z
2572+
else:
2573+
try:
2574+
tensordict[out_key] = self.parent.observation_spec[in_key].zero(
2575+
self.parent.batch_size
2576+
)
2577+
except KeyError as err:
2578+
raise KeyError(
2579+
f"The key {in_key} was not found in the parent "
2580+
f"observation_spec with keys "
2581+
f"{list(self.parent.observation_spec.keys())}. "
2582+
) from err
25362583

25372584
return tensordict
25382585

@@ -2554,8 +2601,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
25542601
*tensordict.shape, 1, dtype=reward.dtype, device=reward.device
25552602
),
25562603
)
2557-
tensordict[out_key] += reward
2558-
2604+
tensordict[out_key] = tensordict[out_key] + reward
25592605
return tensordict
25602606

25612607
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:

0 commit comments

Comments
 (0)