Skip to content

Commit f6495f5

Browse files
authored
Merge pull request #205 from NeuroDiffGym/sb/use_metal_gpus
Support for different GPU device
2 parents 7c108c3 + 3f3a385 commit f6495f5

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

neurodiffeq/utils.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from pathlib import Path
55
import numpy as np
66
import torch
7+
import re
78

89

910
def set_tensor_type(device=None, float_bits=32):
1011
"""Set the default torch tensor type to be used with neurodiffeq.
1112
12-
:param device: Either "cpu" or "cuda" ("gpu"); defaults to "cuda" if available.
13+
:param device: Either "cpu", "cuda" or "cuda:x" ("gpu") where "x" is the device number; defaults to "cuda" if available.
1314
:type device: str
1415
:param float_bits: Length of float numbers. Either 32 (float) or 64 (double); defaults to 32.
1516
:type float_bits: int
@@ -21,22 +22,23 @@ def set_tensor_type(device=None, float_bits=32):
2122
if not isinstance(float_bits, int):
2223
raise ValueError(f"float_bits must be int, got {type(float_bits)}")
2324
if float_bits == 32:
24-
tensor_str = "FloatTensor"
25+
torch.set_default_dtype(torch.float32)
2526
elif float_bits == 64:
26-
tensor_str = "DoubleTensor"
27+
torch.set_default_dtype(torch.float64)
2728
else:
2829
raise ValueError(f"float_bits must be 32 or 64, got {float_bits}")
2930

3031
if device is None:
31-
device = "cuda" if torch.cuda.is_available() else "cpu"
32-
if device == "cpu":
33-
type_string = f"torch.{tensor_str}"
34-
elif device in ["cuda", "gpu"]:
35-
type_string = f"torch.cuda.{tensor_str}"
36-
else:
37-
raise ValueError(f"Unknown device '{device}'; device must be either 'cuda' or 'cpu'")
38-
39-
torch.set_default_tensor_type(type_string)
32+
if torch.cuda.is_available():
33+
device = "cuda"
34+
else:
35+
device = "cpu"
36+
37+
cuda_regex = re.compile(r'cuda(?::\d+)?')
38+
if device != "cpu" and not cuda_regex.match(device):
39+
raise ValueError(f"Unknown device '{device}'; device must be either 'cuda', 'cuda:x' where x is the device number, 'cpu'")
40+
41+
torch.set_default_device(device)
4042

4143

4244
def safe_mkdir(path):

0 commit comments

Comments
 (0)