Skip to content

Commit c3310b8

Browse files
author
Vincent Moens
committed
[Refactor] VecNormV2: update before norm, bias_correction at the right time
ghstack-source-id: a90aeb2 Pull Request resolved: #2900
1 parent 9e3c4df commit c3310b8

File tree

3 files changed

+121
-47
lines changed

3 files changed

+121
-47
lines changed

test/test_rb.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,6 @@
1818
import pytest
1919
import torch
2020

21-
if os.getenv("PYTORCH_TEST_FBCODE"):
22-
from pytorch.rl.test._utils_internal import (
23-
capture_log_records,
24-
CARTPOLE_VERSIONED,
25-
get_default_devices,
26-
make_tc,
27-
)
28-
from pytorch.rl.test.mocking_classes import CountingEnv
29-
else:
30-
from _utils_internal import (
31-
capture_log_records,
32-
CARTPOLE_VERSIONED,
33-
get_default_devices,
34-
make_tc,
35-
)
36-
from mocking_classes import CountingEnv
37-
3821
from packaging import version
3922
from packaging.version import parse
4023
from tensordict import (
@@ -124,6 +107,23 @@
124107
)
125108

126109

110+
if os.getenv("PYTORCH_TEST_FBCODE"):
111+
from pytorch.rl.test._utils_internal import (
112+
capture_log_records,
113+
CARTPOLE_VERSIONED,
114+
get_default_devices,
115+
make_tc,
116+
)
117+
from pytorch.rl.test.mocking_classes import CountingEnv
118+
else:
119+
from _utils_internal import (
120+
capture_log_records,
121+
CARTPOLE_VERSIONED,
122+
get_default_devices,
123+
make_tc,
124+
)
125+
from mocking_classes import CountingEnv
126+
127127
OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
128128
_has_tv = importlib.util.find_spec("torchvision") is not None
129129
_has_gym = importlib.util.find_spec("gym") is not None

test/test_transforms.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9490,11 +9490,68 @@ def test_vc1_spec_against_real(self, del_keys, device):
94909490
class TestVecNormV2:
94919491
SEED = -1
94929492

9493-
# @pytest.fixture(scope="class")
9494-
# def set_dtype(self):
9495-
# def_dtype = torch.get_default_dtype()
9496-
# yield torch.set_default_dtype(torch.double)
9497-
# torch.set_default_dtype(def_dtype)
9493+
class SimpleEnv(EnvBase):
9494+
def __init__(self, **kwargs):
9495+
super().__init__(**kwargs)
9496+
self.full_reward_spec = Composite(reward=Unbounded((1,)))
9497+
self.full_observation_spec = Composite(observation=Unbounded(()))
9498+
self.full_action_spec = Composite(action=Unbounded(()))
9499+
9500+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
9501+
tensordict = (
9502+
TensorDict()
9503+
.update(self.full_observation_spec.rand())
9504+
.update(self.full_done_spec.zero())
9505+
)
9506+
return tensordict
9507+
9508+
def _step(
9509+
self,
9510+
tensordict: TensorDictBase,
9511+
) -> TensorDictBase:
9512+
tensordict = (
9513+
TensorDict()
9514+
.update(self.full_observation_spec.rand())
9515+
.update(self.full_done_spec.zero())
9516+
)
9517+
tensordict["reward"] = self.reward_spec.rand()
9518+
return tensordict
9519+
9520+
def _set_seed(self, seed: int | None):
9521+
...
9522+
9523+
def test_vecnorm2_decay1(self):
9524+
env = self.SimpleEnv().append_transform(
9525+
VecNormV2(
9526+
in_keys=["reward", "observation"],
9527+
out_keys=["reward_norm", "obs_norm"],
9528+
decay=1,
9529+
)
9530+
)
9531+
s_ = env.reset()
9532+
ss = []
9533+
N = 20
9534+
for i in range(N):
9535+
s, s_ = env.step_and_maybe_reset(env.rand_action(s_))
9536+
ss.append(s)
9537+
sstack = torch.stack(ss)
9538+
if i >= 2:
9539+
for k in ("reward",):
9540+
loc = sstack[: i + 1]["next", k].mean(0)
9541+
scale = (
9542+
sstack[: i + 1]["next", k]
9543+
.std(0, unbiased=False)
9544+
.clamp_min(1e-6)
9545+
)
9546+
# Assert that loc and scale match the expected values
9547+
torch.testing.assert_close(
9548+
loc,
9549+
env.transform.loc[k],
9550+
), f"Loc mismatch at step {i}"
9551+
torch.testing.assert_close(
9552+
scale,
9553+
env.transform.scale[k],
9554+
), f"Scale mismatch at step {i}"
94989555

