File tree Expand file tree Collapse file tree 1 file changed +30
-0
lines changed Expand file tree Collapse file tree 1 file changed +30
-0
lines changed Original file line number Diff line number Diff line change @@ -1933,6 +1933,36 @@ def state_spec_unbatched(self, spec: Composite):
1933
1933
spec = spec .expand (self .batch_size + spec .shape )
1934
1934
self .state_spec = spec
1935
1935
1936
+ def _skip_tensordict (self , tensordict : TensorDictBase ) -> TensorDictBase :
1937
+ # Creates a "skip" tensordict, ie a placeholder for when a step is skipped
1938
+ next_tensordict = self .full_done_spec .zero ()
1939
+ next_tensordict .update (self .full_observation_spec .zero ())
1940
+ next_tensordict .update (self .full_reward_spec .zero ())
1941
+
1942
+ # Copy the data from tensordict in `next`
1943
+ keys = set ()
1944
+
1945
+ def select_and_clone (name , x , y ):
1946
+ keys .add (name )
1947
+ if y is not None :
1948
+ if y .device == x .device :
1949
+ return x .clone ()
1950
+ return x .to (y .device )
1951
+
1952
+ result = tensordict ._fast_apply (
1953
+ select_and_clone ,
1954
+ next_tensordict ,
1955
+ device = next_tensordict .device ,
1956
+ batch_size = next_tensordict .batch_size ,
1957
+ default = None ,
1958
+ filter_empty = True ,
1959
+ is_leaf = _is_leaf_nontensor ,
1960
+ named = True ,
1961
+ nested_keys = True ,
1962
+ )
1963
+ result .update (next_tensordict .exclude (* keys ).filter_empty_ ())
1964
+ return result
1965
+
1936
1966
def step (self , tensordict : TensorDictBase ) -> TensorDictBase :
1937
1967
"""Makes a step in the environment.
1938
1968
You can’t perform that action at this time.
0 commit comments