Skip to content

Commit e0d3eee

Browse files
author
Vincent Moens
committed
"[BugFix] Refactor _skip_tensordict to avoid update calls (#2802)"
ghstack-source-id: 0f31b87 Pull Request resolved: #2802
1 parent 21c4d87 commit e0d3eee

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

torchrl/envs/common.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,31 +1947,35 @@ def state_spec_unbatched(self, spec: Composite):
19471947
spec = spec.expand(self.batch_size + spec.shape)
19481948
self.state_spec = spec
19491949

1950-
def _skip_tensordict(self, tensordict):
1950+
def _skip_tensordict(self, tensordict: TensorDictBase) -> TensorDictBase:
19511951
# Creates a "skip" tensordict, ie a placeholder for when a step is skipped
19521952
next_tensordict = self.full_done_spec.zero()
19531953
next_tensordict.update(self.full_observation_spec.zero())
19541954
next_tensordict.update(self.full_reward_spec.zero())
19551955

19561956
# Copy the data from tensordict in `next`
1957-
def select_and_clone(x, y):
1957+
keys = set()
1958+
1959+
def select_and_clone(name, x, y):
1960+
keys.add(name)
19581961
if y is not None:
19591962
if y.device == x.device:
19601963
return x.clone()
19611964
return x.to(y.device)
19621965

1963-
next_tensordict.update(
1964-
tensordict._fast_apply(
1965-
select_and_clone,
1966-
next_tensordict,
1967-
device=next_tensordict.device,
1968-
batch_size=next_tensordict.batch_size,
1969-
default=None,
1970-
filter_empty=True,
1971-
is_leaf=_is_leaf_nontensor,
1972-
)
1966+
result = tensordict._fast_apply(
1967+
select_and_clone,
1968+
next_tensordict,
1969+
device=next_tensordict.device,
1970+
batch_size=next_tensordict.batch_size,
1971+
default=None,
1972+
filter_empty=True,
1973+
is_leaf=_is_leaf_nontensor,
1974+
named=True,
1975+
nested_keys=True,
19731976
)
1974-
return next_tensordict
1977+
result.update(next_tensordict.exclude(*keys).filter_empty_())
1978+
return result
19751979

19761980
def step(self, tensordict: TensorDictBase) -> TensorDictBase:
19771981
"""Makes a step in the environment.

0 commit comments

Comments
 (0)