diff --git a/test/test_collector.py b/test/test_collector.py index 21862cf8297..ba77dcf1662 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -545,6 +545,7 @@ def env_fn(): frames_per_batch=frames_per_batch, max_frames_per_traj=1000, total_frames=frames_per_batch * 100, + cat_dim=-1, ) ccollector.set_seed(seed) for i, b in enumerate(ccollector): @@ -800,7 +801,10 @@ def test_collector_vecnorm_envcreator(static_seed): policy = RandomPolicy(env_make.action_spec) num_data_collectors = 2 c = MultiSyncDataCollector( - [env_make] * num_data_collectors, policy=policy, total_frames=int(1e6) + [env_make] * num_data_collectors, + policy=policy, + total_frames=int(1e6), + cat_dim=-1, ) init_seed = 0 @@ -856,11 +860,13 @@ def create_env(): collector_class = ( MultiSyncDataCollector if not use_async else MultiaSyncDataCollector ) + kwargs = {"cat_dim": -1} if not use_async else {} collector = collector_class( [create_env] * 3, policy=policy, devices=[torch.device("cuda:0")] * 3, storing_devices=[torch.device("cuda:0")] * 3, + **kwargs, ) # collect state_dict state_dict = collector.state_dict() @@ -933,6 +939,8 @@ def make_env(): collector_kwargs["create_env_fn"] = [ collector_kwargs["create_env_fn"] for _ in range(3) ] + if collector_class is MultiSyncDataCollector: + collector_kwargs["cat_dim"] = -1 collector = collector_class(**collector_kwargs) collector._exclude_private_keys = exclude @@ -1016,6 +1024,8 @@ def test_collector_output_keys( collector_kwargs["create_env_fn"] = [ collector_kwargs["create_env_fn"] for _ in range(num_envs) ] + if collector_class is MultiSyncDataCollector: + collector_kwargs["cat_dim"] = -1 collector = collector_class(**collector_kwargs) @@ -1093,6 +1103,7 @@ def env_fn(seed): storing_devices=[ storing_device, ], + cat_dim=-1, ) batch = next(collector.iterator()) assert batch.device == torch.device(storing_device) @@ -1151,6 +1162,8 @@ def _create_collector_kwargs(self, env_maker, collector_class, policy): collector_kwargs["create_env_fn"] = [ collector_kwargs["create_env_fn"] for _ in range(self.num_envs) ] + if collector_class is MultiSyncDataCollector: + collector_kwargs["cat_dim"] = -1 return collector_kwargs @@ -1324,12 +1337,14 @@ def env_fn(seed): storing_devices="cpu", split_trajs=False, preemptive_threshold=0.0, # stop after one iteration + cat_dim=-1, ) for batch in collector: trajectory_ids = batch["collector"]["traj_ids"] trajectory_ids_mask = trajectory_ids != -1 # valid frames mask - assert trajectory_ids[trajectory_ids_mask].numel() < frames_per_batch + assert trajectory_ids_mask.all() + assert trajectory_ids.numel() < frames_per_batch def test_maxframes_error(): @@ -1398,6 +1413,7 @@ def test_multi_collector_nested_env_consistency(self, seed=1): frames_per_batch=20, total_frames=100, device="cpu", + cat_dim=-1, ) for i, d in enumerate(ccollector): if i == 0: @@ -1411,8 +1427,8 @@ def test_multi_collector_nested_env_consistency(self, seed=1): assert_allclose_td(d1, d2) ccollector.shutdown() - assert_allclose_td(c1, d1) - assert_allclose_td(c2, d2) + assert_allclose_td(c1, d1.select(*c1.keys(True, True))) + assert_allclose_td(c2, d2.select(*c1.keys(True, True))) @pytest.mark.parametrize("nested_obs_action", [True, False]) @pytest.mark.parametrize("nested_done", [True, False]) @@ -1544,6 +1560,7 @@ def test_multi_collector_het_env_consistency( frames_per_batch=frames_per_batch, total_frames=100, device="cpu", + cat_dim=-1, ) for i, d in enumerate(ccollector): if i == 0: @@ -1557,8 +1574,8 @@ def test_multi_collector_het_env_consistency( assert_allclose_td(d1, d2) ccollector.shutdown() - assert_allclose_td(c1, d1) - assert_allclose_td(c2, d2) + assert_allclose_td(c1, d1.select(*c1.keys(True, True))) + assert_allclose_td(c2, d2.select(*c1.keys(True, True))) class TestMultiKeyEnvsCollector: @@ -1619,6 +1636,7 @@ def test_multi_collector_consistency( frames_per_batch=frames_per_batch, total_frames=100, device="cpu", + cat_dim=-1, ) for i, d in enumerate(ccollector): if i == 0: @@ -1632,8 +1650,8 @@ def test_multi_collector_consistency( assert_allclose_td(d1, d2) ccollector.shutdown() - assert_allclose_td(c1, d1) - assert_allclose_td(c2, d2) + assert_allclose_td(c1, d1.select(*c1.keys(True, True))) + assert_allclose_td(c2, d2.select(*c1.keys(True, True))) @pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda") diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b06a9b15252..bfa91fedcda 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -349,6 +349,117 @@ def __repr__(self) -> str: return string +def DataCollector( + create_env_fn: Sequence[Callable[[], EnvBase]], + policy: Optional[ + Union[ + TensorDictModule, + Callable[[TensorDictBase], TensorDictBase], + ] + ], + *, + num_workers: int = None, + sync: bool = True, + frames_per_batch: int = 200, + total_frames: Optional[int] = -1, + device: DEVICE_TYPING = None, + storing_device: Optional[Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]]] = None, + create_env_kwargs: Optional[Sequence[dict]] = None, + max_frames_per_traj: int = -1, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, + split_trajs: Optional[bool] = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + exploration_mode=None, + reset_when_done: bool = True, + preemptive_threshold: float = None, + update_at_each_batch: bool = False, +): + if not isinstance(create_env_fn, EnvBase) and not callable(create_env_fn): + if num_workers is not None and num_workers != len(create_env_fn): + raise TypeError( + "The number of workers provided does not match the number of environment constructors." + ) + else: + num_workers = len(create_env_fn) + elif num_workers is not None and num_workers > 0: + create_env_fn = [create_env_fn] * num_workers + from torchrl.envs import EnvCreator + + if num_workers and any( + not isinstance(func, (EnvCreator, EnvBase)) for func in create_env_fn + ): + create_env_fn = [ + func if isinstance(func, (EnvCreator, EnvBase)) else EnvCreator(func) + for func in create_env_fn + ] + if num_workers: + if sync: + return MultiSyncDataCollector( + create_env_fn=create_env_fn, + policy=policy, + num_workers=num_workers, + sync=sync, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=storing_device, + create_env_kwargs=create_env_kwargs, + max_frames_per_traj=max_frames_per_traj, + init_random_frames=init_random_frames, + reset_at_each_iter=reset_at_each_iter, + postproc=postproc, + split_trajs=split_trajs, + exploration_type=exploration_type, + exploration_mode=exploration_mode, + reset_when_done=reset_when_done, + preemptive_threshold=preemptive_threshold, + update_at_each_batch=update_at_each_batch, + cat_dim=-1, + ) + else: + return MultiaSyncDataCollector( + create_env_fn=create_env_fn, + policy=policy, + num_workers=num_workers, + sync=sync, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=storing_device, + create_env_kwargs=create_env_kwargs, + max_frames_per_traj=max_frames_per_traj, + init_random_frames=init_random_frames, + reset_at_each_iter=reset_at_each_iter, + postproc=postproc, + split_trajs=split_trajs, + exploration_type=exploration_type, + exploration_mode=exploration_mode, + reset_when_done=reset_when_done, + preemptive_threshold=preemptive_threshold, + update_at_each_batch=update_at_each_batch, + ) + else: + return SyncDataCollector( + create_env_fn=create_env_fn, + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=storing_device, + create_env_kwargs=create_env_kwargs, + max_frames_per_traj=max_frames_per_traj, + init_random_frames=init_random_frames, + reset_at_each_iter=reset_at_each_iter, + postproc=postproc, + split_trajs=split_trajs, + exploration_type=exploration_type, + exploration_mode=exploration_mode, + reset_when_done=reset_when_done, + ) + + @accept_remote_rref_udf_invocation class SyncDataCollector(DataCollectorBase): """Generic data collector for RL problems. Requires an environment constructor and a policy. @@ -480,6 +591,9 @@ class SyncDataCollector(DataCollectorBase): """ + def __new__(cls, *args, **kwargs): + return DataCollectorBase.__new__(cls) + def __init__( self, create_env_fn: Union[ @@ -1098,6 +1212,9 @@ class _MultiDataCollector(DataCollectorBase): that will be allowed to finished collecting their rollout before the rest are forced to end early. """ + def __new__(cls, *args, **kwargs): + return DataCollectorBase.__new__(cls) + def __init__( self, create_env_fn: Sequence[Callable[[], EnvBase]], @@ -1590,6 +1707,20 @@ class MultiSyncDataCollector(_MultiDataCollector): __doc__ += _MultiDataCollector.__doc__ + def __init__(self, *args, cat_dim=None, **kwargs): + if cat_dim is None: + warnings.warn( + "The cat_dim argument wasn't passed to the sync data collector. " + "The previous default value was 0 and will stay like " + "this until v0.3. From v0.3, the default will be -1 (the time dimension). In v0.4, " + "this argument will be likely be deprecated." + "To remove this warning, set the cat_dim to -1.", + category=DeprecationWarning, + ) + cat_dim = 0 + self._cat_dim = cat_dim + super().__init__(*args, **kwargs) + # for RPC def next(self): return super().next() @@ -1672,7 +1803,7 @@ def iterator(self) -> Iterator[TensorDictBase]: while self.queue_out.qsize() < int(self.num_workers): continue - for _ in range(self.num_workers): + for idx in range(self.num_workers): new_data, j = self.queue_out.get() if j == 0: data, idx = new_data @@ -1683,14 +1814,16 @@ def iterator(self) -> Iterator[TensorDictBase]: if workers_frames[idx] >= self.total_frames: dones[idx] = True - # we have to correct the traj_ids to make sure that they don't overlap - for idx in range(self.num_workers): + traj_ids = self.buffers[idx].get(("collector", "traj_ids")) - if max_traj_idx is not None: - traj_ids[traj_ids != -1] += max_traj_idx - # out_tensordicts_shared[idx].set("traj_ids", traj_ids) - max_traj_idx = traj_ids.max().item() + 1 - # out = out_tensordicts_shared[idx] + preempt_mask = traj_ids != -1 + if preempt_mask.all(): + if max_traj_idx is not None: + traj_ids += max_traj_idx + else: + if max_traj_idx is not None: + traj_ids[traj_ids != -1] += max_traj_idx + max_traj_idx = traj_ids.max().item() if same_device is None: prev_device = None same_device = True @@ -1702,12 +1835,12 @@ def iterator(self) -> Iterator[TensorDictBase]: if same_device: self.out_buffer = torch.cat( - list(self.buffers.values()), 0, out=self.out_buffer + list(self.buffers.values()), self._cat_dim, out=self.out_buffer ) else: self.out_buffer = torch.cat( [item.cpu() for item in self.buffers.values()], - 0, + self._cat_dim, out=self.out_buffer, ) @@ -1715,7 +1848,45 @@ def iterator(self) -> Iterator[TensorDictBase]: out = split_trajectories(self.out_buffer, prefix="collector") frames += out.get(("collector", "mask")).sum().item() else: - out = self.out_buffer.clone() + traj_ids = self.out_buffer.get(("collector", "traj_ids")) + cat_dim = self._cat_dim + if cat_dim < 0: + cat_dim = self.out_buffer.ndim + cat_dim + truncated = None + if cat_dim == self.out_buffer.ndim - 1: + idx = (slice(None),) * (self.out_buffer.ndim - 1) + truncated = ( + traj_ids[idx + (slice(None, -1),)] + != traj_ids[idx + (slice(1),)] + ) + truncated = torch.cat( + [ + truncated, + torch.ones_like(truncated[idx + (slice(-1, None),)]), + ], + self._cat_dim, + ) + valid_mask = traj_ids != -1 + shape = tuple( + s if i != cat_dim else -1 + for i, s in enumerate(self.out_buffer.shape) + ) + if not valid_mask.all(): + out = self.out_buffer[valid_mask] + if truncated is not None: + out.set( + ("next", "truncated"), truncated[valid_mask].unsqueeze(-1) + ) + out = out.reshape(shape) + out.names = self.out_buffer.names + else: + out = self.out_buffer.clone() + if truncated is not None: + out.set( + ("next", "truncated"), + truncated[valid_mask].reshape(*shape, 1), + ) + frames += prod(out.shape) if self.postprocs: self.postprocs = self.postprocs.to(out.device)