Skip to content

Commit b953702

Browse files
committed
Allow for specifying a tensor device in AnnLoader
1 parent 3e340e1 commit b953702

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

anndata/experimental/pytorch/_annloader.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import partial
55
from math import ceil
66
from typing import TYPE_CHECKING
7+
from warnings import warn
78

89
import numpy as np
910
from scipy.sparse import issparse
@@ -54,18 +55,17 @@ def __len__(self):
5455
return length
5556

5657

57-
# maybe replace use_cuda with explicit device option
58-
def default_converter(arr, use_cuda, pin_memory):
58+
def default_converter(arr, device, pin_memory):
5959
if isinstance(arr, torch.Tensor):
60-
if use_cuda:
61-
arr = arr.cuda()
60+
if device != "cpu":
61+
arr = arr.to(device)
6262
elif pin_memory:
6363
arr = arr.pin_memory()
6464
elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number):
6565
if issparse(arr):
6666
arr = arr.toarray()
67-
if use_cuda:
68-
arr = torch.tensor(arr, device="cuda")
67+
if device != "cpu":
68+
arr = torch.tensor(arr, device=device)
6969
else:
7070
arr = torch.tensor(arr)
7171
arr = arr.pin_memory() if pin_memory else arr
@@ -114,12 +114,15 @@ class AnnLoader(DataLoader):
114114
Set to `True` to have the data reshuffled at every epoch.
115115
use_default_converter
116116
Use the default converter to convert arrays to pytorch tensors, transfer to
117-
the default cuda device (if `use_cuda=True`), do memory pinning (if `pin_memory=True`).
117+
the specified device (if `device!=None`), do memory pinning (if `pin_memory=True`).
118118
If you pass an AnnCollection object with prespecified converters, the default converter
119119
won't overwrite these converters but will be applied on top of them.
120120
use_cuda
121121
Transfer pytorch tensors to the default cuda device after conversion.
122-
Only works if `use_default_converter=True`
122+
Only works if `use_default_converter=True`. DEPRECATED in favour of `device`.
123+
device
124+
The device to which to transfer pytorch tensors after conversion (example: "cuda").
125+
Only works if `use_default_converter=True`.
123126
**kwargs
124127
Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also
125128
arguments for `AnnCollection` initialization.
@@ -132,8 +135,16 @@ def __init__(
132135
shuffle: bool = False,
133136
use_default_converter: bool = True,
134137
use_cuda: bool = False,
138+
device: str | None = None,
135139
**kwargs,
136140
):
141+
if use_cuda:
142+
warn(
143+
"Argument use_cuda has been deprecated in favour of `device`. ",
144+
FutureWarning,
145+
)
146+
device = "cuda"
147+
137148
if isinstance(adatas, AnnData):
138149
adatas = [adatas]
139150

@@ -171,7 +182,7 @@ def __init__(
171182
if use_default_converter:
172183
pin_memory = kwargs.pop("pin_memory", False)
173184
_converter = partial(
174-
default_converter, use_cuda=use_cuda, pin_memory=pin_memory
185+
default_converter, device=device, pin_memory=pin_memory
175186
)
176187
dataset.convert = _convert_on_top(
177188
dataset.convert, _converter, dict(dataset.attrs_keys, X=[])

0 commit comments

Comments
 (0)