We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 73c7b0a commit d561115Copy full SHA for d561115
torchrl/objectives/value/advantages.py
@@ -1281,8 +1281,18 @@ def __init__(
1281
skip_existing=skip_existing,
1282
device=device,
1283
)
1284
- self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
1285
- self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device))
+ self.register_buffer(
+ "gamma",
1286
+ gamma.to(self._device)
1287
+ if isinstance(gamma, Tensor)
1288
+ else torch.tensor(gamma, device=self._device),
1289
+ )
1290
1291
+ "lmbda",
1292
+ lmbda.to(self._device)
1293
+ if isinstance(lmbda, Tensor)
1294
+ else torch.tensor(lmbda, device=self._device),
1295
1296
self.average_gae = average_gae
1297
self.vectorized = vectorized
1298
self.time_dim = time_dim
0 commit comments