Skip to content

Commit 6ca216e

Browse files
author
Vincent Moens
committed
[Test] Fix device in PPO tests
ghstack-source-id: efe21e7 Pull-Request-resolved: #2971
1 parent 96e7a9c commit 6ca216e

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

test/test_cost.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8783,6 +8783,7 @@ def test_ppo(
87838783
value,
87848784
loss_critic_type="l2",
87858785
functional=functional,
8786+
device=device,
87868787
)
87878788
if composite_action_dist:
87888789
loss_fn.set_keys(
@@ -8883,6 +8884,7 @@ def test_ppo_composite_no_aggregate(
88838884
value,
88848885
loss_critic_type="l2",
88858886
functional=functional,
8887+
device=device,
88868888
)
88878889
loss_fn.set_keys(
88888890
action=("action", "action1"),
@@ -8943,9 +8945,19 @@ def test_ppo_state_dict(
89438945
device=device, composite_action_dist=composite_action_dist
89448946
)
89458947
value = self._create_mock_value(device=device)
8946-
loss_fn = loss_class(actor, value, loss_critic_type="l2")
8948+
loss_fn = loss_class(
8949+
actor,
8950+
value,
8951+
loss_critic_type="l2",
8952+
device=device,
8953+
)
89478954
sd = loss_fn.state_dict()
8948-
loss_fn2 = loss_class(actor, value, loss_critic_type="l2")
8955+
loss_fn2 = loss_class(
8956+
actor,
8957+
value,
8958+
loss_critic_type="l2",
8959+
device=device,
8960+
)
89498961
loss_fn2.load_state_dict(sd)
89508962

89518963
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@@ -8993,6 +9005,7 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
89939005
value,
89949006
loss_critic_type="l2",
89959007
separate_losses=True,
9008+
device=device,
89969009
)
89979010

89989011
if advantage is not None:
@@ -9100,6 +9113,7 @@ def test_ppo_shared_seq(
91009113
loss_critic_type="l2",
91019114
separate_losses=separate_losses,
91029115
entropy_coef=0.0,
9116+
device=device,
91039117
)
91049118

91059119
loss_fn2 = loss_class(
@@ -9108,6 +9122,7 @@ def test_ppo_shared_seq(
91089122
loss_critic_type="l2",
91099123
separate_losses=separate_losses,
91109124
entropy_coef=0.0,
9125+
device=device,
91119126
)
91129127

91139128
if advantage is not None:
@@ -9202,7 +9217,12 @@ def test_ppo_diff(
92029217
else:
92039218
raise NotImplementedError
92049219

9205-
loss_fn = loss_class(actor, value, loss_critic_type="l2")
9220+
loss_fn = loss_class(
9221+
actor,
9222+
value,
9223+
loss_critic_type="l2",
9224+
device=device,
9225+
)
92069226

92079227
params = TensorDict.from_module(loss_fn, as_module=True)
92089228

@@ -9595,6 +9615,7 @@ def test_ppo_value_clipping(
95959615
value,
95969616
loss_critic_type="l2",
95979617
clip_value=clip_value,
9618+
device=device,
95989619
)
95999620

96009621
else:
@@ -9603,6 +9624,7 @@ def test_ppo_value_clipping(
96039624
value,
96049625
loss_critic_type="l2",
96059626
clip_value=clip_value,
9627+
device=device,
96069628
)
96079629
advantage(td)
96089630
if composite_action_dist:

torchrl/objectives/ppo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,9 @@ def __init__(
440440
raise ValueError(
441441
f"clip_value must be a float or a scalar tensor, got {clip_value}."
442442
)
443-
self.register_buffer("clip_value", clip_value)
443+
self.register_buffer("clip_value", clip_value.to(device))
444+
else:
445+
self.clip_value = None
444446
try:
445447
log_prob_keys = self.actor_network.log_prob_keys
446448
action_keys = self.actor_network.dist_sample_keys

0 commit comments

Comments
 (0)