Skip to content

Commit 1476c31

Browse files
authored
[Refactor] Refactor data collectors constructors (#970)
1 parent 878d023 commit 1476c31

File tree

4 files changed

+307
-198
lines changed

4 files changed

+307
-198
lines changed

.circleci/unittest/linux_examples/scripts/run_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
9999
record_frames=4 \
100100
lr_scheduler=
101101
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
102-
total_frames=48 \
102+
total_frames=200 \
103103
init_random_frames=10 \
104104
batch_size=10 \
105105
frames_per_batch=200 \
@@ -201,7 +201,7 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
201201
record_frames=4 \
202202
lr_scheduler=
203203
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
204-
total_frames=48 \
204+
total_frames=200 \
205205
init_random_frames=10 \
206206
batch_size=10 \
207207
frames_per_batch=200 \

test/test_collector.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def env_fn(seed):
204204
total_frames=20000,
205205
device=_device,
206206
storing_device=_storing_device,
207-
pin_memory=False,
208207
)
209208
for _, d in enumerate(collector):
210209
assert _is_consistent_device_type(
@@ -223,7 +222,6 @@ def env_fn(seed):
223222
total_frames=20000,
224223
device=_device,
225224
storing_device=_storing_device,
226-
pin_memory=False,
227225
)
228226

