Skip to content

Commit 0436851

Browse files
louisfauryLouis Faury
authored andcommitted
[BugFix] GAE warning when gamma/lmbda are tensors (#2838)
Co-authored-by: Louis Faury <louis.faury@helsing.ai> (cherry picked from commit d561115)
1 parent c797855 commit 0436851

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

torchrl/objectives/value/advantages.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,8 +1284,18 @@ def __init__(
12841284
skip_existing=skip_existing,
12851285
device=device,
12861286
)
1287-
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
1288-
self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device))
1287+
self.register_buffer(
1288+
"gamma",
1289+
gamma.to(self._device)
1290+
if isinstance(gamma, Tensor)
1291+
else torch.tensor(gamma, device=self._device),
1292+
)
1293+
self.register_buffer(
1294+
"lmbda",
1295+
lmbda.to(self._device)
1296+
if isinstance(lmbda, Tensor)
1297+
else torch.tensor(lmbda, device=self._device),
1298+
)
12891299
self.average_gae = average_gae
12901300
self.vectorized = vectorized
12911301
self.time_dim = time_dim

0 commit comments

Comments
 (0)