Skip to content

Commit b08e7ac

Browse files
author
Vincent Moens
committed
[Feature] VecNormV2: Usage with batched envs
ghstack-source-id: 5e14ed9 Pull Request resolved: #2901
1 parent c3310b8 commit b08e7ac

File tree

4 files changed

+177
-29
lines changed

4 files changed

+177
-29
lines changed

test/test_libs.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,6 @@ def get_gym_pixel_wrapper(): # noqa: F811
237237

238238
gym_version = version.parse(gym_version)
239239

240-
if _has_dmc:
241-
from dm_control import suite
242-
from dm_control.suite.wrappers import pixels
243-
244-
if _has_vmas:
245-
import vmas
246-
247240

248241
if _has_envpool:
249242
import envpool
@@ -1058,7 +1051,8 @@ def test_vecenvs_wrapper(self, envname):
10581051

10591052
with set_gym_backend("gymnasium"):
10601053
self._test_vecenvs_wrapper(
1061-
envname, kwargs={"reset_mode": gymnasium.vector.AutoresetMode.SAME_STEP}
1054+
envname,
1055+
kwargs={"autoreset_mode": gymnasium.vector.AutoresetMode.SAME_STEP},
10621056
)
10631057

10641058
@implement_for("gymnasium", None, "1.0.0")
@@ -1747,8 +1741,12 @@ def test_dmcontrol(self, env_name, task, frame_skip, from_pixels, pixels_only):
17471741
assert final_seed0 == final_seed1
17481742
assert_allclose_td(rollout0, rollout1)
17491743

