From 5636edf88e5c060dc341615ed248091f2efc03dd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 10 Jan 2025 09:17:08 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/objectives/cql.py | 2 +- torchrl/objectives/crossq.py | 2 +- torchrl/objectives/decision_transformer.py | 2 +- torchrl/objectives/deprecated.py | 2 +- torchrl/objectives/ppo.py | 2 +- torchrl/objectives/redq.py | 2 +- torchrl/objectives/sac.py | 4 ++-- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 6e056589a8c..894c8db5212 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -323,7 +323,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 22e84673641..8bd37f38c39 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -306,7 +306,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index a0d193acbfc..16e7b5212a1 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -103,7 +103,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 2a4124c80de..d4df68c6cb6 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -195,7 +195,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device)) self.register_buffer( diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 079a1efa92c..ad204bfe3e5 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -375,7 +375,7 @@ def __init__( try: device = next(self.parameters()).device except (AttributeError, StopIteration): - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) if critic_coef is not None: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index e234df1a512..68eafb834e6 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -309,7 +309,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) self.register_buffer( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index eae6b7feb34..66431b9c9a5 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -383,7 +383,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 @@ -1102,7 +1102,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha):