Skip to content

Commit 2998610

Browse files
committed
amend
1 parent bd1ccd7 commit 2998610

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6526,12 +6526,15 @@ def _reset_func(
65266526
# tensordict_reset = tensordict_reset.clone()
65276527
reset_val = self.default_value(reset=_reset)
65286528
# This is safe because env.reset calls _update_during_reset which will discard the new data
6529-
tensordict_reset = (
6530-
self.container.full_observation_spec.zero().select(
6531-
*reset_val.keys(True)
6532-
)
6533-
)
6534-
tensordict_reset = torch.where(_reset, reset_val, 0)
6529+
# tensordict_reset = (
6530+
# self.container.full_observation_spec.zero().select(
6531+
# *reset_val.keys(True)
6532+
# )
6533+
# )
6534+
tensordict_reset = _reset.new_zeros(_reset.shape)
6535+
print(f"tensordict_reset: {tensordict_reset}")
6536+
print(f"reset_val: {reset_val}")
6537+
tensordict_reset[_reset] = reset_val
65356538
else:
65366539
resets = self.default_value(reset=_reset)
65376540
tensordict_reset.update(resets)

0 commit comments

Comments
 (0)