4
4
from pathlib import Path
5
5
import numpy as np
6
6
import torch
7
+ import re
7
8
8
9
9
10
def set_tensor_type (device = None , float_bits = 32 ):
10
11
"""Set the default torch tensor type to be used with neurodiffeq.
11
12
12
- :param device: Either "cpu" or "cuda" ("gpu"); defaults to "cuda" if available.
13
+ :param device: Either "cpu", "cuda" or "cuda:x " ("gpu") where "x" is the device number ; defaults to "cuda" if available.
13
14
:type device: str
14
15
:param float_bits: Length of float numbers. Either 32 (float) or 64 (double); defaults to 32.
15
16
:type float_bits: int
@@ -21,22 +22,23 @@ def set_tensor_type(device=None, float_bits=32):
21
22
if not isinstance (float_bits , int ):
22
23
raise ValueError (f"float_bits must be int, got { type (float_bits )} " )
23
24
if float_bits == 32 :
24
- tensor_str = "FloatTensor"
25
+ torch . set_default_dtype ( torch . float32 )
25
26
elif float_bits == 64 :
26
- tensor_str = "DoubleTensor"
27
+ torch . set_default_dtype ( torch . float64 )
27
28
else :
28
29
raise ValueError (f"float_bits must be 32 or 64, got { float_bits } " )
29
30
30
31
if device is None :
31
- device = "cuda" if torch .cuda .is_available () else "cpu"
32
- if device == "cpu" :
33
- type_string = f"torch.{ tensor_str } "
34
- elif device in ["cuda" , "gpu" ]:
35
- type_string = f"torch.cuda.{ tensor_str } "
36
- else :
37
- raise ValueError (f"Unknown device '{ device } '; device must be either 'cuda' or 'cpu'" )
38
-
39
- torch .set_default_tensor_type (type_string )
32
+ if torch .cuda .is_available ():
33
+ device = "cuda"
34
+ else :
35
+ device = "cpu"
36
+
37
+ cuda_regex = re .compile (r'cuda(?::\d+)?' )
38
+ if device != "cpu" and not cuda_regex .match (device ):
39
+ raise ValueError (f"Unknown device '{ device } '; device must be either 'cuda', 'cuda:x' where x is the device number, 'cpu'" )
40
+
41
+ torch .set_default_device (device )
40
42
41
43
42
44
def safe_mkdir (path ):
0 commit comments