@@ -8783,6 +8783,7 @@ def test_ppo(
8783
8783
value,
8784
8784
loss_critic_type="l2",
8785
8785
functional=functional,
8786
+ device=device,
8786
8787
)
8787
8788
if composite_action_dist:
8788
8789
loss_fn.set_keys(
@@ -8883,6 +8884,7 @@ def test_ppo_composite_no_aggregate(
8883
8884
value,
8884
8885
loss_critic_type="l2",
8885
8886
functional=functional,
8887
+ device=device,
8886
8888
)
8887
8889
loss_fn.set_keys(
8888
8890
action=("action", "action1"),
@@ -8943,9 +8945,19 @@ def test_ppo_state_dict(
8943
8945
device=device, composite_action_dist=composite_action_dist
8944
8946
)
8945
8947
value = self._create_mock_value(device=device)
8946
- loss_fn = loss_class(actor, value, loss_critic_type="l2")
8948
+ loss_fn = loss_class(
8949
+ actor,
8950
+ value,
8951
+ loss_critic_type="l2",
8952
+ device=device,
8953
+ )
8947
8954
sd = loss_fn.state_dict()
8948
- loss_fn2 = loss_class(actor, value, loss_critic_type="l2")
8955
+ loss_fn2 = loss_class(
8956
+ actor,
8957
+ value,
8958
+ loss_critic_type="l2",
8959
+ device=device,
8960
+ )
8949
8961
loss_fn2.load_state_dict(sd)
8950
8962
8951
8963
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@@ -8993,6 +9005,7 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
8993
9005
value,
8994
9006
loss_critic_type="l2",
8995
9007
separate_losses=True,
9008
+ device=device,
8996
9009
)
8997
9010
8998
9011
if advantage is not None:
@@ -9100,6 +9113,7 @@ def test_ppo_shared_seq(
9100
9113
loss_critic_type="l2",
9101
9114
separate_losses=separate_losses,
9102
9115
entropy_coef=0.0,
9116
+ device=device,
9103
9117
)
9104
9118
9105
9119
loss_fn2 = loss_class(
@@ -9108,6 +9122,7 @@ def test_ppo_shared_seq(
9108
9122
loss_critic_type="l2",
9109
9123
separate_losses=separate_losses,
9110
9124
entropy_coef=0.0,
9125
+ device=device,
9111
9126
)
9112
9127
9113
9128
if advantage is not None:
@@ -9202,7 +9217,12 @@ def test_ppo_diff(
9202
9217
else:
9203
9218
raise NotImplementedError
9204
9219
9205
- loss_fn = loss_class(actor, value, loss_critic_type="l2")
9220
+ loss_fn = loss_class(
9221
+ actor,
9222
+ value,
9223
+ loss_critic_type="l2",
9224
+ device=device,
9225
+ )
9206
9226
9207
9227
params = TensorDict.from_module(loss_fn, as_module=True)
9208
9228
@@ -9595,6 +9615,7 @@ def test_ppo_value_clipping(
9595
9615
value,
9596
9616
loss_critic_type="l2",
9597
9617
clip_value=clip_value,
9618
+ device=device,
9598
9619
)
9599
9620
9600
9621
else:
@@ -9603,6 +9624,7 @@ def test_ppo_value_clipping(
9603
9624
value,
9604
9625
loss_critic_type="l2",
9605
9626
clip_value=clip_value,
9627
+ device=device,
9606
9628
)
9607
9629
advantage(td)
9608
9630
if composite_action_dist:
0 commit comments