Skip to content

Commit 6d95880

Browse files
author
Vincent Moens
committed
"[BugFix] Refactor _skip_tensordict to avoid update calls (#2802)"
ghstack-source-id: 0f31b87 Pull Request resolved: #2802 (cherry picked from commit e0d3eee)
1 parent 276dee6 commit 6d95880

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

torchrl/envs/common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,36 @@ def state_spec_unbatched(self, spec: Composite):
19331933
spec = spec.expand(self.batch_size + spec.shape)
19341934
self.state_spec = spec
19351935

1936+
def _skip_tensordict(self, tensordict: TensorDictBase) -> TensorDictBase:
1937+
# Creates a "skip" tensordict, ie a placeholder for when a step is skipped
1938+
next_tensordict = self.full_done_spec.zero()
1939+
next_tensordict.update(self.full_observation_spec.zero())
1940+
next_tensordict.update(self.full_reward_spec.zero())
1941+
1942+
# Copy the data from tensordict in `next`
1943+
keys = set()
1944+
1945+
def select_and_clone(name, x, y):
1946+
keys.add(name)
1947+
if y is not None:
1948+
if y.device == x.device:
1949+
return x.clone()
1950+
return x.to(y.device)
1951+
1952+
result = tensordict._fast_apply(
1953+
select_and_clone,
1954+
next_tensordict,
1955+
device=next_tensordict.device,
1956+
batch_size=next_tensordict.batch_size,
1957+
default=None,
1958+
filter_empty=True,
1959+
is_leaf=_is_leaf_nontensor,
1960+
named=True,
1961+
nested_keys=True,
1962+
)
1963+
result.update(next_tensordict.exclude(*keys).filter_empty_())
1964+
return result
1965+
19361966
def step(self, tensordict: TensorDictBase) -> TensorDictBase:
19371967
"""Makes a step in the environment.
19381968

0 commit comments

Comments
 (0)