Skip to content

Commit dd8547e

Browse files
committed
Allow for specifying a tensor device in AnnLoader
1 parent 3e340e1 commit dd8547e

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

anndata/experimental/pytorch/_annloader.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,17 @@ def __len__(self):
5454
return length
5555

5656

57-
# maybe replace use_cuda with explicit device option
58-
def default_converter(arr, use_cuda, pin_memory):
57+
def default_converter(arr, device, pin_memory):
5958
if isinstance(arr, torch.Tensor):
60-
if use_cuda:
61-
arr = arr.cuda()
59+
if device != "cpu":
60+
arr = arr.to(device)
6261
elif pin_memory:
6362
arr = arr.pin_memory()
6463
elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number):
6564
if issparse(arr):
6665
arr = arr.toarray()
67-
if use_cuda:
68-
arr = torch.tensor(arr, device="cuda")
66+
if device != "cpu":
67+
arr = torch.tensor(arr, device=device)
6968
else:
7069
arr = torch.tensor(arr)
7170
arr = arr.pin_memory() if pin_memory else arr
@@ -114,12 +113,15 @@ class AnnLoader(DataLoader):
114113
Set to `True` to have the data reshuffled at every epoch.
115114
use_default_converter
116115
Use the default converter to convert arrays to pytorch tensors, transfer to
117-
the default cuda device (if `use_cuda=True`), do memory pinning (if `pin_memory=True`).
116+
the specified device (if `device!=None`), do memory pinning (if `pin_memory=True`).
118117
If you pass an AnnCollection object with prespecified converters, the default converter
119118
won't overwrite these converters but will be applied on top of them.
120119
use_cuda
121120
Transfer pytorch tensors to the default cuda device after conversion.
122-
Only works if `use_default_converter=True`
121+
Only works if `use_default_converter=True`. DEPRECATED in favour of `device`.
122+
device
123+
The device to which to transfer pytorch tensors after conversion (example: "cuda").
124+
Only works if `use_default_converter=True`.
123125
**kwargs
124126
Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also
125127
arguments for `AnnCollection` initialization.
@@ -132,8 +134,16 @@ def __init__(
132134
shuffle: bool = False,
133135
use_default_converter: bool = True,
134136
use_cuda: bool = False,
137+
device: str | None = None,
135138
**kwargs,
136139
):
140+
if use_cuda:
141+
warn(
142+
"Argument use_cuda has been deprecated in favour of `device`. ",
143+
FutureWarning,
144+
)
145+
device = "cuda"
146+
137147
if isinstance(adatas, AnnData):
138148
adatas = [adatas]
139149

@@ -171,7 +181,7 @@ def __init__(
171181
if use_default_converter:
172182
pin_memory = kwargs.pop("pin_memory", False)
173183
_converter = partial(
174-
default_converter, use_cuda=use_cuda, pin_memory=pin_memory
184+
default_converter, device=device, pin_memory=pin_memory
175185
)
176186
dataset.convert = _convert_on_top(
177187
dataset.convert, _converter, dict(dataset.attrs_keys, X=[])

0 commit comments

Comments
 (0)