@@ -208,7 +208,7 @@ def _get_policy_and_device(
208
208
return policy , None
209
209
210
210
if isinstance (policy , nn .Module ):
211
- param_and_buf = TensorDict .from_module (policy , as_module = True ). data
211
+ param_and_buf = TensorDict .from_module (policy , as_module = True )
212
212
else :
213
213
# Because we want to reach the warning
214
214
param_and_buf = TensorDict ()
@@ -231,19 +231,25 @@ def _get_policy_and_device(
231
231
return policy , None
232
232
233
233
# Create a stateless policy, then populate this copy with params on device
234
- def get_original_weights (policy ):
234
+ def get_original_weights (policy = policy ):
235
235
td = TensorDict .from_module (policy )
236
236
return td .data
237
237
238
238
# We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function
239
239
with param_and_buf .data .to ("meta" ).to_module (policy ):
240
- policy = deepcopy (policy )
240
+ policy_new_device = deepcopy (policy )
241
241
242
- param_and_buf .apply (
242
+ param_and_buf_new_device = param_and_buf .apply (
243
243
functools .partial (_map_weight , policy_device = policy_device ),
244
244
filter_empty = False ,
245
- ).to_module (policy )
246
- return policy , get_original_weights
245
+ )
246
+ param_and_buf_new_device .to_module (policy_new_device )
247
+ # Sanity check
248
+ if set (TensorDict .from_module (policy_new_device ).keys (True , True )) != set (
249
+ get_original_weights ().keys (True , True )
250
+ ):
251
+ raise RuntimeError ("Failed to map weights. The weight sets mismatch." )
252
+ return policy_new_device , get_original_weights
247
253
248
254
def start (self ):
249
255
"""Starts the collector for asynchronous data collection.
@@ -1976,17 +1982,17 @@ def __init__(
1976
1982
for policy_device , env_maker , env_maker_kwargs in _zip_strict (
1977
1983
self .policy_device , self .create_env_fn , self .create_env_kwargs
1978
1984
):
1979
- (policy_copy , get_weights_fn ,) = self ._get_policy_and_device (
1985
+ (policy_new_device , get_weights_fn ,) = self ._get_policy_and_device (
1980
1986
policy = policy ,
1981
1987
policy_device = policy_device ,
1982
1988
env_maker = env_maker ,
1983
1989
env_maker_kwargs = env_maker_kwargs ,
1984
1990
)
1985
- if type (policy_copy ) is not type (policy ):
1986
- policy = policy_copy
1991
+ if type (policy_new_device ) is not type (policy ):
1992
+ policy = policy_new_device
1987
1993
weights = (
1988
- TensorDict .from_module (policy_copy ).data
1989
- if isinstance (policy_copy , nn .Module )
1994
+ TensorDict .from_module (policy_new_device ).data
1995
+ if isinstance (policy_new_device , nn .Module )
1990
1996
else TensorDict ()
1991
1997
)
1992
1998
self ._policy_weights_dict [policy_device ] = weights
0 commit comments