@@ -374,6 +374,14 @@ def __init__(
374
374
375
375
is_spec_locked = EnvBase .is_spec_locked
376
376
377
+ def select_and_clone (self , name , tensor , selected_keys = None ):
378
+ if selected_keys is None :
379
+ selected_keys = self ._selected_step_keys
380
+ if name in selected_keys :
381
+ if self .device is not None and tensor .device != self .device :
382
+ return tensor .to (self .device , non_blocking = self .non_blocking )
383
+ return tensor .clone ()
384
+
377
385
@property
378
386
def non_blocking (self ):
379
387
nb = self ._non_blocking
@@ -1062,12 +1070,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1062
1070
selected_output_keys = self ._selected_reset_keys_filt
1063
1071
1064
1072
# select + clone creates 2 tds, but we can create one only
1065
- def select_and_clone (name , tensor ):
1066
- if name in selected_output_keys :
1067
- return tensor .clone ()
1068
-
1069
1073
out = self .shared_tensordict_parent .named_apply (
1070
- select_and_clone ,
1074
+ lambda * args : self .select_and_clone (
1075
+ * args , selected_keys = selected_output_keys
1076
+ ),
1071
1077
nested_keys = True ,
1072
1078
filter_empty = True ,
1073
1079
)
@@ -1135,14 +1141,14 @@ def _step(
1135
1141
# will be modified in-place at further steps
1136
1142
device = self .device
1137
1143
1138
- def select_and_clone (name , tensor ):
1139
- if name in self ._selected_step_keys :
1140
- return tensor .clone ()
1144
+ selected_keys = self ._selected_step_keys
1141
1145
1142
1146
if partial_steps is not None :
1143
1147
next_td = TensorDict .lazy_stack ([next_td [i ] for i in workers_range ])
1144
1148
out = next_td .named_apply (
1145
- select_and_clone , nested_keys = True , filter_empty = True
1149
+ lambda * args : self .select_and_clone (* args , selected_keys ),
1150
+ nested_keys = True ,
1151
+ filter_empty = True ,
1146
1152
)
1147
1153
if out_tds is not None :
1148
1154
out .update (
@@ -1841,20 +1847,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
1841
1847
next_td = shared_tensordict_parent .get ("next" )
1842
1848
device = self .device
1843
1849
1844
- if next_td .device != device and device is not None :
1845
-
1846
- def select_and_clone (name , tensor ):
1847
- if name in self ._selected_step_keys :
1848
- return tensor .to (device , non_blocking = self .non_blocking )
1849
-
1850
- else :
1851
-
1852
- def select_and_clone (name , tensor ):
1853
- if name in self ._selected_step_keys :
1854
- return tensor .clone ()
1855
-
1856
1850
out = next_td .named_apply (
1857
- select_and_clone ,
1851
+ self . select_and_clone ,
1858
1852
nested_keys = True ,
1859
1853
filter_empty = True ,
1860
1854
device = device ,
@@ -2005,20 +1999,10 @@ def tentative_update(val, other):
2005
1999
selected_output_keys = self ._selected_reset_keys_filt
2006
2000
device = self .device
2007
2001
2008
- if self .shared_tensordict_parent .device != device and device is not None :
2009
-
2010
- def select_and_clone (name , tensor ):
2011
- if name in selected_output_keys :
2012
- return tensor .to (device , non_blocking = self .non_blocking )
2013
-
2014
- else :
2015
-
2016
- def select_and_clone (name , tensor ):
2017
- if name in selected_output_keys :
2018
- return tensor .clone ()
2019
-
2020
2002
out = self .shared_tensordict_parent .named_apply (
2021
- select_and_clone ,
2003
+ lambda * args : self .select_and_clone (
2004
+ * args , selected_keys = selected_output_keys
2005
+ ),
2022
2006
nested_keys = True ,
2023
2007
filter_empty = True ,
2024
2008
device = device ,
0 commit comments