File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -6523,17 +6523,15 @@ def _reset_func(
6523
6523
if self .single_default_value and callable (self .default_value ):
6524
6524
if not _reset .all ():
6525
6525
# FIXME: use masked op
6526
- tensordict_reset = tensordict_reset .clone ()
6526
+ # tensordict_reset = tensordict_reset.clone()
6527
6527
reset_val = self .default_value (reset = _reset )
6528
6528
# This is safe because env.reset calls _update_during_reset which will discard the new data
6529
6529
tensordict_reset = (
6530
6530
self .container .full_observation_spec .zero ().select (
6531
6531
* reset_val .keys (True )
6532
6532
)
6533
6533
)
6534
- tensordict_reset = tensordict_reset .where (
6535
- _reset , reset_val , update_batch_size = True
6536
- )
6534
+ tensordict_reset = torch .where (_reset , reset_val , 0 )
6537
6535
else :
6538
6536
resets = self .default_value (reset = _reset )
6539
6537
tensordict_reset .update (resets )
You can’t perform that action at this time.
0 commit comments