Skip to content

Commit 37e0c53

Browse files
authored
[Refactor] Minor refactorings to envs (#872)
1 parent 31a6db1 commit 37e0c53

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

torchrl/envs/common.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,19 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
376376
# sanity check
377377
self._assert_tensordict_shape(tensordict)
378378

379-
tensordict.is_locked = True # make sure _step does not modify the tensordict
379+
tensordict.lock() # make sure _step does not modify the tensordict
380380
tensordict_out = self._step(tensordict)
381-
tensordict.is_locked = False
381+
if tensordict_out is tensordict:
382+
raise RuntimeError(
383+
"EnvBase._step should return outplace changes to the input "
384+
"tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or "
385+
"tensordict.select()) inside _step before writing new tensors onto this new instance."
386+
)
387+
tensordict.unlock()
388+
382389
obs_keys = set(self.observation_spec.keys())
383390
tensordict_out_select = tensordict_out.select(*obs_keys)
384-
tensordict_out = tensordict_out.exclude(*obs_keys)
391+
tensordict_out = tensordict_out.exclude(*obs_keys, inplace=True)
385392
tensordict_out.set("next", tensordict_out_select)
386393

387394
reward = tensordict_out.get("reward")
@@ -409,12 +416,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
409416
done = done.view(expected_done_shape)
410417
tensordict_out.set("done", done)
411418

412-
if tensordict_out is tensordict:
413-
raise RuntimeError(
414-
"EnvBase._step should return outplace changes to the input "
415-
"tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or "
416-
"tensordict.select()) inside _step before writing new tensors onto this new instance."
417-
)
418419
if self.run_type_checks:
419420
for key in self._select_observation_keys(tensordict_out):
420421
obs = tensordict_out.get(key)
@@ -432,7 +433,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
432433
)
433434
tensordict.update(tensordict_out, inplace=self._inplace_update)
434435

435-
del tensordict_out
436436
return tensordict
437437

438438
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -726,8 +726,6 @@ def _to_tensor(
726726
value = torch.as_tensor(value, device=device)
727727
else:
728728
value = value.to(device)
729-
# if dtype is not None:
730-
# value = value.to(dtype)
731729
return value
732730

733731
def close(self):

torchrl/envs/gym_like.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
177177
action = tensordict.get("action")
178178
action_np = self.read_action(action)
179179

180-
reward = self.reward_spec.zero(self.batch_size)
180+
reward = self.reward_spec.zero()
181181
for _ in range(self.wrapper_frame_skip):
182182
obs, _reward, done, *info = self._output_transform(
183183
self._env.step(action_np)
@@ -200,7 +200,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
200200
)
201201

202202
if _reward is None:
203-
_reward = self.reward_spec.zero(self.batch_size)
203+
_reward = self.reward_spec.zero()
204204

205205
reward = self.read_reward(reward, _reward)
206206

0 commit comments

Comments
 (0)