File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -614,14 +614,18 @@ def _reset_if_necessary(self) -> None:
614
614
steps = steps .clone ()
615
615
if len (self .env .batch_size ):
616
616
self ._tensordict .masked_fill_ (done_or_terminated , 0 )
617
- self ._tensordict .set ("_reset" , done_or_terminated )
617
+ _reset = done_or_terminated
618
+ self ._tensordict .set ("_reset" , _reset )
618
619
else :
620
+ _reset = None
619
621
self ._tensordict .zero_ ()
620
622
self .env .reset (self ._tensordict )
621
623
622
- if self ._tensordict .get ("done" ).any ():
624
+ if (_reset is None and self ._tensordict .get ("done" ).any ()) or (
625
+ _reset is not None and self ._tensordict .get ("done" )[_reset ].any ()
626
+ ):
623
627
raise RuntimeError (
624
- f"Got { sum ( self ._tensordict . get ( 'done' )) } done envs after reset."
628
+ f"Env { self .env } was done after reset on specified '_reset' dimensions. This is (currently) not allowed ."
625
629
)
626
630
traj_ids [done_or_terminated ] = traj_ids .max () + torch .arange (
627
631
1 , done_or_terminated .sum () + 1 , device = traj_ids .device
You can’t perform that action at this time.
0 commit comments