229227
for _, d in enumerate(ccollector):
@@ -265,7 +263,6 @@ def env_fn(seed):
265263
max_frames_per_traj=2000,
266264
total_frames=20000,
267265
device="cpu",
268-
pin_memory=False,
269266
)
270267
for i, d in enumerate(collector):
271268
if i == 0:
@@ -285,7 +282,6 @@ def env_fn(seed):
285282
frames_per_batch=20,
286283
max_frames_per_traj=2000,
287284
total_frames=20000,
288-
pin_memory=False,
289285
)
290286
for i, d in enumerate(ccollector):
291287
if i == 0:
@@ -314,7 +310,7 @@ def make_env():
314310
# env = SerialEnv(2, lambda: GymEnv("CartPole-v1", frame_skip=4))
315311
env.set_seed(0)
316312
collector = SyncDataCollector(
317-
env, total_frames=10000, frames_per_batch=10000, split_trajs=False
313+
env, policy=None, total_frames=10000, frames_per_batch=10000, split_trajs=False
318314
)
319315
for _data in collector:
320316
continue
@@ -370,7 +366,6 @@ def make_env(seed):
370366
max_frames_per_traj=2000,
371367
total_frames=20000,
372368
device="cpu",
373-
pin_memory=False,
374369
reset_when_done=False,
375370
)
376371
for _, d in enumerate(collector): # noqa
@@ -420,7 +415,6 @@ def make_env(seed):
420415
max_frames_per_traj=2000,
421416
total_frames=20000,
422417
device="cpu",
423-
pin_memory=False,
424418
reset_when_done=True,
425419
split_trajs=True,
426420
)
@@ -460,7 +454,6 @@ def make_env(seed):
460454
# frames_per_batch=20,
461455
# max_frames_per_traj=2000,
462456
# total_frames=20000,
463-
# pin_memory=False,
464457
# )
465458
# for i, d in enumerate(ccollector):
466459
# if i == 0:
@@ -507,7 +500,6 @@ def env_fn():
507500
frames_per_batch=frames_per_batch,
508501
max_frames_per_traj=1000,
509502
total_frames=frames_per_batch * 100,
510-
pin_memory=False,
511503
)
512504
ccollector.set_seed(seed)
513505
for i, b in enumerate(ccollector):
@@ -522,7 +514,6 @@ def env_fn():
522514
frames_per_batch=frames_per_batch,
523515
max_frames_per_traj=1000,
524516
total_frames=frames_per_batch * 100,
525-
pin_memory=False,
526517
)
527518
ccollector.set_seed(seed)
528519
for i, b in enumerate(ccollector):
@@ -563,7 +554,6 @@ def env_fn():
563554
frames_per_batch=20,
564555
max_frames_per_traj=20,
565556
total_frames=300,
566-
pin_memory=False,
567557
)
568558
ccollector.set_seed(seed)
569559
for i, data in enumerate(ccollector):
@@ -627,7 +617,6 @@ def env_fn(seed):
627617
max_frames_per_traj=20,
628618
total_frames=200,
629619
device="cpu",
630-
pin_memory=False,
631620
)
632621
collector_iter = iter(collector)
633622
b1 = next(collector_iter)
@@ -683,9 +672,8 @@ def make_frames_per_batch(frames_per_batch):
683672
max_frames_per_traj=2000,
684673
total_frames=2 * num_env * max_frames_per_traj,
685674
device="cpu",
686-
seed=seed,
687-
pin_memory=False,
688675
)
676+
collector1.set_seed(seed)
689677
count = 0
690678
data1 = []
691679
for d in collector1:
@@ -708,9 +696,8 @@ def make_frames_per_batch(frames_per_batch):
708696
max_frames_per_traj=2000,
709697
total_frames=2 * num_env * max_frames_per_traj,
710698
device="cpu",
711-
seed=seed,
712-
pin_memory=False,
713699
)
700+
collector10.set_seed(seed)
714701
count = 0
715702
data10 = []
716703
for d in collector10:
@@ -733,9 +720,8 @@ def make_frames_per_batch(frames_per_batch):
733720
max_frames_per_traj=2000,
734721
total_frames=2 * num_env * max_frames_per_traj,
735722
device="cpu",
736-
seed=seed,
737-
pin_memory=False,
738723
)
724+
collector20.set_seed(seed)
739725
count = 0
740726
data20 = []
741727
for d in collector20:
@@ -902,6 +888,7 @@ def make_env():
902888
"create_env_fn": make_env,
903889
"policy": policy_explore,
904890
"frames_per_batch": 30,
891+
"total_frames": -1,
905892
}
906893
if collector_class is not SyncDataCollector:
907894
collector_kwargs["create_env_fn"] = [
@@ -1045,7 +1032,6 @@ def env_fn(seed):
10451032
total_frames=20000,
10461033
device=device,
10471034
storing_device=storing_device,
1048-
pin_memory=False,
10491035
)
10501036
batch = next(collector.iterator())
10511037
assert batch.device == torch.device(storing_device)
@@ -1068,7 +1054,6 @@ def env_fn(seed):
10681054
storing_devices=[
10691055
storing_device,
10701056
],
1071-
pin_memory=False,
10721057
)
10731058
batch = next(collector.iterator())
10741059
assert batch.device == torch.device(storing_device)
@@ -1091,7 +1076,6 @@ def env_fn(seed):
10911076
storing_devices=[
10921077
storing_device,
10931078
],
1094-
pin_memory=False,
10951079
)
10961080
batch = next(collector.iterator())
10971081
assert batch.device == torch.device(storing_device)
@@ -1117,7 +1101,12 @@ def env_maker(self):
11171101
return lambda: GymEnv(PENDULUM_VERSIONED)
11181102

11191103
def _create_collector_kwargs(self, env_maker, collector_class, policy):
1120-
collector_kwargs = {"create_env_fn": env_maker, "policy": policy}
1104+
collector_kwargs = {
1105+
"create_env_fn": env_maker,
1106+
"policy": policy,
1107+
"frames_per_batch": 200,
1108+
"total_frames": -1,
1109+
}
11211110

11221111
if collector_class is not SyncDataCollector:
11231112
collector_kwargs["create_env_fn"] = [
@@ -1216,6 +1205,7 @@ def test_initial_obs_consistency(env_class, seed=1):
12161205
policy=policy,
12171206
frames_per_batch=((max_steps - 3) * 2 + 2) * num_envs, # at least two episodes
12181207
split_trajs=False,
1208+
total_frames=-1,
12191209
)
12201210
for _d in collector:
12211211
break

0 commit comments

Comments
 (0)