Skip to content

Commit 61f1cc4

Browse files
authored
[BugFix] Deprecate tensordict.set check skips in transforms (#951)
1 parent ea9ff22 commit 61f1cc4

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

torchrl/envs/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,6 @@ def fake_tensordict(self) -> TensorDictBase:
830830
},
831831
batch_size=self.batch_size,
832832
device=self.device,
833-
_run_checks=True, # this method should not be run very often. This facilitates debugging
834833
)
835834
return fake_td
836835

torchrl/envs/transforms/transforms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2638,18 +2638,27 @@ def _update(self, key, value, N) -> torch.Tensor:
26382638
value_sum = _sum_left(value, _sum)
26392639
_sum *= self.decay
26402640
_sum += value_sum
2641-
self._td.set_(key + "_sum", _sum, no_check=True)
2641+
self._td.set_(
2642+
key + "_sum",
2643+
_sum,
2644+
)
26422645

26432646
_ssq = self._td.get(key + "_ssq")
26442647
value_ssq = _sum_left(value.pow(2), _ssq)
26452648
_ssq *= self.decay
26462649
_ssq += value_ssq
2647-
self._td.set_(key + "_ssq", _ssq, no_check=True)
2650+
self._td.set_(
2651+
key + "_ssq",
2652+
_ssq,
2653+
)
26482654

26492655
_count = self._td.get(key + "_count")
26502656
_count *= self.decay
26512657
_count += N
2652-
self._td.set_(key + "_count", _count, no_check=True)
2658+
self._td.set_(
2659+
key + "_count",
2660+
_count,
2661+
)
26532662

26542663
mean = _sum / _count
26552664
std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps).sqrt()

torchrl/envs/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def step_mdp(
2727
exclude_reward: bool = True,
2828
exclude_done: bool = True,
2929
exclude_action: bool = True,
30-
_run_check: bool = True,
3130
) -> TensorDictBase:
3231
"""Creates a new tensordict that reflects a step in time of the input tensordict.
3332

0 commit comments

Comments
 (0)