4
4
from functools import partial
5
5
from math import ceil
6
6
from typing import TYPE_CHECKING
7
+ from warnings import warn
7
8
8
9
import numpy as np
9
10
from scipy .sparse import issparse
@@ -54,18 +55,17 @@ def __len__(self):
54
55
return length
55
56
56
57
57
- # maybe replace use_cuda with explicit device option
58
- def default_converter (arr , use_cuda , pin_memory ):
58
+ def default_converter (arr , device , pin_memory ):
59
59
if isinstance (arr , torch .Tensor ):
60
- if use_cuda :
61
- arr = arr .cuda ( )
60
+ if device != "cpu" :
61
+ arr = arr .to ( device )
62
62
elif pin_memory :
63
63
arr = arr .pin_memory ()
64
64
elif arr .dtype .name != "category" and np .issubdtype (arr .dtype , np .number ):
65
65
if issparse (arr ):
66
66
arr = arr .toarray ()
67
- if use_cuda :
68
- arr = torch .tensor (arr , device = "cuda" )
67
+ if device != "cpu" :
68
+ arr = torch .tensor (arr , device = device )
69
69
else :
70
70
arr = torch .tensor (arr )
71
71
arr = arr .pin_memory () if pin_memory else arr
@@ -114,12 +114,15 @@ class AnnLoader(DataLoader):
114
114
Set to `True` to have the data reshuffled at every epoch.
115
115
use_default_converter
116
116
Use the default converter to convert arrays to pytorch tensors, transfer to
117
- the default cuda device (if `use_cuda=True `), do memory pinning (if `pin_memory=True`).
117
+ the specified device (if `device!=None `), do memory pinning (if `pin_memory=True`).
118
118
If you pass an AnnCollection object with prespecified converters, the default converter
119
119
won't overwrite these converters but will be applied on top of them.
120
120
use_cuda
121
121
Transfer pytorch tensors to the default cuda device after conversion.
122
- Only works if `use_default_converter=True`
122
+ Only works if `use_default_converter=True`. DEPRECATED in favour of `device`.
123
+ device
124
+ The device to which to transfer pytorch tensors after conversion (example: "cuda").
125
+ Only works if `use_default_converter=True`.
123
126
**kwargs
124
127
Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also
125
128
arguments for `AnnCollection` initialization.
@@ -132,8 +135,16 @@ def __init__(
132
135
shuffle : bool = False ,
133
136
use_default_converter : bool = True ,
134
137
use_cuda : bool = False ,
138
+ device : str | None = None ,
135
139
** kwargs ,
136
140
):
141
+ if use_cuda :
142
+ warn (
143
+ "Argument use_cuda has been deprecated in favour of `device`. " ,
144
+ FutureWarning ,
145
+ )
146
+ device = "cuda"
147
+
137
148
if isinstance (adatas , AnnData ):
138
149
adatas = [adatas ]
139
150
@@ -171,7 +182,7 @@ def __init__(
171
182
if use_default_converter :
172
183
pin_memory = kwargs .pop ("pin_memory" , False )
173
184
_converter = partial (
174
- default_converter , use_cuda = use_cuda , pin_memory = pin_memory
185
+ default_converter , device = device , pin_memory = pin_memory
175
186
)
176
187
dataset .convert = _convert_on_top (
177
188
dataset .convert , _converter , dict (dataset .attrs_keys , X = [])
0 commit comments