@@ -695,6 +695,8 @@ def __init__(
695
695
self .split_trajs = split_trajs
696
696
self ._exclude_private_keys = True
697
697
self .interruptor = interruptor
698
+ self ._frames = 0
699
+ self ._iter = - 1
698
700
699
701
# for RPC
700
702
def next (self ):
@@ -745,11 +747,9 @@ def iterator(self) -> Iterator[TensorDictBase]:
745
747
stream = None
746
748
with torch .cuda .stream (stream ):
747
749
total_frames = self .total_frames
748
- i = - 1
749
- self ._frames = 0
750
- while True :
751
- i += 1
752
- self ._iter = i
750
+
751
+ while self ._frames < self .total_frames :
752
+ self ._iter += 1
753
753
tensordict_out = self .rollout ()
754
754
self ._frames += tensordict_out .numel ()
755
755
if self ._frames >= total_frames :
@@ -788,9 +788,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
788
788
# >>> assert data0["done"] is not data1["done"]
789
789
yield tensordict_out .clone ()
790
790
791
- if self ._frames >= self .total_frames :
792
- break
793
-
794
791
def _step_and_maybe_reset (self ) -> None :
795
792
796
793
any_done = False
@@ -985,6 +982,8 @@ def state_dict(self) -> OrderedDict:
985
982
else :
986
983
state_dict = OrderedDict (env_state_dict = env_state_dict )
987
984
985
+ state_dict .update ({"frames" : self ._frames , "iter" : self ._iter })
986
+
988
987
return state_dict
989
988
990
989
def load_state_dict (self , state_dict : OrderedDict , ** kwargs ) -> None :
@@ -1000,6 +999,8 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
1000
999
self .env .load_state_dict (state_dict ["env_state_dict" ], ** kwargs )
1001
1000
if strict or "policy_state_dict" in state_dict :
1002
1001
self .policy .load_state_dict (state_dict ["policy_state_dict" ], ** kwargs )
1002
+ self ._frames = state_dict ["frames" ]
1003
+ self ._iter = state_dict ["iter" ]
1003
1004
1004
1005
def __repr__ (self ) -> str :
1005
1006
env_str = indent (f"env={ self .env } " , 4 * " " )
@@ -1284,6 +1285,8 @@ def device_err_msg(device_name, devices_list):
1284
1285
self .interruptor = None
1285
1286
self ._run_processes ()
1286
1287
self ._exclude_private_keys = True
1288
+ self ._frames = 0
1289
+ self ._iter = - 1
1287
1290
1288
1291
@property
1289
1292
def frames_per_batch_worker (self ):
@@ -1471,6 +1474,7 @@ def state_dict(self) -> OrderedDict:
1471
1474
if msg != "state_dict" :
1472
1475
raise RuntimeError (f"Expected msg='state_dict', got { msg } " )
1473
1476
state_dict [f"worker{ idx } " ] = _state_dict
1477
+ state_dict .update ({"frames" : self ._frames , "iter" : self ._iter })
1474
1478
1475
1479
return state_dict
1476
1480
@@ -1488,6 +1492,8 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
1488
1492
_ , msg = self .pipes [idx ].recv ()
1489
1493
if msg != "loaded" :
1490
1494
raise RuntimeError (f"Expected msg='loaded', got { msg } " )
1495
+ self ._frames = state_dict ["frames" ]
1496
+ self ._iter = state_dict ["iter" ]
1491
1497
1492
1498
1493
1499
@accept_remote_rref_udf_invocation
@@ -1639,27 +1645,26 @@ def _queue_len(self) -> int:
1639
1645
return self .num_workers
1640
1646
1641
1647
def iterator (self ) -> Iterator [TensorDictBase ]:
1642
- i = - 1
1643
- frames = 0
1648
+
1644
1649
self .buffers = {}
1645
1650
dones = [False for _ in range (self .num_workers )]
1646
1651
workers_frames = [0 for _ in range (self .num_workers )]
1647
1652
same_device = None
1648
1653
self .out_buffer = None
1649
1654
1650
- while not all (dones ) and frames < self .total_frames :
1655
+ while not all (dones ) and self . _frames < self .total_frames :
1651
1656
_check_for_faulty_process (self .procs )
1652
1657
if self .update_at_each_batch :
1653
1658
self .update_policy_weights_ ()
1654
1659
1655
1660
for idx in range (self .num_workers ):
1656
- if frames < self .init_random_frames :
1661
+ if self . _frames < self .init_random_frames :
1657
1662
msg = "continue_random"
1658
1663
else :
1659
1664
msg = "continue"
1660
1665
self .pipes [idx ].send ((None , msg ))
1661
1666
1662
- i += 1
1667
+ self . _iter += 1
1663
1668
max_traj_idx = None
1664
1669
1665
1670
if self .interruptor is not None and self .preemptive_threshold < 1.0 :
@@ -1714,10 +1719,10 @@ def iterator(self) -> Iterator[TensorDictBase]:
1714
1719
1715
1720
if self .split_trajs :
1716
1721
out = split_trajectories (self .out_buffer , prefix = "collector" )
1717
- frames += out .get (("collector" , "mask" )).sum ().item ()
1722
+ self . _frames += out .get (("collector" , "mask" )).sum ().item ()
1718
1723
else :
1719
1724
out = self .out_buffer .clone ()
1720
- frames += prod (out .shape )
1725
+ self . _frames += prod (out .shape )
1721
1726
if self .postprocs :
1722
1727
self .postprocs = self .postprocs .to (out .device )
1723
1728
out = self .postprocs (out )
@@ -1894,13 +1899,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
1894
1899
else :
1895
1900
self .pipes [i ].send ((None , "continue" ))
1896
1901
self .running = True
1897
- i = - 1
1898
- self ._frames = 0
1899
1902
1900
1903
workers_frames = [0 for _ in range (self .num_workers )]
1901
1904
while self ._frames < self .total_frames :
1902
1905
_check_for_faulty_process (self .procs )
1903
- i += 1
1906
+ self . _iter += 1
1904
1907
idx , j , out = self ._get_from_queue ()
1905
1908
1906
1909
worker_frames = out .numel ()
0 commit comments