Skip to content

Commit eac05cf

Browse files
committed
(fix): bring to gpu as needed
1 parent 0c08653 commit eac05cf

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

src/anndata/_core/sparse_dataset.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from itertools import accumulate, chain, pairwise
2020
from math import floor
2121
from pathlib import Path
22-
from typing import TYPE_CHECKING, NamedTuple
22+
from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
2323

2424
import h5py
2525
import numpy as np
@@ -33,6 +33,7 @@
3333
from ..compat import (
3434
CSArray,
3535
CSMatrix,
36+
CupyArray,
3637
CupyCSCMatrix,
3738
CupyCSMatrix,
3839
CupyCSRMatrix,
@@ -53,11 +54,20 @@
5354

5455
SCIPY_1_15 = Version(scipy.__version__) >= Version("1.15rc0")
5556

57+
DenseType = TypeVar("DenseType", np.ndarray, CupyArray)
5658

57-
class CompressedVectors(NamedTuple):
58-
data: np.ndarray
59-
indices: np.ndarray
60-
indptr: np.ndarray
59+
60+
class CompressedVectors(NamedTuple, Generic[DenseType]):
61+
data: DenseType
62+
indices: DenseType
63+
indptr: DenseType
64+
65+
def __new__(
66+
cls, data: DenseType, indices: DenseType, indptr: np.ndarray, *args, **kwargs
67+
):
68+
if isinstance(data, CupyArray):
69+
indptr = CupyArray(indptr)
70+
return NamedTuple.__new__(cls, data, indices, indptr, *args, **kwargs)
6171

6272

6373
def slice_len(s: slice, l: int) -> int:
@@ -546,6 +556,8 @@ def _indptr(self) -> np.ndarray:
546556
It should therefore fit into memory, so we cache it for faster access.
547557
"""
548558
arr = self.group["indptr"][...]
559+
if isinstance(arr, CupyArray):
560+
arr = arr.get()
549561
return arr
550562

551563
@cached_property

0 commit comments

Comments
 (0)