Skip to content

Commit d30599e

Browse files
author
Vincent Moens
committed
[BugFix] action_spec_unbatched whenever necessary
ghstack-source-id: ec87794 Pull Request resolved: #2592
1 parent a47b32c commit d30599e

File tree

25 files changed

+191
-116
lines changed

25 files changed

+191
-116
lines changed

examples/distributed/collectors/multi_nodes/ray_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@
8585
in_keys=["loc", "scale"],
8686
distribution_class=TanhNormal,
8787
distribution_kwargs={
88-
"low": env.action_spec.space.low,
89-
"high": env.action_spec.space.high,
88+
"low": env.action_spec_unbatched.space.low,
89+
"high": env.action_spec_unbatched.space.high,
9090
},
9191
return_log_prob=True,
9292
)

sota-implementations/a2c/utils_atari.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment, device):
101101
num_outputs = proof_environment.action_spec.shape
102102
distribution_class = TanhNormal
103103
distribution_kwargs = {
104-
"low": proof_environment.action_spec.space.low.to(device),
105-
"high": proof_environment.action_spec.space.high.to(device),
104+
"low": proof_environment.action_spec_unbatched.space.low.to(device),
105+
"high": proof_environment.action_spec_unbatched.space.high.to(device),
106106
}
107107

108108
# Define input keys

sota-implementations/a2c/utils_mujoco.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
5757
num_outputs = proof_environment.action_spec.shape[-1]
5858
distribution_class = TanhNormal
5959
distribution_kwargs = {
60-
"low": proof_environment.action_spec.space.low.to(device),
61-
"high": proof_environment.action_spec.space.high.to(device),
60+
"low": proof_environment.action_spec_unbatched.space.low.to(device),
61+
"high": proof_environment.action_spec_unbatched.space.high.to(device),
6262
"tanh_loc": False,
6363
"safe_tanh": True,
6464
}

sota-implementations/cql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def make_offline_replay_buffer(rb_cfg):
191191
def make_cql_model(cfg, train_env, eval_env, device="cpu"):
192192
model_cfg = cfg.model
193193

194-
action_spec = train_env.action_spec
194+
action_spec = train_env.action_spec_unbatched
195195

196196
actor_net, q_net = make_cql_modules_state(model_cfg, eval_env)
197197
in_keys = ["observation"]

sota-implementations/crossq/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,7 @@ def make_crossQ_agent(cfg, train_env, device):
147147
"""Make CrossQ agent."""
148148
# Define Actor Network
149149
in_keys = ["observation"]
150-
action_spec = train_env.action_spec
151-
if train_env.batch_size:
152-
action_spec = action_spec[(0,) * len(train_env.batch_size)]
150+
action_spec = train_env.action_spec_unbatched
153151
actor_net_kwargs = {
154152
"num_cells": cfg.network.actor_hidden_sizes,
155153
"out_features": 2 * action_spec.shape[-1],

sota-implementations/decision_transformer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def make_dt_model(cfg):
393393
make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1
394394
)
395395

396-
action_spec = proof_environment.action_spec
396+
action_spec = proof_environment.action_spec_unbatched
397397
for key, value in proof_environment.observation_spec.items():
398398
if key == "observation":
399399
state_dim = value.shape[-1]

sota-implementations/dreamer/dreamer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121

2222
# mixed precision training
23-
from torch.cuda.amp import GradScaler
23+
from torch.amp import GradScaler
2424
from torch.nn.utils import clip_grad_norm_
2525
from torchrl._utils import logger as torchrl_logger, timeit
2626
from torchrl.envs.utils import ExplorationType, set_exploration_type
@@ -321,6 +321,14 @@ def compile_rssms(module):
321321

322322
t_collect_init = time.time()
323323

324+
test_env.close()
325+
train_env.close()
326+
collector.shutdown()
327+
328+
del test_env
329+
del train_env
330+
del collector
331+
324332

325333
if __name__ == "__main__":
326334
main()

sota-implementations/gail/ppo_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
5252
num_outputs = proof_environment.action_spec.shape[-1]
5353
distribution_class = TanhNormal
5454
distribution_kwargs = {
55-
"low": proof_environment.action_spec.space.low,
56-
"high": proof_environment.action_spec.space.high,
55+
"low": proof_environment.action_spec_unbatched.space.low,
56+
"high": proof_environment.action_spec_unbatched.space.high,
5757
"tanh_loc": False,
5858
}
5959

sota-implementations/iql/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"):
195195
model_cfg = cfg.model
196196

197197
in_keys = ["observation"]
198-
action_spec = train_env.action_spec
199-
if train_env.batch_size:
200-
action_spec = action_spec[(0,) * len(train_env.batch_size)]
198+
action_spec = train_env.action_spec_unbatched
201199
actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env)
202200

203201
out_keys = ["loc", "scale"]

sota-implementations/multiagent/iql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def train(cfg: "DictConfig"): # noqa: F821
7272
# Policy
7373
net = MultiAgentMLP(
7474
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
75-
n_agent_outputs=env.action_spec.space.n,
75+
n_agent_outputs=env.full_action_spec["agents", "action"].space.n,
7676
n_agents=env.n_agents,
7777
centralised=False,
7878
share_params=cfg.model.shared_parameters,
@@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
9191
("agents", "action_value"),
9292
("agents", "chosen_action_value"),
9393
],
94-
spec=env.unbatched_action_spec,
94+
spec=env.full_action_spec_unbatched,
9595
action_space=None,
9696
)
9797
qnet = SafeSequential(module, value_module)
@@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
103103
eps_end=0,
104104
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
105105
action_key=env.action_key,
106-
spec=env.unbatched_action_spec,
106+
spec=env.full_action_spec_unbatched,
107107
),
108108
)
109109

0 commit comments

Comments
 (0)