We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 08c0570 commit 6cee33eCopy full SHA for 6cee33e
torchrl/objectives/ppo.py
@@ -80,7 +80,9 @@ def __init__(
80
)
81
self.register_buffer("gamma", torch.tensor(gamma, device=self.device))
82
self.loss_critic_type = loss_critic_type
83
- self.advantage_module = advantage_module.to(self.device)
+ self.advantage_module = advantage_module
84
+ if self.advantage_module is not None:
85
+ self.advantage_module = advantage_module.to(self.device)
86
87
def reset(self) -> None:
88
pass
0 commit comments