Skip to content

Commit 3b64392

Browse files
authored
[BugFix] Fix zero-ing from specs in RewardSum (#860)
1 parent 6bebea7 commit 3b64392

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

test/test_rb.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,8 @@ def test_smoke_replay_buffer_transform(transform):
830830

831831
@pytest.mark.parametrize("transform", transforms)
832832
def test_smoke_replay_buffer_transform_no_inkeys(transform):
833+
if PinMemoryTransform is PinMemoryTransform and not torch.cuda.is_available():
834+
raise pytest.skip("No CUDA device detected, skipping PinMemory")
833835
rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=transform())
834836

835837
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])

torchrl/envs/transforms/transforms.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2632,19 +2632,18 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
26322632
)
26332633
for in_key, out_key in zip(self.in_keys, self.out_keys):
26342634
if out_key in tensordict.keys():
2635-
z = torch.zeros_like(tensordict[out_key])
2636-
_reset = _reset.view_as(z)
2637-
tensordict[out_key][_reset] = z[_reset]
2635+
value = tensordict[out_key]
2636+
dtype = value.dtype
2637+
tensordict[out_key] = value * (~_reset).to(dtype)
26382638
elif in_key == "reward":
26392639
# Since the episode reward is not in the tensordict, we need to allocate it
26402640
# with zeros entirely (regardless of the _reset mask)
2641-
z = self.parent.reward_spec.zero(self.parent.batch_size)
2642-
tensordict[out_key] = z
2641+
tensordict[out_key] = self.parent.reward_spec.zero()
26432642
else:
26442643
try:
2645-
tensordict[out_key] = self.parent.observation_spec[in_key].zero(
2646-
self.parent.batch_size
2647-
)
2644+
tensordict[out_key] = self.parent.observation_spec[
2645+
in_key
2646+
].zero()
26482647
except KeyError as err:
26492648
raise KeyError(
26502649
f"The key {in_key} was not found in the parent "

0 commit comments

Comments
 (0)