Skip to content

Commit 91f66fb

Browse files
committed
Allow for specifying a tensor device in AnnLoader
1 parent b953702 commit 91f66fb

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

anndata/experimental/pytorch/_annloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ def __len__(self):
5757

5858
def default_converter(arr, device, pin_memory):
5959
if isinstance(arr, torch.Tensor):
60-
if device != "cpu":
60+
if device is not None:
6161
arr = arr.to(device)
6262
elif pin_memory:
6363
arr = arr.pin_memory()
6464
elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number):
6565
if issparse(arr):
6666
arr = arr.toarray()
67-
if device != "cpu":
67+
if device is not None:
6868
arr = torch.tensor(arr, device=device)
6969
else:
7070
arr = torch.tensor(arr)

docs/release-notes/0.10.4.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
### 0.10.4 {small}`the future`
22

3+
```{rubric} New features
4+
```
5+
* `AnnLoader` now accepts a `device` argument to specify the device to load the data to {pr}`1240` {user}`austinv11`
6+
37
```{rubric} Bugfix
48
```
59
* Only try to use `Categorical.map(na_action=…)` in actually supported Pandas ≥2.1 {pr}`1226` {user}`flying-sheep`
@@ -10,3 +14,7 @@
1014

1115
```{rubric} Performance
1216
```
17+
18+
```{rubric} Deprecations
19+
```
20+
* `AnnLoader(use_cuda=…)` is deprecated in favour of `AnnLoader(device=…)` {pr}`1240` {user}`austinv11

0 commit comments

Comments
 (0)