Skip to content

Commit d561115

Browse files
louisfauryLouis Faury
andauthored
[BugFix] GAE warning when gamma/lmbda are tensors (#2838)
Co-authored-by: Louis Faury <louis.faury@helsing.ai>
1 parent 73c7b0a commit d561115

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

torchrl/objectives/value/advantages.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,8 +1281,18 @@ def __init__(
12811281
skip_existing=skip_existing,
12821282
device=device,
12831283
)
1284-
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
1285-
self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device))
1284+
self.register_buffer(
1285+
"gamma",
1286+
gamma.to(self._device)
1287+
if isinstance(gamma, Tensor)
1288+
else torch.tensor(gamma, device=self._device),
1289+
)
1290+
self.register_buffer(
1291+
"lmbda",
1292+
lmbda.to(self._device)
1293+
if isinstance(lmbda, Tensor)
1294+
else torch.tensor(lmbda, device=self._device),
1295+
)
12861296
self.average_gae = average_gae
12871297
self.vectorized = vectorized
12881298
self.time_dim = time_dim

0 commit comments

Comments
 (0)