@@ -204,7 +204,6 @@ def env_fn(seed):
204
204
total_frames = 20000 ,
205
205
device = _device ,
206
206
storing_device = _storing_device ,
207
- pin_memory = False ,
208
207
)
209
208
for _ , d in enumerate (collector ):
210
209
assert _is_consistent_device_type (
@@ -223,7 +222,6 @@ def env_fn(seed):
223
222
total_frames = 20000 ,
224
223
device = _device ,
225
224
storing_device = _storing_device ,
226
- pin_memory = False ,
227
225
)
228
226
229
227
for _ , d in enumerate (ccollector ):
@@ -265,7 +263,6 @@ def env_fn(seed):
265
263
max_frames_per_traj = 2000 ,
266
264
total_frames = 20000 ,
267
265
device = "cpu" ,
268
- pin_memory = False ,
269
266
)
270
267
for i , d in enumerate (collector ):
271
268
if i == 0 :
@@ -285,7 +282,6 @@ def env_fn(seed):
285
282
frames_per_batch = 20 ,
286
283
max_frames_per_traj = 2000 ,
287
284
total_frames = 20000 ,
288
- pin_memory = False ,
289
285
)
290
286
for i , d in enumerate (ccollector ):
291
287
if i == 0 :
@@ -314,7 +310,7 @@ def make_env():
314
310
# env = SerialEnv(2, lambda: GymEnv("CartPole-v1", frame_skip=4))
315
311
env .set_seed (0 )
316
312
collector = SyncDataCollector (
317
- env , total_frames = 10000 , frames_per_batch = 10000 , split_trajs = False
313
+ env , policy = None , total_frames = 10000 , frames_per_batch = 10000 , split_trajs = False
318
314
)
319
315
for _data in collector :
320
316
continue
@@ -370,7 +366,6 @@ def make_env(seed):
370
366
max_frames_per_traj = 2000 ,
371
367
total_frames = 20000 ,
372
368
device = "cpu" ,
373
- pin_memory = False ,
374
369
reset_when_done = False ,
375
370
)
376
371
for _ , d in enumerate (collector ): # noqa
@@ -420,7 +415,6 @@ def make_env(seed):
420
415
max_frames_per_traj = 2000 ,
421
416
total_frames = 20000 ,
422
417
device = "cpu" ,
423
- pin_memory = False ,
424
418
reset_when_done = True ,
425
419
split_trajs = True ,
426
420
)
@@ -460,7 +454,6 @@ def make_env(seed):
460
454
# frames_per_batch=20,
461
455
# max_frames_per_traj=2000,
462
456
# total_frames=20000,
463
- # pin_memory=False,
464
457
# )
465
458
# for i, d in enumerate(ccollector):
466
459
# if i == 0:
@@ -507,7 +500,6 @@ def env_fn():
507
500
frames_per_batch = frames_per_batch ,
508
501
max_frames_per_traj = 1000 ,
509
502
total_frames = frames_per_batch * 100 ,
510
- pin_memory = False ,
511
503
)
512
504
ccollector .set_seed (seed )
513
505
for i , b in enumerate (ccollector ):
@@ -522,7 +514,6 @@ def env_fn():
522
514
frames_per_batch = frames_per_batch ,
523
515
max_frames_per_traj = 1000 ,
524
516
total_frames = frames_per_batch * 100 ,
525
- pin_memory = False ,
526
517
)
527
518
ccollector .set_seed (seed )
528
519
for i , b in enumerate (ccollector ):
@@ -563,7 +554,6 @@ def env_fn():
563
554
frames_per_batch = 20 ,
564
555
max_frames_per_traj = 20 ,
565
556
total_frames = 300 ,
566
- pin_memory = False ,
567
557
)
568
558
ccollector .set_seed (seed )
569
559
for i , data in enumerate (ccollector ):
@@ -627,7 +617,6 @@ def env_fn(seed):
627
617
max_frames_per_traj = 20 ,
628
618
total_frames = 200 ,
629
619
device = "cpu" ,
630
- pin_memory = False ,
631
620
)
632
621
collector_iter = iter (collector )
633
622
b1 = next (collector_iter )
@@ -683,9 +672,8 @@ def make_frames_per_batch(frames_per_batch):
683
672
max_frames_per_traj = 2000 ,
684
673
total_frames = 2 * num_env * max_frames_per_traj ,
685
674
device = "cpu" ,
686
- seed = seed ,
687
- pin_memory = False ,
688
675
)
676
+ collector1 .set_seed (seed )
689
677
count = 0
690
678
data1 = []
691
679
for d in collector1 :
@@ -708,9 +696,8 @@ def make_frames_per_batch(frames_per_batch):
708
696
max_frames_per_traj = 2000 ,
709
697
total_frames = 2 * num_env * max_frames_per_traj ,
710
698
device = "cpu" ,
711
- seed = seed ,
712
- pin_memory = False ,
713
699
)
700
+ collector10 .set_seed (seed )
714
701
count = 0
715
702
data10 = []
716
703
for d in collector10 :
@@ -733,9 +720,8 @@ def make_frames_per_batch(frames_per_batch):
733
720
max_frames_per_traj = 2000 ,
734
721
total_frames = 2 * num_env * max_frames_per_traj ,
735
722
device = "cpu" ,
736
- seed = seed ,
737
- pin_memory = False ,
738
723
)
724
+ collector20 .set_seed (seed )
739
725
count = 0
740
726
data20 = []
741
727
for d in collector20 :
@@ -902,6 +888,7 @@ def make_env():
902
888
"create_env_fn" : make_env ,
903
889
"policy" : policy_explore ,
904
890
"frames_per_batch" : 30 ,
891
+ "total_frames" : - 1 ,
905
892
}
906
893
if collector_class is not SyncDataCollector :
907
894
collector_kwargs ["create_env_fn" ] = [
@@ -1045,7 +1032,6 @@ def env_fn(seed):
1045
1032
total_frames = 20000 ,
1046
1033
device = device ,
1047
1034
storing_device = storing_device ,
1048
- pin_memory = False ,
1049
1035
)
1050
1036
batch = next (collector .iterator ())
1051
1037
assert batch .device == torch .device (storing_device )
@@ -1068,7 +1054,6 @@ def env_fn(seed):
1068
1054
storing_devices = [
1069
1055
storing_device ,
1070
1056
],
1071
- pin_memory = False ,
1072
1057
)
1073
1058
batch = next (collector .iterator ())
1074
1059
assert batch .device == torch .device (storing_device )
@@ -1091,7 +1076,6 @@ def env_fn(seed):
1091
1076
storing_devices = [
1092
1077
storing_device ,
1093
1078
],
1094
- pin_memory = False ,
1095
1079
)
1096
1080
batch = next (collector .iterator ())
1097
1081
assert batch .device == torch .device (storing_device )
@@ -1117,7 +1101,12 @@ def env_maker(self):
1117
1101
return lambda : GymEnv (PENDULUM_VERSIONED )
1118
1102
1119
1103
def _create_collector_kwargs (self , env_maker , collector_class , policy ):
1120
- collector_kwargs = {"create_env_fn" : env_maker , "policy" : policy }
1104
+ collector_kwargs = {
1105
+ "create_env_fn" : env_maker ,
1106
+ "policy" : policy ,
1107
+ "frames_per_batch" : 200 ,
1108
+ "total_frames" : - 1 ,
1109
+ }
1121
1110
1122
1111
if collector_class is not SyncDataCollector :
1123
1112
collector_kwargs ["create_env_fn" ] = [
@@ -1216,6 +1205,7 @@ def test_initial_obs_consistency(env_class, seed=1):
1216
1205
policy = policy ,
1217
1206
frames_per_batch = ((max_steps - 3 ) * 2 + 2 ) * num_envs , # at least two episodes
1218
1207
split_trajs = False ,
1208
+ total_frames = - 1 ,
1219
1209
)
1220
1210
for _d in collector :
1221
1211
break
0 commit comments