Skip to content

Commit e353b20

Browse files
author
Vincent Moens
authored
[BugFix] Fix EXAMPLES.md (#1649)
1 parent 105e861 commit e353b20

File tree

29 files changed

+1538
-314
lines changed

29 files changed

+1538
-314
lines changed

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
7373
optim.batch_size=10 \
7474
collector.frames_per_batch=16 \
7575
collector.env_per_collector=2 \
76-
collector.collector_device=cuda:0 \
76+
collector.device=cuda:0 \
7777
network.device=cuda:0 \
7878
optim.utd_ratio=1 \
7979
replay_buffer.size=120 \
@@ -107,23 +107,24 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
107107
record_frames=4 \
108108
buffer_size=120
109109
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
110-
total_frames=48 \
111-
init_random_frames=10 \
112-
batch_size=10 \
113-
frames_per_batch=16 \
114110
num_workers=4 \
115-
env_per_collector=2 \
116-
collector_device=cuda:0 \
117-
optim_steps_per_batch=1 \
118-
record_video=True \
119-
record_frames=4 \
120-
buffer_size=120
111+
collector.total_frames=48 \
112+
collector.init_random_frames=10 \
113+
collector.frames_per_batch=16 \
114+
collector.env_per_collector=2 \
115+
collector.device=cuda:0 \
116+
buffer.batch_size=10 \
117+
optim.steps_per_batch=1 \
118+
logger.record_video=True \
119+
logger.record_frames=4 \
120+
buffer.size=120 \
121+
logger.backend=
121122
python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
122123
collector.total_frames=48 \
123124
collector.init_random_frames=10 \
124125
collector.frames_per_batch=16 \
125126
collector.env_per_collector=2 \
126-
collector.collector_device=cuda:0 \
127+
collector.device=cuda:0 \
127128
optim.batch_size=10 \
128129
optim.utd_ratio=1 \
129130
replay_buffer.size=120 \
@@ -152,21 +153,21 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
152153
collector.frames_per_batch=16 \
153154
collector.num_workers=4 \
154155
collector.env_per_collector=2 \
155-
collector.collector_device=cuda:0 \
156+
collector.device=cuda:0 \
157+
collector.device=cuda:0 \
156158
network.device=cuda:0 \
157159
logger.mode=offline \
158160
env.name=Pendulum-v1 \
159161
logger.backend=
160162
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
161-
total_frames=48 \
162-
batch_size=10 \
163-
frames_per_batch=16 \
164-
num_workers=4 \
165-
env_per_collector=2 \
166-
collector_device=cuda:0 \
167-
device=cuda:0 \
168-
mode=offline \
169-
logger=
163+
collector.total_frames=48 \
164+
buffer.batch_size=10 \
165+
collector.frames_per_batch=16 \
166+
collector.env_per_collector=2 \
167+
collector.device=cuda:0 \
168+
network.device=cuda:0 \
169+
logger.mode=offline \
170+
logger.backend=
170171

171172
# With single envs
172173
python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
@@ -188,7 +189,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
188189
optim.batch_size=10 \
189190
collector.frames_per_batch=16 \
190191
collector.env_per_collector=1 \
191-
collector.collector_device=cuda:0 \
192+
collector.device=cuda:0 \
192193
network.device=cuda:0 \
193194
optim.utd_ratio=1 \
194195
replay_buffer.size=120 \
@@ -209,23 +210,24 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
209210
record_frames=4 \
210211
buffer_size=120
211212
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
212-
total_frames=48 \
213-
init_random_frames=10 \
214-
batch_size=10 \
215-
frames_per_batch=16 \
216213
num_workers=2 \
217-
env_per_collector=1 \
218-
collector_device=cuda:0 \
219-
optim_steps_per_batch=1 \
220-
record_video=True \
221-
record_frames=4 \
222-
buffer_size=120
214+
collector.total_frames=48 \
215+
collector.init_random_frames=10 \
216+
collector.frames_per_batch=16 \
217+
collector.env_per_collector=1 \
218+
buffer.batch_size=10 \
219+
collector.device=cuda:0 \
220+
optim.steps_per_batch=1 \
221+
logger.record_video=True \
222+
logger.record_frames=4 \
223+
buffer.size=120 \
224+
logger.backend=
223225
python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
224226
collector.total_frames=48 \
225227
collector.init_random_frames=10 \
226228
collector.frames_per_batch=16 \
227229
collector.env_per_collector=1 \
228-
collector.collector_device=cuda:0 \
230+
collector.device=cuda:0 \
229231
optim.batch_size=10 \
230232
optim.utd_ratio=1 \
231233
network.device=cuda:0 \
@@ -235,24 +237,23 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
235237
env.name=Pendulum-v1 \
236238
logger.backend=
237239
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
238-
total_frames=48 \
239-
batch_size=10 \
240-
frames_per_batch=16 \
241-
num_workers=2 \
242-
env_per_collector=1 \
243-
mode=offline \
244-
device=cuda:0 \
245-
collector_device=cuda:0 \
246-
logger=
240+
collector.total_frames=48 \
241+
collector.frames_per_batch=16 \
242+
collector.env_per_collector=1 \
243+
collector.device=cuda:0 \
244+
network.device=cuda:0 \
245+
buffer.batch_size=10 \
246+
logger.mode=offline \
247+
logger.backend=
247248
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
248249
collector.total_frames=48 \
249250
collector.init_random_frames=10 \
250-
optim.batch_size=10 \
251251
collector.frames_per_batch=16 \
252252
collector.num_workers=2 \
253253
collector.env_per_collector=1 \
254+
collector.device=cuda:0 \
254255
logger.mode=offline \
255-
collector.collector_device=cuda:0 \
256+
optim.batch_size=10 \
256257
env.name=Pendulum-v1 \
257258
logger.backend=
258259
python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/mappo_ippo.py \

examples/EXAMPLES.md

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ python sac.py
1818
```
1919
or similar. Hyperparameters can be easily changed by providing the arguments to hydra:
2020
```
21-
python sac.py frames_per_batch=63
21+
python sac.py collector.frames_per_batch=63
2222
```
2323
# Results
2424

@@ -32,11 +32,11 @@ We average the results over 5 different seeds and plot the standard error.
3232
To reproduce a single run:
3333

3434
```
35-
python sac/sac.py env_name="HalfCheetah-v4" env_task="" env_library="gym"
35+
python sac/sac.py env.name="HalfCheetah-v4" env.task="" env.library="gym"
3636
```
3737

3838
```
39-
python redq/redq.py env_name="HalfCheetah-v4" env_task="" env_library="gym"
39+
python redq/redq.py env.name="HalfCheetah-v4" env.library="gymnasium"
4040
```
4141

4242

@@ -48,39 +48,61 @@ python redq/redq.py env_name="HalfCheetah-v4" env_task="" env_library="gym"
4848
To reproduce a single run:
4949

5050
```
51-
python sac/sac.py env_name="cheetah" env_task="run" env_library="dm_control"
51+
python sac/sac.py env.name="cheetah" env.task="run" env.library="dm_control"
5252
```
5353

5454
```
55-
python redq/redq.py env_name="cheetah" env_task="run" env_library="dm_control"
55+
python redq/redq.py env.name="cheetah" env.task="run" env.library="dm_control"
5656
```
5757

58-
## Gym's Ant-v4
58+
[//]: # (TODO: adapt these scripts)
59+
[//]: # (## Gym's Ant-v4)
5960

60-
<p align="center">
61-
<img src="media/ant_chart.png" width="600px">
62-
</p>
63-
To reproduce a single run:
61+
[//]: # ()
62+
[//]: # (<p align="center">)
6463

65-
```
66-
python sac/sac.py env_name="Ant-v4" env_task="" env_library="gym"
67-
```
64+
[//]: # (<img src="media/ant_chart.png" width="600px">)
6865

69-
```
70-
python redq/redq.py env_name="Ant-v4" env_task="" env_library="gym"
71-
```
66+
[//]: # (</p>)
7267

73-
## Gym's Walker2D-v4
68+
[//]: # (To reproduce a single run:)
7469

75-
<p align="center">
76-
<img src="media/walker2d_chart.png" width="600px">
77-
</p>
78-
To reproduce a single run:
70+
[//]: # ()
71+
[//]: # (```)
7972

