@@ -1947,31 +1947,35 @@ def state_spec_unbatched(self, spec: Composite):
1947
1947
spec = spec .expand (self .batch_size + spec .shape )
1948
1948
self .state_spec = spec
1949
1949
1950
- def _skip_tensordict (self , tensordict ) :
1950
+ def _skip_tensordict (self , tensordict : TensorDictBase ) -> TensorDictBase :
1951
1951
# Creates a "skip" tensordict, ie a placeholder for when a step is skipped
1952
1952
next_tensordict = self .full_done_spec .zero ()
1953
1953
next_tensordict .update (self .full_observation_spec .zero ())
1954
1954
next_tensordict .update (self .full_reward_spec .zero ())
1955
1955
1956
1956
# Copy the data from tensordict in `next`
1957
- def select_and_clone (x , y ):
1957
+ keys = set ()
1958
+
1959
+ def select_and_clone (name , x , y ):
1960
+ keys .add (name )
1958
1961
if y is not None :
1959
1962
if y .device == x .device :
1960
1963
return x .clone ()
1961
1964
return x .to (y .device )
1962
1965
1963
- next_tensordict . update (
1964
- tensordict . _fast_apply (
1965
- select_and_clone ,
1966
- next_tensordict ,
1967
- device = next_tensordict .device ,
1968
- batch_size = next_tensordict . batch_size ,
1969
- default = None ,
1970
- filter_empty = True ,
1971
- is_leaf = _is_leaf_nontensor ,
1972
- )
1966
+ result = tensordict . _fast_apply (
1967
+ select_and_clone ,
1968
+ next_tensordict ,
1969
+ device = next_tensordict . device ,
1970
+ batch_size = next_tensordict .batch_size ,
1971
+ default = None ,
1972
+ filter_empty = True ,
1973
+ is_leaf = _is_leaf_nontensor ,
1974
+ named = True ,
1975
+ nested_keys = True ,
1973
1976
)
1974
- return next_tensordict
1977
+ result .update (next_tensordict .exclude (* keys ).filter_empty_ ())
1978
+ return result
1975
1979
1976
1980
def step (self , tensordict : TensorDictBase ) -> TensorDictBase :
1977
1981
"""Makes a step in the environment.
0 commit comments