Skip to content

Commit a12700b

Browse files
committed
(fix): bring to gpu as needed
1 parent 0c08653 commit a12700b

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

src/anndata/_core/sparse_dataset.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from itertools import accumulate, chain, pairwise
2020
from math import floor
2121
from pathlib import Path
22-
from typing import TYPE_CHECKING, NamedTuple
22+
from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
2323

2424
import h5py
2525
import numpy as np
@@ -33,6 +33,7 @@
3333
from ..compat import (
3434
CSArray,
3535
CSMatrix,
36+
CupyArray,
3637
CupyCSCMatrix,
3738
CupyCSMatrix,
3839
CupyCSRMatrix,
@@ -53,11 +54,13 @@
5354

5455
SCIPY_1_15 = Version(scipy.__version__) >= Version("1.15rc0")
5556

57+
DenseType = TypeVar("DenseType", np.ndarray, CupyArray)
5658

57-
class CompressedVectors(NamedTuple):
58-
data: np.ndarray
59-
indices: np.ndarray
60-
indptr: np.ndarray
59+
60+
class CompressedVectors(NamedTuple, Generic[DenseType]):
61+
data: DenseType
62+
indices: DenseType
63+
indptr: DenseType
6164

6265

6366
def slice_len(s: slice, l: int) -> int:
@@ -153,7 +156,11 @@ def _get_contiguous_compressed_slice(self, s: slice) -> CompressedVectors:
153156
new_data: np.ndarray = self.data[start:stop]
154157
new_indices: np.ndarray = self.indices[start:stop]
155158

156-
return CompressedVectors(new_data, new_indices, new_indptr)
159+
return CompressedVectors(
160+
new_data,
161+
new_indices,
162+
new_indptr if isinstance(new_indptr, np.ndarray) else CupyArray(new_indptr),
163+
)
157164

158165
def get_compressed_vectors(self, row_idxs: Iterable[int]) -> CompressedVectors:
159166
indptr_slices = [slice(*(self.indptr[i : i + 2])) for i in row_idxs]
@@ -172,7 +179,11 @@ def get_compressed_vectors(self, row_idxs: Iterable[int]) -> CompressedVectors:
172179
indptr = np.array(
173180
list(accumulate(chain((0,), (s.stop - s.start for s in indptr_slices))))
174181
)
175-
return CompressedVectors(data, indices, indptr)
182+
return CompressedVectors(
183+
data,
184+
indices,
185+
indptr if isinstance(indptr, np.ndarray) else CupyArray(indptr),
186+
)
176187

177188
def get_compressed_vectors_for_slices(
178189
self, slices: Iterable[slice]
@@ -194,19 +205,33 @@ def get_compressed_vectors_for_slices(
194205
offsets = accumulate(chain([indptr_limits[0].start], gaps))
195206
start_indptr = indptr_indices[0] - next(offsets)
196207
if len(slices) < 2: # there is only one slice so no need to concatenate
197-
return CompressedVectors(data, indices, start_indptr)
208+
return CompressedVectors(
209+
data,
210+
indices,
211+
start_indptr
212+
if isinstance(start_indptr, np.ndarray)
213+
else CupyArray(start_indptr),
214+
)
198215
end_indptr = np.concatenate(
199216
[s[1:] - o for s, o in zip(indptr_indices[1:], offsets)]
200217
)
201218
indptr = np.concatenate([start_indptr, end_indptr])
202-
return CompressedVectors(data, indices, indptr)
219+
return CompressedVectors(
220+
data,
221+
indices,
222+
indptr if isinstance(indptr, np.ndarray) else CupyArray(indptr),
223+
)
203224

204225
def get_compressed_vector(self, idx: int) -> CompressedVectors:
205226
s = slice(*(self.indptr[idx : idx + 2]))
206227
data: np.ndarray = self.data[s]
207228
indices: np.ndarray = self.indices[s]
208-
indptr: np.ndarray = [0, len(data)]
209-
return CompressedVectors(data, indices, indptr)
229+
indptr: np.ndarray = np.array([0, len(data)])
230+
return CompressedVectors(
231+
data,
232+
indices,
233+
indptr if isinstance(indptr, np.ndarray) else CupyArray(indptr),
234+
)
210235

211236
def __getitem__(self, key):
212237
if isinstance(key, tuple):
@@ -546,6 +571,8 @@ def _indptr(self) -> np.ndarray:
546571
It should therefore fit into memory, so we cache it for faster access.
547572
"""
548573
arr = self.group["indptr"][...]
574+
if isinstance(arr, CupyArray):
575+
arr = arr.get()
549576
return arr
550577

551578
@cached_property

0 commit comments

Comments
 (0)