Skip to content

Commit 3da4d6a

Browse files
[Features]: Keep actions and rewards across steps in rollout (#460)
Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent e0e3cf3 commit 3da4d6a

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

test/mocking_classes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,21 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int:
109109

110110
def _step(self, tensordict):
111111
self.counter += 1
112-
n = torch.tensor([self.counter]).to(self.device).to(torch.get_default_dtype())
112+
n = torch.tensor(
113+
[self.counter], device=self.device, dtype=torch.get_default_dtype()
114+
)
113115
done = self.counter >= self.max_val
114116
done = torch.tensor([done], dtype=torch.bool, device=self.device)
115-
return TensorDict({"reward": n, "done": done, "next_observation": n}, [])
117+
return TensorDict(
118+
{"reward": n, "done": done, "next_observation": n.clone()}, []
119+
)
116120

117121
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
118122
self.max_val = max(self.counter + 100, self.counter * 2)
119123

120-
n = torch.tensor([self.counter]).to(self.device).to(torch.get_default_dtype())
124+
n = torch.tensor(
125+
[self.counter], device=self.device, dtype=torch.get_default_dtype()
126+
)
121127
done = self.counter >= self.max_val
122128
done = torch.tensor([done], dtype=torch.bool, device=self.device)
123129
return TensorDict({"done": done, "next_observation": n}, [])

torchrl/envs/common.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,11 @@ def __init__(
229229

230230
@classmethod
231231
def __new__(cls, *args, _batch_locked=True, **kwargs):
232-
cls._inplace_update = True
232+
# inplace update will write tensors in-place on the provided tensordict.
233+
# This is risky, especially if gradients need to be passed (in-place copy
234+
# for tensors that are part of computational graphs will result in an error).
235+
# It can also lead to inconsistencies when calling rollout.
236+
cls._inplace_update = False
233237
cls._batch_locked = _batch_locked
234238
return super().__new__(cls)
235239

@@ -552,7 +556,12 @@ def policy(td):
552556
break_when_any_done and tensordict.get("done").any()
553557
) or i == max_steps - 1:
554558
break
555-
tensordict = step_tensordict(tensordict, keep_other=True)
559+
tensordict = step_tensordict(
560+
tensordict,
561+
keep_other=True,
562+
exclude_reward=False,
563+
exclude_action=False,
564+
)
556565

557566
if callback is not None:
558567
callback(self, tensordict)

0 commit comments

Comments
 (0)