@@ -376,12 +376,19 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
376
376
# sanity check
377
377
self ._assert_tensordict_shape (tensordict )
378
378
379
- tensordict .is_locked = True # make sure _step does not modify the tensordict
379
+ tensordict .lock () # make sure _step does not modify the tensordict
380
380
tensordict_out = self ._step (tensordict )
381
- tensordict .is_locked = False
381
+ if tensordict_out is tensordict :
382
+ raise RuntimeError (
383
+ "EnvBase._step should return outplace changes to the input "
384
+ "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or "
385
+ "tensordict.select()) inside _step before writing new tensors onto this new instance."
386
+ )
387
+ tensordict .unlock ()
388
+
382
389
obs_keys = set (self .observation_spec .keys ())
383
390
tensordict_out_select = tensordict_out .select (* obs_keys )
384
- tensordict_out = tensordict_out .exclude (* obs_keys )
391
+ tensordict_out = tensordict_out .exclude (* obs_keys , inplace = True )
385
392
tensordict_out .set ("next" , tensordict_out_select )
386
393
387
394
reward = tensordict_out .get ("reward" )
@@ -409,12 +416,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
409
416
done = done .view (expected_done_shape )
410
417
tensordict_out .set ("done" , done )
411
418
412
- if tensordict_out is tensordict :
413
- raise RuntimeError (
414
- "EnvBase._step should return outplace changes to the input "
415
- "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or "
416
- "tensordict.select()) inside _step before writing new tensors onto this new instance."
417
- )
418
419
if self .run_type_checks :
419
420
for key in self ._select_observation_keys (tensordict_out ):
420
421
obs = tensordict_out .get (key )
@@ -432,7 +433,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
432
433
)
433
434
tensordict .update (tensordict_out , inplace = self ._inplace_update )
434
435
435
- del tensordict_out
436
436
return tensordict
437
437
438
438
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
@@ -726,8 +726,6 @@ def _to_tensor(
726
726
value = torch .as_tensor (value , device = device )
727
727
else :
728
728
value = value .to (device )
729
- # if dtype is not None:
730
- # value = value.to(dtype)
731
729
return value
732
730
733
731
def close (self ):
0 commit comments