Skip to content

Commit 705ecc2

Browse files
author
Vincent Moens
committed
[BugFix] Fix get_default_device calls in older PT versions
ghstack-source-id: fd3a739 Pull Request resolved: #2586
1 parent 236d38f commit 705ecc2

File tree

4 files changed

+6
-4
lines changed

4 files changed

+6
-4
lines changed

benchmarks/test_objectives_benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353
@pytest.fixture(scope="module", autouse=True)
5454
def set_default_device():
55-
cur_device = torch.get_default_device()
55+
cur_device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
5656
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
5757
torch.set_default_device(device)
5858
yield

torchrl/envs/libs/vmas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,9 @@ def _build_env(
803803
num_envs=num_envs,
804804
device=self.device
805805
if self.device is not None
806-
else torch.get_default_device(),
806+
else getattr(
807+
torch, "get_default_device", lambda: torch.device("cpu")
808+
)(),
807809
continuous_actions=continuous_actions,
808810
max_steps=max_steps,
809811
seed=seed,

torchrl/objectives/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def _get_default_device(net):
596596
for p in net.parameters():
597597
return p.device
598598
else:
599-
return torch.get_default_device()
599+
return getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
600600

601601

602602
def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer:

torchrl/objectives/value/advantages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(
219219
):
220220
super().__init__()
221221
if device is None:
222-
device = torch.get_default_device()
222+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
223223
# this is saved for tracking only and should not be used to cast anything else than buffers during
224224
# init.
225225
self._device = device

0 commit comments

Comments
 (0)