Skip to content

Commit 29f42ea

Browse files
[BugFix] Load collector frames and iter (#1557)
Signed-off-by: Matteo Bettini <matbet@meta.com>
1 parent 680412a commit 29f42ea

File tree

2 files changed

+64
-18
lines changed

2 files changed

+64
-18
lines changed

test/test_collector.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,6 +1740,49 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
17401740
col.shutdown()
17411741

17421742

1743+
@pytest.mark.parametrize(
1744+
"collector_class",
1745+
[MultiSyncDataCollector, MultiaSyncDataCollector, SyncDataCollector],
1746+
)
1747+
def test_collector_reloading(collector_class):
1748+
def make_env():
1749+
return ContinuousActionVecMockEnv()
1750+
1751+
dummy_env = make_env()
1752+
obs_spec = dummy_env.observation_spec["observation"]
1753+
policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1])
1754+
policy = Actor(policy_module, spec=dummy_env.action_spec)
1755+
policy_explore = OrnsteinUhlenbeckProcessWrapper(policy)
1756+
1757+
collector_kwargs = {
1758+
"create_env_fn": make_env,
1759+
"policy": policy_explore,
1760+
"frames_per_batch": 30,
1761+
"total_frames": 90,
1762+
}
1763+
if collector_class is not SyncDataCollector:
1764+
collector_kwargs["create_env_fn"] = [
1765+
collector_kwargs["create_env_fn"] for _ in range(3)
1766+
]
1767+
1768+
collector = collector_class(**collector_kwargs)
1769+
for i, _ in enumerate(collector):
1770+
if i == 3:
1771+
break
1772+
collector_frames = collector._frames
1773+
collector_iter = collector._iter
1774+
collector_state_dict = collector.state_dict()
1775+
collector.shutdown()
1776+
1777+
collector = collector_class(**collector_kwargs)
1778+
collector.load_state_dict(collector_state_dict)
1779+
assert collector._frames == collector_frames
1780+
assert collector._iter == collector_iter
1781+
for _ in enumerate(collector):
1782+
raise AssertionError
1783+
collector.shutdown()
1784+
1785+
17431786
if __name__ == "__main__":
17441787
args, unknown = argparse.ArgumentParser().parse_known_args()
17451788
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/collectors/collectors.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,8 @@ def __init__(
695695
self.split_trajs = split_trajs
696696
self._exclude_private_keys = True
697697
self.interruptor = interruptor
698+
self._frames = 0
699+
self._iter = -1
698700

699701
# for RPC
700702
def next(self):
@@ -745,11 +747,9 @@ def iterator(self) -> Iterator[TensorDictBase]:
745747
stream = None
746748
with torch.cuda.stream(stream):
747749
total_frames = self.total_frames
748-
i = -1
749-
self._frames = 0
750-
while True:
751-
i += 1
752-
self._iter = i
750+
751+
while self._frames < self.total_frames:
752+
self._iter += 1
753753
tensordict_out = self.rollout()
754754
self._frames += tensordict_out.numel()
755755
if self._frames >= total_frames:
@@ -788,9 +788,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
788788
# >>> assert data0["done"] is not data1["done"]
789789
yield tensordict_out.clone()
790790

791-
if self._frames >= self.total_frames:
792-
break
793-
794791
def _step_and_maybe_reset(self) -> None:
795792

796793
any_done = False
@@ -985,6 +982,8 @@ def state_dict(self) -> OrderedDict:
985982
else:
986983
state_dict = OrderedDict(env_state_dict=env_state_dict)
987984

985+
state_dict.update({"frames": self._frames, "iter": self._iter})
986+
988987
return state_dict
989988

990989
def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
@@ -1000,6 +999,8 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
1000999
self.env.load_state_dict(state_dict["env_state_dict"], **kwargs)
10011000
if strict or "policy_state_dict" in state_dict:
10021001
self.policy.load_state_dict(state_dict["policy_state_dict"], **kwargs)
1002+
self._frames = state_dict["frames"]
1003+
self._iter = state_dict["iter"]
10031004

10041005
def __repr__(self) -> str:
10051006
env_str = indent(f"env={self.env}", 4 * " ")
@@ -1284,6 +1285,8 @@ def device_err_msg(device_name, devices_list):
12841285
self.interruptor = None
12851286
self._run_processes()
12861287
self._exclude_private_keys = True
1288+
self._frames = 0
1289+
self._iter = -1
12871290

12881291
@property
12891292
def frames_per_batch_worker(self):
@@ -1471,6 +1474,7 @@ def state_dict(self) -> OrderedDict:
14711474
if msg != "state_dict":
14721475
raise RuntimeError(f"Expected msg='state_dict', got {msg}")
14731476
state_dict[f"worker{idx}"] = _state_dict
1477+
state_dict.update({"frames": self._frames, "iter": self._iter})
14741478

14751479
return state_dict
14761480

@@ -1488,6 +1492,8 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
14881492
_, msg = self.pipes[idx].recv()
14891493
if msg != "loaded":
14901494
raise RuntimeError(f"Expected msg='loaded', got {msg}")
1495+
self._frames = state_dict["frames"]
1496+
self._iter = state_dict["iter"]
14911497

14921498

14931499
@accept_remote_rref_udf_invocation
@@ -1639,27 +1645,26 @@ def _queue_len(self) -> int:
16391645
return self.num_workers
16401646

16411647
def iterator(self) -> Iterator[TensorDictBase]:
1642-
i = -1
1643-
frames = 0
1648+
16441649
self.buffers = {}
16451650
dones = [False for _ in range(self.num_workers)]
16461651
workers_frames = [0 for _ in range(self.num_workers)]
16471652
same_device = None
16481653
self.out_buffer = None
16491654

1650-
while not all(dones) and frames < self.total_frames:
1655+
while not all(dones) and self._frames < self.total_frames:
16511656
_check_for_faulty_process(self.procs)
16521657
if self.update_at_each_batch:
16531658
self.update_policy_weights_()
16541659

16551660
for idx in range(self.num_workers):
1656-
if frames < self.init_random_frames:
1661+
if self._frames < self.init_random_frames:
16571662
msg = "continue_random"
16581663
else:
16591664
msg = "continue"
16601665
self.pipes[idx].send((None, msg))
16611666

1662-
i += 1
1667+
self._iter += 1
16631668
max_traj_idx = None
16641669

16651670
if self.interruptor is not None and self.preemptive_threshold < 1.0:
@@ -1714,10 +1719,10 @@ def iterator(self) -> Iterator[TensorDictBase]:
17141719

17151720
if self.split_trajs:
17161721
out = split_trajectories(self.out_buffer, prefix="collector")
1717-
frames += out.get(("collector", "mask")).sum().item()
1722+
self._frames += out.get(("collector", "mask")).sum().item()
17181723
else:
17191724
out = self.out_buffer.clone()
1720-
frames += prod(out.shape)
1725+
self._frames += prod(out.shape)
17211726
if self.postprocs:
17221727
self.postprocs = self.postprocs.to(out.device)
17231728
out = self.postprocs(out)
@@ -1894,13 +1899,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
18941899
else:
18951900
self.pipes[i].send((None, "continue"))
18961901
self.running = True
1897-
i = -1
1898-
self._frames = 0
18991902

19001903
workers_frames = [0 for _ in range(self.num_workers)]
19011904
while self._frames < self.total_frames:
19021905
_check_for_faulty_process(self.procs)
1903-
i += 1
1906+
self._iter += 1
19041907
idx, j, out = self._get_from_queue()
19051908

19061909
worker_frames = out.numel()

0 commit comments

Comments
 (0)