Skip to content

Commit bd1ccd7

Browse files
committed
amend
1 parent e67a1c7 commit bd1ccd7

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6523,17 +6523,15 @@ def _reset_func(
65236523
if self.single_default_value and callable(self.default_value):
65246524
if not _reset.all():
65256525
# FIXME: use masked op
6526-
tensordict_reset = tensordict_reset.clone()
6526+
# tensordict_reset = tensordict_reset.clone()
65276527
reset_val = self.default_value(reset=_reset)
65286528
# This is safe because env.reset calls _update_during_reset which will discard the new data
65296529
tensordict_reset = (
65306530
self.container.full_observation_spec.zero().select(
65316531
*reset_val.keys(True)
65326532
)
65336533
)
6534-
tensordict_reset = tensordict_reset.where(
6535-
_reset, reset_val, update_batch_size=True
6536-
)
6534+
tensordict_reset = torch.where(_reset, reset_val, 0)
65376535
else:
65386536
resets = self.default_value(reset=_reset)
65396537
tensordict_reset.update(resets)

0 commit comments

Comments
 (0)