@@ -55,18 +55,24 @@ class VecNormV2(Transform):
55
55
out_keys (Sequence[NestedKey] | None): The output keys for the normalized data. Defaults to `in_keys` if
56
56
not provided.
57
57
lock (mp.Lock, optional): A lock for thread safety.
58
- stateful (bool, optional): Whether the `VecNorm` is stateful. Defaults to `True`.
58
+ stateful (bool, optional): Whether the `VecNorm` is stateful. Stateless versions of this
59
+ transform requires the data to be carried within the input/output tensordicts.
60
+ Defaults to `True`.
59
61
decay (float, optional): The decay rate for updating statistics. Defaults to `0.9999`.
62
+ If `decay=1` is used, the normalizing statistics have an infinite memory (each item is weighed
63
+ identically). Lower values weigh recent data more than old ones.
60
64
eps (float, optional): A small value to prevent division by zero. Defaults to `1e-4`.
61
- shapes (list[torch.Size], optional): The shapes of the inputs. Defaults to `None`.
62
65
shared_data (TensorDictBase | None, optional): Shared data for initialization. Defaults to `None`.
66
+ reduce_batch_dims (bool, optional): If `True`, the batch dimensions are reduced by averaging the data
67
+ before updating the statistics. This is useful when samples are received in batches, as it allows
68
+ the moving average to be computed over the entire batch rather than individual elements. Note that
69
+ this option is only supported in stateful mode (`stateful=True`). Defaults to `False`.
63
70
64
71
Attributes:
65
72
stateful (bool): Indicates whether the VecNormV2 is stateful or stateless.
66
73
lock (mp.Lock): A multiprocessing lock to ensure thread safety when updating statistics.
67
74
decay (float): The decay rate for updating statistics.
68
75
eps (float): A small value to prevent division by zero during normalization.
69
- shapes (list[torch.Size]): The shapes of the inputs to be normalized.
70
76
frozen (bool): Indicates whether the VecNormV2 is frozen, preventing updates to statistics.
71
77
_cast_int_to_float (bool): Indicates whether integer inputs should be cast to float.
72
78
@@ -99,6 +105,116 @@ class VecNormV2(Transform):
99
105
100
106
.. seealso:: :class:`~torchrl.envs.transforms.VecNorm` for the first version of this transform.
101
107
108
+ Examples:
109
+ >>> import torch
110
+ >>> from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, SerialEnv, VecNormV2
111
+ >>>
112
+ >>> torch.manual_seed(0)
113
+ >>> env = GymEnv("Pendulum-v1")
114
+ >>> env_trsf = env.append_transform(
115
+ >>> VecNormV2(in_keys=["observation", "reward"], out_keys=["observation_norm", "reward_norm"])
116
+ >>> )
117
+ >>> r = env_trsf.rollout(10)
118
+ >>> print("Unnormalized rewards", r["next", "reward"])
119
+ Unnormalized rewards tensor([[ -1.7967],
120
+ [ -2.1238],
121
+ [ -2.5911],
122
+ [ -3.5275],
123
+ [ -4.8585],
124
+ [ -6.5028],
125
+ [ -8.2505],
126
+ [-10.3169],
127
+ [-12.1332],
128
+ [-13.1235]])
129
+ >>> print("Normalized rewards", r["next", "reward_norm"])
130
+ Normalized rewards tensor([[-1.6596e-04],
131
+ [-8.3072e-02],
132
+ [-1.9170e-01],
133
+ [-3.9255e-01],
134
+ [-5.9131e-01],
135
+ [-7.4671e-01],
136
+ [-8.3760e-01],
137
+ [-9.2058e-01],
138
+ [-9.3484e-01],
139
+ [-8.6185e-01]])
140
+ >>> # Aggregate values when using batched envs
141
+ >>> env = SerialEnv(2, [lambda: GymEnv("Pendulum-v1")] * 2)
142
+ >>> env_trsf = env.append_transform(
143
+ >>> VecNormV2(
144
+ >>> in_keys=["observation", "reward"],
145
+ >>> out_keys=["observation_norm", "reward_norm"],
146
+ >>> # Use reduce_batch_dims=True to aggregate values across batch elements
147
+ >>> reduce_batch_dims=True, )
148
+ >>> )
149
+ >>> r = env_trsf.rollout(10)
150
+ >>> print("Unnormalized rewards", r["next", "reward"])
151
+ Unnormalized rewards tensor([[[-0.1456],
152
+ [-0.1862],
153
+ [-0.2053],
154
+ [-0.2605],
155
+ [-0.4046],
156
+ [-0.5185],
157
+ [-0.8023],
158
+ [-1.1364],
159
+ [-1.6183],
160
+ [-2.5406]],
161
+
162
+ [[-0.0920],
163
+ [-0.1492],
164
+ [-0.2702],
165
+ [-0.3917],
166
+ [-0.5001],
167
+ [-0.7947],
168
+ [-1.0160],
169
+ [-1.3347],
170
+ [-1.9082],
171
+ [-2.9679]]])
172
+ >>> print("Normalized rewards", r["next", "reward_norm"])
173
+ Normalized rewards tensor([[[-0.2199],
174
+ [-0.2918],
175
+ [-0.1668],
176
+ [-0.2083],
177
+ [-0.4981],
178
+ [-0.5046],
179
+ [-0.7950],
180
+ [-0.9791],
181
+ [-1.1484],
182
+ [-1.4182]],
183
+
184
+ [[ 0.2201],
185
+ [-0.0403],
186
+ [-0.5206],
187
+ [-0.7791],
188
+ [-0.8282],
189
+ [-1.2306],
190
+ [-1.2279],
191
+ [-1.2907],
192
+ [-1.4929],
193
+ [-1.7793]]])
194
+ >>> print("Loc / scale", env_trsf.transform.loc["reward"], env_trsf.transform.scale["reward"])
195
+ Loc / scale tensor([-0.8626]) tensor([1.1832])
196
+ >>>
197
+ >>> # Share values between workers
198
+ >>> def make_env():
199
+ ... env = GymEnv("Pendulum-v1")
200
+ ... env_trsf = env.append_transform(
201
+ ... VecNormV2(in_keys=["observation", "reward"], out_keys=["observation_norm", "reward_norm"])
202
+ ... )
203
+ ... return env_trsf
204
+ ...
205
+ ...
206
+ >>> if __name__ == "__main__":
207
+ ... # EnvCreator will share the loc/scale vals
208
+ ... make_env = EnvCreator(make_env)
209
+ ... # Create a local env to track the loc/scale
210
+ ... local_env = make_env()
211
+ ... env = ParallelEnv(2, [make_env] * 2)
212
+ ... r = env.rollout(10)
213
+ ... # Non-zero loc and scale testify that the sub-envs share their summary stats with us
214
+ ... print("Remotely updated loc / scale", local_env.transform.loc["reward"], local_env.transform.scale["reward"])
215
+ Remotely updated loc / scale tensor([-0.4307]) tensor([0.9613])
216
+ ... env.close()
217
+
102
218
"""
103
219
104
220
# TODO:
@@ -114,8 +230,8 @@ def __init__(
114
230
stateful : bool = True ,
115
231
decay : float = 0.9999 ,
116
232
eps : float = 1e-4 ,
117
- shapes : list [torch .Size ] = None ,
118
233
shared_data : TensorDictBase | None = None ,
234
+ reduce_batch_dims : bool = False ,
119
235
) -> None :
120
236
self .stateful = stateful
121
237
if lock is None :
@@ -126,7 +242,6 @@ def __init__(
126
242
127
243
self .lock = lock
128
244
self .decay = decay
129
- self .shapes = shapes
130
245
self .eps = eps
131
246
self .frozen = False
132
247
self ._cast_int_to_float = False
@@ -145,6 +260,11 @@ def __init__(
145
260
if shared_data :
146
261
# FIXME
147
262
raise NotImplementedError
263
+ if reduce_batch_dims and not stateful :
264
+ raise RuntimeError (
265
+ "reduce_batch_dims=True and stateful=False are not supported."
266
+ )
267
+ self .reduce_batch_dims = reduce_batch_dims
148
268
149
269
@property
150
270
def in_keys (self ) -> Sequence [NestedKey ]:
@@ -306,7 +426,9 @@ def _maybe_stateful_init(self, data):
306
426
)
307
427
data_select = data_select .update (data )
308
428
data_select = data_select .select (* self ._in_keys_safe , strict = True )
309
-
429
+ if self .reduce_batch_dims and data_select .ndim :
430
+ # collapse the batch-dims
431
+ data_select = data_select .mean (dim = tuple (range (data .ndim )))
310
432
# For the count, we must use a TD because some keys (eg Reward) may be missing at some steps (eg, reset)
311
433
# We use mean() to eliminate all dims - since it's local we don't need to expand the shape
312
434
count = (
@@ -372,16 +494,33 @@ def _stateful_update(self, data):
372
494
var = self ._var
373
495
loc = self ._loc
374
496
count = self ._count
375
- count += 1
376
497
data = self ._maybe_cast_to_float (data )
377
- if self .decay != 1.0 :
378
- weight = 1 - self .decay
379
- loc .lerp_ (end = data , weight = weight )
380
- var .lerp_ (end = data .pow (2 ), weight = weight )
498
+ if self .reduce_batch_dims and data .ndim :
499
+ # The naive way to do this would be to convert the data to a list and iterate over it, but (1) that is
500
+ # slow, and (2) it makes the value of the loc/var conditioned on the order we take to iterate over the data.
501
+ # The second approach would be to average the data, but that would mean that having one vecnorm per batched
502
+ # env or one per sub-env will lead to different results as a batch of N elements will actually be
503
+ # considered as a single one.
504
+ # What we go for instead is to average the data (and its squared value) then do the moving average with
505
+ # adapted decay.
506
+ n = data .numel ()
507
+ count += n
508
+ data2 = data .pow (2 ).mean (dim = tuple (range (data .ndim )))
509
+ data_mean = data .mean (dim = tuple (range (data .ndim )))
510
+ if self .decay != 1.0 :
511
+ weight = 1 - self .decay ** n
512
+ else :
513
+ weight = n / count
381
514
else :
382
- weight = 1 / count
383
- loc .lerp_ (end = data , weight = weight )
384
- var .lerp_ (end = data .pow (2 ), weight = weight )
515
+ count += 1
516
+ data2 = data .pow (2 )
517
+ data_mean = data
518
+ if self .decay != 1.0 :
519
+ weight = 1 - self .decay
520
+ else :
521
+ weight = 1 / count
522
+ loc .lerp_ (end = data_mean , weight = weight )
523
+ var .lerp_ (end = data2 , weight = weight )
385
524
386
525
def _maybe_stateless_init (self , data ):
387
526
if not self .initialized or f"{ self .prefix } _loc" not in data .keys ():
0 commit comments