@@ -54,18 +54,17 @@ def __len__(self):
54
54
return length
55
55
56
56
57
- # maybe replace use_cuda with explicit device option
58
- def default_converter (arr , use_cuda , pin_memory ):
57
+ def default_converter (arr , device , pin_memory ):
59
58
if isinstance (arr , torch .Tensor ):
60
- if use_cuda :
61
- arr = arr .cuda ( )
59
+ if device != "cpu" :
60
+ arr = arr .to ( device )
62
61
elif pin_memory :
63
62
arr = arr .pin_memory ()
64
63
elif arr .dtype .name != "category" and np .issubdtype (arr .dtype , np .number ):
65
64
if issparse (arr ):
66
65
arr = arr .toarray ()
67
- if use_cuda :
68
- arr = torch .tensor (arr , device = "cuda" )
66
+ if device != "cpu" :
67
+ arr = torch .tensor (arr , device = device )
69
68
else :
70
69
arr = torch .tensor (arr )
71
70
arr = arr .pin_memory () if pin_memory else arr
@@ -114,12 +113,15 @@ class AnnLoader(DataLoader):
114
113
Set to `True` to have the data reshuffled at every epoch.
115
114
use_default_converter
116
115
Use the default converter to convert arrays to pytorch tensors, transfer to
117
- the default cuda device (if `use_cuda=True `), do memory pinning (if `pin_memory=True`).
116
+ the specified device (if `device!=None `), do memory pinning (if `pin_memory=True`).
118
117
If you pass an AnnCollection object with prespecified converters, the default converter
119
118
won't overwrite these converters but will be applied on top of them.
120
119
use_cuda
121
120
Transfer pytorch tensors to the default cuda device after conversion.
122
- Only works if `use_default_converter=True`
121
+ Only works if `use_default_converter=True`. DEPRECATED in favour of `device`.
122
+ device
123
+ The device to which to transfer pytorch tensors after conversion (example: "cuda").
124
+ Only works if `use_default_converter=True`.
123
125
**kwargs
124
126
Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also
125
127
arguments for `AnnCollection` initialization.
@@ -132,8 +134,16 @@ def __init__(
132
134
shuffle : bool = False ,
133
135
use_default_converter : bool = True ,
134
136
use_cuda : bool = False ,
137
+ device : str | None = None ,
135
138
** kwargs ,
136
139
):
140
+ if use_cuda :
141
+ warn (
142
+ "Argument use_cuda has been deprecated in favour of `device`. " ,
143
+ FutureWarning ,
144
+ )
145
+ device = "cuda"
146
+
137
147
if isinstance (adatas , AnnData ):
138
148
adatas = [adatas ]
139
149
@@ -171,7 +181,7 @@ def __init__(
171
181
if use_default_converter :
172
182
pin_memory = kwargs .pop ("pin_memory" , False )
173
183
_converter = partial (
174
- default_converter , use_cuda = use_cuda , pin_memory = pin_memory
184
+ default_converter , device = device , pin_memory = pin_memory
175
185
)
176
186
dataset .convert = _convert_on_top (
177
187
dataset .convert , _converter , dict (dataset .attrs_keys , X = [])
0 commit comments