Skip to content

Commit 5b67dd3

Browse files
author
Vincent Moens
authored
[Feature] Non-functional objectives (PPO, A2C, Reinforce) (#1804)
1 parent 6769fee commit 5b67dd3

File tree

16 files changed

+433
-157
lines changed

16 files changed

+433
-157
lines changed

benchmarks/test_objectives_benchmarks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def test_a2c_speed(
548548
actor(td.clone())
549549
critic(td.clone())
550550

551-
loss = A2CLoss(actor=actor, critic=critic)
551+
loss = A2CLoss(actor_network=actor, critic_network=critic)
552552
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
553553
advantage(td)
554554
loss(td)
@@ -605,7 +605,7 @@ def test_ppo_speed(
605605
actor(td.clone())
606606
critic(td.clone())
607607

608-
loss = ClipPPOLoss(actor=actor, critic=critic)
608+
loss = ClipPPOLoss(actor_network=actor, critic_network=critic)
609609
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
610610
advantage(td)
611611
loss(td)
@@ -662,7 +662,7 @@ def test_reinforce_speed(
662662
actor(td.clone())
663663
critic(td.clone())
664664

665-
loss = ReinforceLoss(actor=actor, critic=critic)
665+
loss = ReinforceLoss(actor_network=actor, critic_network=critic)
666666
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
667667
advantage(td)
668668
loss(td)

examples/a2c/a2c_atari.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def main(cfg: "DictConfig"): # noqa: F821
6969
average_gae=True,
7070
)
7171
loss_module = A2CLoss(
72-
actor=actor,
73-
critic=critic,
72+
actor_network=actor,
73+
critic_network=critic,
7474
loss_critic_type=cfg.loss.loss_critic_type,
7575
entropy_coef=cfg.loss.entropy_coef,
7676
critic_coef=cfg.loss.critic_coef,

examples/a2c/a2c_mujoco.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821
6363
average_gae=False,
6464
)
6565
loss_module = A2CLoss(
66-
actor=actor,
67-
critic=critic,
66+
actor_network=actor,
67+
critic_network=critic,
6868
loss_critic_type=cfg.loss.loss_critic_type,
6969
entropy_coef=cfg.loss.entropy_coef,
7070
critic_coef=cfg.loss.critic_coef,

examples/distributed/collectors/multi_nodes/ray_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@
145145
)
146146
loss_module = ClipPPOLoss(
147147
actor=policy_module,
148-
critic=value_module,
148+
critic_network=value_module,
149149
advantage_key="advantage",
150150
clip_epsilon=clip_epsilon,
151151
entropy_bonus=bool(entropy_eps),

examples/impala/impala_multi_node_ray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def main(cfg: "DictConfig"): # noqa: F821
114114
average_adv=False,
115115
)
116116
loss_module = A2CLoss(
117-
actor=actor,
118-
critic=critic,
117+
actor_network=actor,
118+
critic_network=critic,
119119
loss_critic_type=cfg.loss.loss_critic_type,
120120
entropy_coef=cfg.loss.entropy_coef,
121121
critic_coef=cfg.loss.critic_coef,

examples/impala/impala_multi_node_submitit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def main(cfg: "DictConfig"): # noqa: F821
106106
average_adv=False,
107107
)
108108
loss_module = A2CLoss(
109-
actor=actor,
110-
critic=critic,
109+
actor_network=actor,
110+
critic_network=critic,
111111
loss_critic_type=cfg.loss.loss_critic_type,
112112
entropy_coef=cfg.loss.entropy_coef,
113113
critic_coef=cfg.loss.critic_coef,

examples/impala/impala_single_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def main(cfg: "DictConfig"): # noqa: F821
8484
average_adv=False,
8585
)
8686
loss_module = A2CLoss(
87-
actor=actor,
88-
critic=critic,
87+
actor_network=actor,
88+
critic_network=critic,
8989
loss_critic_type=cfg.loss.loss_critic_type,
9090
entropy_coef=cfg.loss.entropy_coef,
9191
critic_coef=cfg.loss.critic_coef,

examples/multiagent/mappo_ippo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def train(cfg: "DictConfig"): # noqa: F821
137137

138138
# Loss
139139
loss_module = ClipPPOLoss(
140-
actor=policy,
141-
critic=value_module,
140+
actor_network=policy,
141+
critic_network=value_module,
142142
clip_epsilon=cfg.loss.clip_epsilon,
143143
entropy_coef=cfg.loss.entropy_eps,
144144
normalize_advantage=False,
@@ -174,7 +174,7 @@ def train(cfg: "DictConfig"): # noqa: F821
174174
with torch.no_grad():
175175
loss_module.value_estimator(
176176
tensordict_data,
177-
params=loss_module.critic_params,
177+
params=loss_module.critic_network_params,
178178
target_params=loss_module.target_critic_params,
179179
)
180180
current_frames = tensordict_data.numel()

examples/ppo/ppo_atari.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821
7070
average_gae=False,
7171
)
7272
loss_module = ClipPPOLoss(
73-
actor=actor,
74-
critic=critic,
73+
actor_network=actor,
74+
critic_network=critic,
7575
clip_epsilon=cfg.loss.clip_epsilon,
7676
loss_critic_type=cfg.loss.loss_critic_type,
7777
entropy_coef=cfg.loss.entropy_coef,

examples/ppo/ppo_mujoco.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821
7070
)
7171

7272
loss_module = ClipPPOLoss(
73-
actor=actor,
74-
critic=critic,
73+
actor_network=actor,
74+
critic_network=critic,
7575
clip_epsilon=cfg.loss.clip_epsilon,
7676
loss_critic_type=cfg.loss.loss_critic_type,
7777
entropy_coef=cfg.loss.entropy_coef,

0 commit comments

Comments
 (0)