1744+
from dm_control import suite
1745+
17501746
base_env = suite.load(env_name, task)
17511747
if from_pixels:
1748+
from dm_control.suite.wrappers import pixels
1749+
17521750
render_kwargs = {"camera_id": 0}
17531751
base_env = pixels.Wrapper(
17541752
base_env, pixels_only=pixels_only, render_kwargs=render_kwargs
@@ -2634,6 +2632,8 @@ def test_vmas_batch_size(self, scenario_name, num_envs, n_agents):
26342632
def test_vmas_spec_rollout(
26352633
self, scenario_name, num_envs, n_agents, continuous_actions
26362634
):
2635+
import vmas
2636+
26372637
vmas_env = VmasEnv(
26382638
scenario=scenario_name,
26392639
num_envs=num_envs,

test/test_transforms.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9520,12 +9520,17 @@ def _step(
95209520
def _set_seed(self, seed: int | None):
95219521
...
95229522

9523-
def test_vecnorm2_decay1(self):
9524-
env = self.SimpleEnv().append_transform(
9523+
@pytest.mark.parametrize("batched", [False, True])
9524+
def test_vecnorm2_decay1(self, batched):
9525+
env = self.SimpleEnv()
9526+
if batched:
9527+
env = SerialEnv(2, [lambda env=env: env] * 2)
9528+
env = env.append_transform(
95259529
VecNormV2(
95269530
in_keys=["reward", "observation"],
95279531
out_keys=["reward_norm", "obs_norm"],
95289532
decay=1,
9533+
reduce_batch_dims=True,
95299534
)
95309535
)
95319536
s_ = env.reset()
@@ -9537,21 +9542,25 @@ def test_vecnorm2_decay1(self):
95379542
sstack = torch.stack(ss)
95389543
if i >= 2:
95399544
for k in ("reward",):
9540-
loc = sstack[: i + 1]["next", k].mean(0)
9545+
loc = sstack[: i + 1]["next", k].mean().unsqueeze(-1)
95419546
scale = (
95429547
sstack[: i + 1]["next", k]
9543-
.std(0, unbiased=False)
9548+
.std(unbiased=False)
95449549
.clamp_min(1e-6)
9550+
.unsqueeze(-1)
95459551
)
95469552
# Assert that loc and scale match the expected values
95479553
torch.testing.assert_close(
95489554
loc,
95499555
env.transform.loc[k],
9550-
), f"Loc mismatch at step {i}"
9556+
)
95519557
torch.testing.assert_close(
95529558
scale,
95539559
env.transform.scale[k],
9554-
), f"Scale mismatch at step {i}"
9560+
)
9561+
if batched:
9562+
assert env.transform._loc.ndim == 0
9563+
assert env.transform._var.ndim == 0
95559564

95569565
@pytest.mark.skipif(not _has_gym, reason="gym not available")
95579566
@pytest.mark.parametrize("stateful", [True, False])

torchrl/envs/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ def ndim(self):
853853
def append_transform(
854854
self,
855855
transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821
856-
) -> EnvBase:
856+
) -> torchrl.envs.TransformedEnv: # noqa
857857
"""Returns a transformed environment where the callable/transform passed is applied.
858858
859859
Args:

torchrl/envs/transforms/vecnorm.py

Lines changed: 153 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,24 @@ class VecNormV2(Transform):
5555
out_keys (Sequence[NestedKey] | None): The output keys for the normalized data. Defaults to `in_keys` if
5656
not provided.
5757
lock (mp.Lock, optional): A lock for thread safety.
58-
stateful (bool, optional): Whether the `VecNorm` is stateful. Defaults to `True`.
58+
stateful (bool, optional): Whether the `VecNorm` is stateful. Stateless versions of this
59+
transform requires the data to be carried within the input/output tensordicts.
60+
Defaults to `True`.
5961
decay (float, optional): The decay rate for updating statistics. Defaults to `0.9999`.
62+
If `decay=1` is used, the normalizing statistics have an infinite memory (each item is weighed
63+
identically). Lower values weigh recent data more than old ones.
6064
eps (float, optional): A small value to prevent division by zero. Defaults to `1e-4`.
61-
shapes (list[torch.Size], optional): The shapes of the inputs. Defaults to `None`.
6265
shared_data (TensorDictBase | None, optional): Shared data for initialization. Defaults to `None`.
66+
reduce_batch_dims (bool, optional): If `True`, the batch dimensions are reduced by averaging the data
67+
before updating the statistics. This is useful when samples are received in batches, as it allows
68+
the moving average to be computed over the entire batch rather than individual elements. Note that
69+
this option is only supported in stateful mode (`stateful=True`). Defaults to `False`.
6370
6471
Attributes:
6572
stateful (bool): Indicates whether the VecNormV2 is stateful or stateless.
6673
lock (mp.Lock): A multiprocessing lock to ensure thread safety when updating statistics.
6774
decay (float): The decay rate for updating statistics.
6875
eps (float): A small value to prevent division by zero during normalization.
69-
shapes (list[torch.Size]): The shapes of the inputs to be normalized.
7076
frozen (bool): Indicates whether the VecNormV2 is frozen, preventing updates to statistics.
7177
_cast_int_to_float (bool): Indicates whether integer inputs should be cast to float.
7278
@@ -99,6 +105,116 @@ class VecNormV2(Transform):
99105
100106
.. seealso:: :class:`~torchrl.envs.transforms.VecNorm` for the first version of this transform.
101107
108+
Examples:
109+
>>> import torch
110+
>>> from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, SerialEnv, VecNormV2
111+
>>>
112+
>>> torch.manual_seed(0)
113+
>>> env = GymEnv("Pendulum-v1")
114+
>>> env_trsf = env.append_transform(
115+
>>> VecNormV2(in_keys=["observation", "reward"], out_keys=["observation_norm", "reward_norm"])
116+
>>> )
117+
>>> r = env_trsf.rollout(10)
118+
>>> print("Unnormalized rewards", r["next", "reward"])
119+
Unnormalized rewards tensor([[ -1.7967],
120+
[ -2.1238],
121+
[ -2.5911],
122+
[ -3.5275],
123+
[ -4.8585],
124+
[ -6.5028],
125+
[ -8.2505],
126+
[-10.3169],
127+
[-12.1332],
128+
[-13.1235]])
129+
>>> print("Normalized rewards", r["next", "reward_norm"])
130+
Normalized rewards tensor([[-1.6596e-04],
131+
[-8.3072e-02],
132+
[-1.9170e-01],
133+
[-3.9255e-01],
134+
[-5.9131e-01],
135+
[-7.4671e-01],
136+
[-8.3760e-01],
137+
[-9.2058e-01],
138+
[-9.3484e-01],
139+
[-8.6185e-01]])
140+
>>> # Aggregate values when using batched envs
141+
>>> env = SerialEnv(2, [lambda: GymEnv("Pendulum-v1")] * 2)
142+
>>> env_trsf = env.append_transform(
143+
>>> VecNormV2(
144+
>>> in_keys=["observation", "reward"],
145+
>>> out_keys=["observation_norm", "reward_norm"],
146+
>>> # Use reduce_batch_dims=True to aggregate values across batch elements
147+
>>> reduce_batch_dims=True, )
148+
>>> )
149+
>>> r = env_trsf.rollout(10)
150+
>>> print("Unnormalized rewards", r["next", "reward"])
151+
Unnormalized rewards tensor([[[-0.1456],
152+
[-0.1862],
153+
[-0.2053],
154+
[-0.2605],
155+
[-0.4046],
156+
[-0.5185],
157+
[-0.8023],
158+
[-1.1364],
159+
[-1.6183],
160+
[-2.5406]],
161+
162+
[[-0.0920],
163+
[-0.1492],
164+
[-0.2702],
165+
[-0.3917],
166+
[-0.5001],
167+
[-0.7947],
168+
[-1.0160],
169+
[-1.3347],
170+
[-1.9082],
171+
[-2.9679]]])
172+
>>> print("Normalized rewards", r["next", "reward_norm"])
173+
Normalized rewards tensor([[[-0.2199],
174+
[-0.2918],
175+
[-0.1668],
176+
[-0.2083],
177+
[-0.4981],
178+
[-0.5046],
179+
[-0.7950],
180+
[-0.9791],
181+
[-1.1484],
182+
[-1.4182]],
183+
184+
[[ 0.2201],
185+
[-0.0403],
186+
[-0.5206],
187+
[-0.7791],
188+
[-0.8282],
189+
[-1.2306],
190+
[-1.2279],
191+
[-1.2907],
192+
[-1.4929],
193+
[-1.7793]]])
194+
>>> print("Loc / scale", env_trsf.transform.loc["reward"], env_trsf.transform.scale["reward"])
195+
Loc / scale tensor([-0.8626]) tensor([1.1832])
196+
>>>
197+
>>> # Share values between workers
198+
>>> def make_env():
199+
... env = GymEnv("Pendulum-v1")
200+
... env_trsf = env.append_transform(
201+
... VecNormV2(in_keys=["observation", "reward"], out_keys=["observation_norm", "reward_norm"])
202+
... )
203+
... return env_trsf
204+
...
205+
...
206+
>>> if __name__ == "__main__":
207+
... # EnvCreator will share the loc/scale vals
208+
... make_env = EnvCreator(make_env)
209+
... # Create a local env to track the loc/scale
210+
... local_env = make_env()
211+
... env = ParallelEnv(2, [make_env] * 2)
212+
... r = env.rollout(10)
213+
... # Non-zero loc and scale testify that the sub-envs share their summary stats with us
214+
... print("Remotely updated loc / scale", local_env.transform.loc["reward"], local_env.transform.scale["reward"])
215+
Remotely updated loc / scale tensor([-0.4307]) tensor([0.9613])
216+
... env.close()
217+
102218
"""
103219

104220
# TODO:
@@ -114,8 +230,8 @@ def __init__(
114230
stateful: bool = True,
115231
decay: float = 0.9999,
116232
eps: float = 1e-4,
117-
shapes: list[torch.Size] = None,
118233
shared_data: TensorDictBase | None = None,
234+
reduce_batch_dims: bool = False,
119235
) -> None:
120236
self.stateful = stateful
121237
if lock is None:
@@ -126,7 +242,6 @@ def __init__(
126242

127243
self.lock = lock
128244
self.decay = decay
129-
self.shapes = shapes
130245
self.eps = eps
131246
self.frozen = False
132247
self._cast_int_to_float = False
@@ -145,6 +260,11 @@ def __init__(
145260
if shared_data:
146261
# FIXME
147262
raise NotImplementedError
263+
if reduce_batch_dims and not stateful:
264+
raise RuntimeError(
265+
"reduce_batch_dims=True and stateful=False are not supported."
266+
)
267+
self.reduce_batch_dims = reduce_batch_dims
148268

149269
@property
150270
def in_keys(self) -> Sequence[NestedKey]:
@@ -306,7 +426,9 @@ def _maybe_stateful_init(self, data):
306426
)
307427
data_select = data_select.update(data)
308428
data_select = data_select.select(*self._in_keys_safe, strict=True)
309-
429+
if self.reduce_batch_dims and data_select.ndim:
430+
# collapse the batch-dims
431+
data_select = data_select.mean(dim=tuple(range(data.ndim)))
310432
# For the count, we must use a TD because some keys (eg Reward) may be missing at some steps (eg, reset)
311433
# We use mean() to eliminate all dims - since it's local we don't need to expand the shape
312434
count = (
@@ -372,16 +494,33 @@ def _stateful_update(self, data):
372494
var = self._var
373495
loc = self._loc
374496
count = self._count
375-
count += 1
376497
data = self._maybe_cast_to_float(data)
377-
if self.decay != 1.0:
378-
weight = 1 - self.decay
379-
loc.lerp_(end=data, weight=weight)
380-
var.lerp_(end=data.pow(2), weight=weight)
498+
if self.reduce_batch_dims and data.ndim:
499+
# The naive way to do this would be to convert the data to a list and iterate over it, but (1) that is
500+
# slow, and (2) it makes the value of the loc/var conditioned on the order we take to iterate over the data.
501+
# The second approach would be to average the data, but that would mean that having one vecnorm per batched
502+
# env or one per sub-env will lead to different results as a batch of N elements will actually be
503+
# considered as a single one.
504+
# What we go for instead is to average the data (and its squared value) then do the moving average with
505+
# adapted decay.
506+
n = data.numel()
507+
count += n
508+
data2 = data.pow(2).mean(dim=tuple(range(data.ndim)))
509+
data_mean = data.mean(dim=tuple(range(data.ndim)))
510+
if self.decay != 1.0:
511+
weight = 1 - self.decay**n
512+
else:
513+
weight = n / count
381514
else:
382-
weight = 1 / count
383-
loc.lerp_(end=data, weight=weight)
384-
var.lerp_(end=data.pow(2), weight=weight)
515+
count += 1
516+
data2 = data.pow(2)
517+
data_mean = data
518+
if self.decay != 1.0:
519+
weight = 1 - self.decay
520+
else:
521+
weight = 1 / count
522+
loc.lerp_(end=data_mean, weight=weight)
523+
var.lerp_(end=data2, weight=weight)
385524

386525
def _maybe_stateless_init(self, data):
387526
if not self.initialized or f"{self.prefix}_loc" not in data.keys():

0 commit comments

Comments
 (0)