Skip to content

Commit 156a668

Browse files
author
Vincent Moens
authored
[Feature] serial_for_single arg in batched envs (#1846)
1 parent 9da61f2 commit 156a668

File tree

22 files changed

+150
-17
lines changed

22 files changed

+150
-17
lines changed

examples/a2c/utils_atari.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def make_base_env(
6161

6262
def make_parallel_env(env_name, num_envs, device, is_test=False):
6363
env = ParallelEnv(
64-
num_envs, EnvCreator(lambda: make_base_env(env_name, device=device))
64+
num_envs,
65+
EnvCreator(lambda: make_base_env(env_name, device=device)),
66+
serial_for_single=True,
6567
)
6668
env = TransformedEnv(env)
6769
env.append_transform(ToTensorImage())

examples/cql/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
8080
parallel_env = ParallelEnv(
8181
train_num_envs,
8282
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
83+
serial_for_single=True,
8384
)
8485
parallel_env.set_seed(cfg.env.seed)
8586

@@ -89,6 +90,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
8990
ParallelEnv(
9091
eval_num_envs,
9192
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
93+
serial_for_single=True,
9294
),
9395
train_env.transform.clone(),
9496
)

examples/ddpg/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def make_environment(cfg):
7676
parallel_env = ParallelEnv(
7777
cfg.collector.env_per_collector,
7878
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
79+
serial_for_single=True,
7980
)
8081
parallel_env.set_seed(cfg.env.seed)
8182

@@ -87,6 +88,7 @@ def make_environment(cfg):
8788
ParallelEnv(
8889
cfg.collector.env_per_collector,
8990
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
91+
serial_for_single=True,
9092
),
9193
train_env.transform.clone(),
9294
)

examples/decision_transformer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def make_env():
142142
return make_base_env(env_cfg)
143143

144144
env = make_transformed_env(
145-
ParallelEnv(num_envs, EnvCreator(make_env)),
145+
ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True),
146146
env_cfg,
147147
obs_loc,
148148
obs_std,

examples/discrete_sac/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def make_environment(cfg):
7777
parallel_env = ParallelEnv(
7878
cfg.collector.env_per_collector,
7979
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
80+
serial_for_single=True,
8081
)
8182
parallel_env.set_seed(cfg.env.seed)
8283

@@ -88,6 +89,7 @@ def make_environment(cfg):
8889
ParallelEnv(
8990
cfg.collector.env_per_collector,
9091
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
92+
serial_for_single=True,
9193
),
9294
train_env.transform.clone(),
9395
)

examples/distributed/collectors/single_machine/generic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ def gym_make():
100100
if args.worker_parallelism == "collector" or num_workers == 1:
101101
action_spec = make_env().action_spec
102102
else:
103-
make_env = ParallelEnv(num_workers, make_env)
103+
make_env = ParallelEnv(
104+
num_workers,
105+
make_env,
106+
serial_for_single=True,
107+
)
104108
action_spec = make_env.action_spec
105109

106110
if args.worker_parallelism == "collector" and num_workers > 1:

examples/distributed/collectors/single_machine/rpc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ def gym_make():
9696
if num_workers == 1:
9797
action_spec = make_env().action_spec
9898
else:
99-
make_env = ParallelEnv(num_workers, make_env)
99+
make_env = ParallelEnv(
100+
num_workers,
101+
make_env,
102+
serial_for_single=True,
103+
)
100104
action_spec = make_env.action_spec
101105

102106
collector = RPCDataCollector(

examples/distributed/collectors/single_machine/sync.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@ def gym_make():
9595
if args.worker_parallelism == "collector" or num_workers == 1:
9696
action_spec = make_env().action_spec
9797
else:
98-
make_env = ParallelEnv(num_workers, make_env)
98+
make_env = ParallelEnv(
99+
num_workers,
100+
make_env,
101+
serial_for_single=True,
102+
)
99103
action_spec = make_env.action_spec
100104

101105
if args.worker_parallelism == "collector" and num_workers > 1:

examples/dreamer/dreamer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def parallel_env_constructor(
270270
create_env_kwargs=None,
271271
pin_memory=cfg.pin_memory,
272272
device=cfg.collector_device,
273+
serial_for_single=True,
273274
)
274275
if batch_transform:
275276
kwargs.update(

examples/iql/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
8484
parallel_env = ParallelEnv(
8585
train_num_envs,
8686
EnvCreator(lambda: env_maker(cfg)),
87+
serial_for_single=True,
8788
)
8889
parallel_env.set_seed(cfg.env.seed)
8990

@@ -93,6 +94,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
9394
ParallelEnv(
9495
eval_num_envs,
9596
EnvCreator(lambda: env_maker(cfg)),
97+
serial_for_single=True,
9698
),
9799
train_env.transform.clone(),
98100
)

0 commit comments

Comments
 (0)