Skip to content

Commit 74fdadc

Browse files
authored
[Feature] Simplifying collector envs (#870)
1 parent b876ce6 commit 74fdadc

File tree

13 files changed

+104
-133
lines changed

13 files changed

+104
-133
lines changed

docs/source/reference/collectors.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ avoid OOM errors. Finally, the choice of the batch size and passing device (ie t
4141
device where the data will be stored while waiting to be passed to the collection
4242
worker) may also impact the memory management. The key parameters to control are
4343
:obj:`devices` which controls the execution devices (ie the device of the policy)
44-
and :obj:`passing_devices` which will control the device where the environment and
44+
and :obj:`storing_devices` which will control the device where the environment and
4545
data are stored during a rollout. A good heuristic is usually to use the same device
4646
for storage and compute, which is the default behaviour when only the `devices` argument
4747
is being passed.

docs/source/reference/trainers.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ Trainer and hooks
141141
LogReward
142142
OptimizerHook
143143
Recorder
144-
ReplayBuffer
144+
ReplayBufferTrainer
145145
RewardNormalizer
146146
SelectKeys
147147
Trainer

test/test_collector.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def make_policy(env):
131131

132132

133133
def _is_consistent_device_type(
134-
device_type, policy_device_type, passing_device_type, tensordict_device_type
134+
device_type, policy_device_type, storing_device_type, tensordict_device_type
135135
):
136-
if passing_device_type is None:
136+
if storing_device_type is None:
137137
if device_type is None:
138138
if policy_device_type is None:
139139
return tensordict_device_type == "cpu"
@@ -142,7 +142,7 @@ def _is_consistent_device_type(
142142

143143
return tensordict_device_type == device_type
144144

145-
return tensordict_device_type == passing_device_type
145+
return tensordict_device_type == storing_device_type
146146

147147

148148
@pytest.mark.skipif(
@@ -152,12 +152,12 @@ def _is_consistent_device_type(
152152
@pytest.mark.parametrize("num_env", [1, 2])
153153
@pytest.mark.parametrize("device", ["cuda", "cpu", None])
154154
@pytest.mark.parametrize("policy_device", ["cuda", "cpu", None])
155-
@pytest.mark.parametrize("passing_device", ["cuda", "cpu", None])
155+
@pytest.mark.parametrize("storing_device", ["cuda", "cpu", None])
156156
def test_output_device_consistency(
157-
num_env, device, policy_device, passing_device, seed=40
157+
num_env, device, policy_device, storing_device, seed=40
158158
):
159159
if (
160-
device == "cuda" or policy_device == "cuda" or passing_device == "cuda"
160+
device == "cuda" or policy_device == "cuda" or storing_device == "cuda"
161161
) and not torch.cuda.is_available():
162162
pytest.skip("cuda is not available")
163163

@@ -169,7 +169,7 @@ def test_output_device_consistency(
169169

170170
_device = "cuda:0" if device == "cuda" else device
171171
_policy_device = "cuda:0" if policy_device == "cuda" else policy_device
172-
_passing_device = "cuda:0" if passing_device == "cuda" else passing_device
172+
_storing_device = "cuda:0" if storing_device == "cuda" else storing_device
173173

174174
if num_env == 1:
175175

@@ -201,12 +201,12 @@ def env_fn(seed):
201201
max_frames_per_traj=2000,
202202
total_frames=20000,
203203
device=_device,
204-
passing_device=_passing_device,
204+
storing_device=_storing_device,
205205
pin_memory=False,
206206
)
207207
for _, d in enumerate(collector):
208208
assert _is_consistent_device_type(
209-
device, policy_device, passing_device, d.device.type
209+
device, policy_device, storing_device, d.device.type
210210
)
211211
break
212212

@@ -220,13 +220,13 @@ def env_fn(seed):
220220
max_frames_per_traj=2000,
221221
total_frames=20000,
222222
device=_device,
223-
passing_device=_passing_device,
223+
storing_device=_storing_device,
224224
pin_memory=False,
225225
)
226226

227227
for _, d in enumerate(ccollector):
228228
assert _is_consistent_device_type(
229-
device, policy_device, passing_device, d.device.type
229+
device, policy_device, storing_device, d.device.type
230230
)
231231
break
232232

@@ -833,7 +833,7 @@ def create_env():
833833
[create_env] * 3,
834834
policy=policy,
835835
devices=[torch.device("cuda:0")] * 3,
836-
passing_devices=[torch.device("cuda:0")] * 3,
836+
storing_devices=[torch.device("cuda:0")] * 3,
837837
)
838838
# collect state_dict
839839
state_dict = collector.state_dict()
@@ -1010,13 +1010,13 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
10101010

10111011

10121012
@pytest.mark.parametrize("device", ["cuda", "cpu"])
1013-
@pytest.mark.parametrize("passing_device", ["cuda", "cpu"])
1013+
@pytest.mark.parametrize("storing_device", ["cuda", "cpu"])
10141014
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found")
1015-
def test_collector_device_combinations(device, passing_device):
1015+
def test_collector_device_combinations(device, storing_device):
10161016
if (
10171017
_os_is_windows
10181018
and _python_is_3_10
1019-
and passing_device == "cuda"
1019+
and storing_device == "cuda"
10201020
and device == "cuda"
10211021
):
10221022
pytest.skip("Windows fatal exception: access violation in torch.storage")
@@ -1036,11 +1036,11 @@ def env_fn(seed):
10361036
max_frames_per_traj=2000,
10371037
total_frames=20000,
10381038
device=device,
1039-
passing_device=passing_device,
1039+
storing_device=storing_device,
10401040
pin_memory=False,
10411041
)
10421042
batch = next(collector.iterator())
1043-
assert batch.device == torch.device(passing_device)
1043+
assert batch.device == torch.device(storing_device)
10441044
collector.shutdown()
10451045

10461046
collector = MultiSyncDataCollector(
@@ -1057,13 +1057,13 @@ def env_fn(seed):
10571057
devices=[
10581058
device,
10591059
],
1060-
passing_devices=[
1061-
passing_device,
1060+
storing_devices=[
1061+
storing_device,
10621062
],
10631063
pin_memory=False,
10641064
)
10651065
batch = next(collector.iterator())
1066-
assert batch.device == torch.device(passing_device)
1066+
assert batch.device == torch.device(storing_device)
10671067
collector.shutdown()
10681068

10691069
collector = MultiaSyncDataCollector(
@@ -1080,13 +1080,13 @@ def env_fn(seed):
10801080
devices=[
10811081
device,
10821082
],
1083-
passing_devices=[
1084-
passing_device,
1083+
storing_devices=[
1084+
storing_device,
10851085
],
10861086
pin_memory=False,
10871087
)
10881088
batch = next(collector.iterator())
1089-
assert batch.device == torch.device(passing_device)
1089+
assert batch.device == torch.device(storing_device)
10901090
collector.shutdown()
10911091

10921092

test/test_libs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def test_collector_run(self, env_lib, env_args, env_kwargs, device):
337337
reset_at_each_iter=False,
338338
split_trajs=True,
339339
devices=[device, device],
340-
passing_devices=[device, device],
340+
storing_devices=[device, device],
341341
update_at_each_batch=False,
342342
init_with_lag=False,
343343
exploration_mode="random",

0 commit comments

Comments
 (0)