File tree Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -6526,12 +6526,15 @@ def _reset_func(
6526
6526
# tensordict_reset = tensordict_reset.clone()
6527
6527
reset_val = self .default_value (reset = _reset )
6528
6528
# This is safe because env.reset calls _update_during_reset which will discard the new data
6529
- tensordict_reset = (
6530
- self .container .full_observation_spec .zero ().select (
6531
- * reset_val .keys (True )
6532
- )
6533
- )
6534
- tensordict_reset = torch .where (_reset , reset_val , 0 )
6529
+ # tensordict_reset = (
6530
+ # self.container.full_observation_spec.zero().select(
6531
+ # *reset_val.keys(True)
6532
+ # )
6533
+ # )
6534
+ tensordict_reset = _reset .new_zeros (_reset .shape )
6535
+ print (f"tensordict_reset: { tensordict_reset } " )
6536
+ print (f"reset_val: { reset_val } " )
6537
+ tensordict_reset [_reset ] = reset_val
6535
6538
else :
6536
6539
resets = self .default_value (reset = _reset )
6537
6540
tensordict_reset .update (resets )
You can’t perform that action at this time.
0 commit comments