80-
```
81-
python sac/sac.py env_name="Walker2D-v4" env_task="" env_library="gym"
82-
```
73+
[//]: # (python sac/sac.py env.name="Ant-v4" env.task="" env.library="gym")
8374

84-
```
85-
python redq/redq.py env_name="Walker2D-v4" env_task="" env_library="gym"
86-
```
75+
[//]: # (```)
76+
77+
[//]: # ()
78+
[//]: # (``` )
79+
80+
[//]: # (python redq/redq.py env_name="Ant-v4" env_task="" env_library="gym")
81+
82+
[//]: # (```)
83+
84+
[//]: # ()
85+
[//]: # (## Gym's Walker2D-v4)
86+
87+
[//]: # ()
88+
[//]: # (<p align="center">)
89+
90+
[//]: # (<img src="media/walker2d_chart.png" width="600px">)
91+
92+
[//]: # (</p>)
93+
94+
[//]: # (To reproduce a single run:)
95+
96+
[//]: # ()
97+
[//]: # (```)
98+
99+
[//]: # (python sac/sac.py env_name="Walker2D-v4" env_task="" env_library="gym")
100+
101+
[//]: # (```)
102+
103+
[//]: # ()
104+
[//]: # (``` )
105+
106+
[//]: # (python redq/redq.py env_name="Walker2D-v4" env_task="" env_library="gym")
107+
108+
[//]: # (```)

examples/cql/cql_offline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727

2828

29-
@hydra.main(config_path=".", config_name="offline_config")
29+
@hydra.main(config_path=".", config_name="offline_config", version_base="1.1")
3030
def main(cfg: "DictConfig"): # noqa: F821
3131
exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name)
3232
logger = None

examples/cql/cql_online.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828

2929

30-
@hydra.main(config_path=".", config_name="online_config")
30+
@hydra.main(version_base="1.1", config_path=".", config_name="online_config")
3131
def main(cfg: "DictConfig"): # noqa: F821
3232
exp_name = generate_exp_name("CQL-online", cfg.env.exp_name)
3333
logger = None

examples/cql/online_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ collector:
1818
multi_step: 0
1919
init_random_frames: 1000
2020
env_per_collector: 1
21-
collector_device: cpu
21+
device: cpu
2222
max_frames_per_traj: 200
2323

2424
# logger

examples/cql/utils.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
1313
from torchrl.data.replay_buffers import SamplerWithoutReplacement
1414
from torchrl.envs import (
15+
CatTensors,
1516
Compose,
17+
DMControlEnv,
1618
DoubleToFloat,
1719
EnvCreator,
1820
ParallelEnv,
1921
RewardScaling,
2022
TransformedEnv,
2123
)
22-
from torchrl.envs.libs.gym import GymEnv
24+
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
2325
from torchrl.envs.utils import ExplorationType, set_exploration_type
2426
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
2527
from torchrl.objectives import CQLLoss, SoftUpdate
@@ -32,8 +34,21 @@
3234
# -----------------
3335

3436

35-
def env_maker(task, frame_skip=1, device="cpu", from_pixels=False):
36-
return GymEnv(task, device=device, frame_skip=frame_skip, from_pixels=from_pixels)
37+
def env_maker(cfg, device="cpu"):
38+
lib = cfg.env.library
39+
if lib in ("gym", "gymnasium"):
40+
with set_gym_backend(lib):
41+
return GymEnv(
42+
cfg.env.name,
43+
device=device,
44+
)
45+
elif lib == "dm_control":
46+
env = DMControlEnv(cfg.env.name, cfg.env.task)
47+
return TransformedEnv(
48+
env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
49+
)
50+
else:
51+
raise NotImplementedError(f"Unknown lib {lib}.")
3752

3853

3954
def apply_env_transforms(env, reward_scaling=1.0):
@@ -51,7 +66,7 @@ def make_environment(cfg, num_envs=1):
5166
"""Make environments for training and evaluation."""
5267
parallel_env = ParallelEnv(
5368
num_envs,
54-
EnvCreator(lambda: env_maker(task=cfg.env.name)),
69+
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
5570
)
5671
parallel_env.set_seed(cfg.env.seed)
5772

@@ -60,7 +75,7 @@ def make_environment(cfg, num_envs=1):
6075
eval_env = TransformedEnv(
6176
ParallelEnv(
6277
num_envs,
63-
EnvCreator(lambda: env_maker(task=cfg.env.name)),
78+
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
6479
),
6580
train_env.transform.clone(),
6681
)
@@ -80,7 +95,7 @@ def make_collector(cfg, train_env, actor_model_explore):
8095
frames_per_batch=cfg.collector.frames_per_batch,
8196
max_frames_per_traj=cfg.collector.max_frames_per_traj,
8297
total_frames=cfg.collector.total_frames,
83-
device=cfg.collector.collector_device,
98+
device=cfg.collector.device,
8499
)
85100
collector.set_seed(cfg.env.seed)
86101
return collector

examples/ddpg/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ collector:
1414
frames_per_batch: 1000
1515
init_env_steps: 1000
1616
reset_at_each_iter: False
17-
collector_device: cpu
17+
device: cpu
1818
env_per_collector: 1
1919

2020

0 commit comments

Comments
 (0)