94999556
@pytest.mark.skipif(not _has_gym, reason="gym not available")
95009557
@pytest.mark.parametrize("stateful", [True, False])
@@ -9906,14 +9963,14 @@ def test_to_obsnorm_multikeys(self):
99069963
{"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4]
99079964
)
99089965
td0 = transform0._step(td, td.clone())
9909-
td0.update(transform0[0]._stateful_norm(td.select(*transform0[0].in_keys)))
9966+
# td0.update(transform0[0]._stateful_norm(td.select(*transform0[0].in_keys)))
99109967
td1 = transform0[0].to_observation_norm()._step(td, td.clone())
99119968
assert_allclose_td(td0, td1)
99129969

99139970
loc = transform0[0].loc
99149971
scale = transform0[0].scale
99159972
keys = list(transform0[0].in_keys)
9916-
td2 = (td.select(*keys) - loc) / (scale + torch.finfo(scale.dtype).eps)
9973+
td2 = (td.select(*keys) - loc) / (scale.clamp_min(torch.finfo(scale.dtype).eps))
99179974
td2.rename_key_("a", "a_avg")
99189975
td2.rename_key_(("b", "c"), ("b", "c_avg"))
99199976
assert_allclose_td(td0.select(*td2.keys(True, True)), td2)
@@ -9928,16 +9985,16 @@ def test_frozen(self):
99289985
transform0.frozen_copy()
99299986
td = TensorDict({"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4])
99309987
td0 = transform0._step(td, td.clone())
9931-
td0.update(transform0._stateful_norm(td0.select(*transform0.in_keys)))
9988+
# td0.update(transform0._stateful_norm(td0.select(*transform0.in_keys)))
99329989

99339990
transform1 = transform0.frozen_copy()
99349991
td1 = transform1._step(td, td.clone())
99359992
assert_allclose_td(td0, td1)
99369993

99379994
td += 1
99389995
td2 = transform0._step(td, td.clone())
9939-
td3 = transform1._step(td, td.clone())
9940-
assert_allclose_td(td2, td3)
9996+
transform1._step(td, td.clone())
9997+
# assert_allclose_td(td2, td3)
99419998
with pytest.raises(AssertionError):
99429999
assert_allclose_td(td0, td2)
994310000

torchrl/envs/transforms/vecnorm.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ def _step(
248248
)
249249
if self.missing_tolerance and next_tensordict_select.is_empty():
250250
return next_tensordict
251-
next_tensordict_norm = self._stateful_norm(next_tensordict_select)
252251
self._stateful_update(next_tensordict_select)
252+
next_tensordict_norm = self._stateful_norm(next_tensordict_select)
253253
else:
254254
self._maybe_stateless_init(tensordict)
255255
next_tensordict_select = next_tensordict.select(
@@ -261,10 +261,10 @@ def _step(
261261
var = tensordict[f"{self.prefix}_var"]
262262
count = tensordict[f"{self.prefix}_count"]
263263

264-
next_tensordict_norm = self._stateless_norm(
264+
loc, var, count = self._stateless_update(
265265
next_tensordict_select, loc, var, count
266266
)
267-
loc, var, count = self._stateless_update(
267+
next_tensordict_norm = self._stateless_norm(
268268
next_tensordict_select, loc, var, count
269269
)
270270
# updates have been done in-place, we're good
@@ -328,27 +328,38 @@ def _in_keys_safe(self):
328328
return self.in_keys[:-3]
329329
return self.in_keys
330330

331-
def _norm(self, data, loc, var):
331+
def _norm(self, data, loc, var, count):
332332
if self.missing_tolerance:
333333
loc = loc.select(*data.keys(True, True))
334334
var = var.select(*data.keys(True, True))
335+
count = count.select(*data.keys(True, True))
335336
if loc.is_empty():
336337
return data
337338

339+
if self.decay < 1.0:
340+
bias_correction = 1 - (count * math.log(self.decay)).exp()
341+
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
342+
else:
343+
bias_correction = 1
344+
338345
var = var - loc.pow(2)
346+
loc = loc / bias_correction
347+
var = var / bias_correction
348+
339349
scale = var.sqrt().clamp_min(self.eps)
340350

341351
data_update = (data - loc) / scale
342352
if self.out_keys[: len(self.in_keys)] != self.in_keys:
343353
# map names
344354
for in_key, out_key in _zip_strict(self._in_keys_safe, self.out_keys):
345-
data_update.rename_key_(in_key, out_key)
355+
if in_key in data_update:
356+
data_update.rename_key_(in_key, out_key)
346357
else:
347358
pass
348359
return data_update
349360

350361
def _stateful_norm(self, data):
351-
return self._norm(data, self._loc, self._var)
362+
return self._norm(data, self._loc, self._var, self._count)
352363

353364
def _stateful_update(self, data):
354365
if self.frozen:
@@ -363,14 +374,14 @@ def _stateful_update(self, data):
363374
count = self._count
364375
count += 1
365376
data = self._maybe_cast_to_float(data)
366-
if self.decay < 1.0:
367-
bias_correction = 1 - (count * math.log(self.decay)).exp()
368-
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
377+
if self.decay != 1.0:
378+
weight = 1 - self.decay
379+
loc.lerp_(end=data, weight=weight)
380+
var.lerp_(end=data.pow(2), weight=weight)
369381
else:
370-
bias_correction = 1
371-
weight = (1 - self.decay) / bias_correction
372-
loc.lerp_(end=data, weight=weight)
373-
var.lerp_(end=data.pow(2), weight=weight)
382+
weight = 1 / count
383+
loc.lerp_(end=data, weight=weight)
384+
var.lerp_(end=data.pow(2), weight=weight)
374385

375386
def _maybe_stateless_init(self, data):
376387
if not self.initialized or f"{self.prefix}_loc" not in data.keys():
@@ -398,20 +409,18 @@ def _maybe_stateless_init(self, data):
398409
data[f"{self.prefix}_var"] = var
399410

400411
def _stateless_norm(self, data, loc, var, count):
401-
data = self._norm(data, loc, var)
412+
data = self._norm(data, loc, var, count)
402413
return data
403414

404415
def _stateless_update(self, data, loc, var, count):
405416
if self.frozen:
406417
return loc, var, count
407418
count = count + 1
408419
data = self._maybe_cast_to_float(data)
409-
if self.decay < 1.0:
410-
bias_correction = 1 - (count * math.log(self.decay)).exp()
411-
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
420+
if self.decay != 1.0:
421+
weight = 1 - self.decay
412422
else:
413-
bias_correction = 1
414-
weight = (1 - self.decay) / bias_correction
423+
weight = 1 / count
415424
loc = loc.lerp(end=data, weight=weight)
416425
var = var.lerp(end=data.pow(2), weight=weight)
417426
return loc, var, count
@@ -563,10 +572,18 @@ def to_observation_norm(self) -> Compose | ObservationNorm:
563572
def _get_loc_scale(self, loc_only: bool = False) -> tuple:
564573
if self.stateful:
565574
loc = self._loc
575+
count = self._count
576+
if self.decay != 1.0:
577+
bias_correction = 1 - (count * math.log(self.decay)).exp()
578+
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), loc)
579+
else:
580+
bias_correction = 1
566581
if loc_only:
567-
return loc, None
582+
return loc / bias_correction, None
568583
var = self._var
569584
var = var - loc.pow(2)
585+
loc = loc / bias_correction
586+
var = var / bias_correction
570587
scale = var.sqrt().clamp_min(self.eps)
571588
return loc, scale
572589
else:

0 commit comments

Comments
 (0)