We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c797855 commit 0436851Copy full SHA for 0436851
torchrl/objectives/value/advantages.py
@@ -1284,8 +1284,18 @@ def __init__(
1284
skip_existing=skip_existing,
1285
device=device,
1286
)
1287
- self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
1288
- self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device))
+ self.register_buffer(
+ "gamma",
1289
+ gamma.to(self._device)
1290
+ if isinstance(gamma, Tensor)
1291
+ else torch.tensor(gamma, device=self._device),
1292
+ )
1293
1294
+ "lmbda",
1295
+ lmbda.to(self._device)
1296
+ if isinstance(lmbda, Tensor)
1297
+ else torch.tensor(lmbda, device=self._device),
1298
1299
self.average_gae = average_gae
1300
self.vectorized = vectorized
1301
self.time_dim = time_dim
0 commit comments