19
19
from itertools import accumulate , chain , pairwise
20
20
from math import floor
21
21
from pathlib import Path
22
- from typing import TYPE_CHECKING , NamedTuple
22
+ from typing import TYPE_CHECKING , Generic , NamedTuple , TypeVar
23
23
24
24
import h5py
25
25
import numpy as np
33
33
from ..compat import (
34
34
CSArray ,
35
35
CSMatrix ,
36
+ CupyArray ,
36
37
CupyCSCMatrix ,
37
38
CupyCSMatrix ,
38
39
CupyCSRMatrix ,
53
54
54
55
SCIPY_1_15 = Version (scipy .__version__ ) >= Version ("1.15rc0" )
55
56
57
+ DenseType = TypeVar ("DenseType" , np .ndarray , CupyArray )
56
58
57
- class CompressedVectors (NamedTuple ):
58
- data : np .ndarray
59
- indices : np .ndarray
60
- indptr : np .ndarray
59
+
60
+ class CompressedVectors (NamedTuple , Generic [DenseType ]):
61
+ data : DenseType
62
+ indices : DenseType
63
+ indptr : DenseType
61
64
62
65
63
66
def slice_len (s : slice , l : int ) -> int :
@@ -153,7 +156,11 @@ def _get_contiguous_compressed_slice(self, s: slice) -> CompressedVectors:
153
156
new_data : np .ndarray = self .data [start :stop ]
154
157
new_indices : np .ndarray = self .indices [start :stop ]
155
158
156
- return CompressedVectors (new_data , new_indices , new_indptr )
159
+ return CompressedVectors (
160
+ new_data ,
161
+ new_indices ,
162
+ new_indptr if isinstance (new_indptr , np .ndarray ) else CupyArray (new_indptr ),
163
+ )
157
164
158
165
def get_compressed_vectors (self , row_idxs : Iterable [int ]) -> CompressedVectors :
159
166
indptr_slices = [slice (* (self .indptr [i : i + 2 ])) for i in row_idxs ]
@@ -172,7 +179,11 @@ def get_compressed_vectors(self, row_idxs: Iterable[int]) -> CompressedVectors:
172
179
indptr = np .array (
173
180
list (accumulate (chain ((0 ,), (s .stop - s .start for s in indptr_slices ))))
174
181
)
175
- return CompressedVectors (data , indices , indptr )
182
+ return CompressedVectors (
183
+ data ,
184
+ indices ,
185
+ indptr if isinstance (indptr , np .ndarray ) else CupyArray (indptr ),
186
+ )
176
187
177
188
def get_compressed_vectors_for_slices (
178
189
self , slices : Iterable [slice ]
@@ -194,19 +205,33 @@ def get_compressed_vectors_for_slices(
194
205
offsets = accumulate (chain ([indptr_limits [0 ].start ], gaps ))
195
206
start_indptr = indptr_indices [0 ] - next (offsets )
196
207
if len (slices ) < 2 : # there is only one slice so no need to concatenate
197
- return CompressedVectors (data , indices , start_indptr )
208
+ return CompressedVectors (
209
+ data ,
210
+ indices ,
211
+ start_indptr
212
+ if isinstance (start_indptr , np .ndarray )
213
+ else CupyArray (start_indptr ),
214
+ )
198
215
end_indptr = np .concatenate (
199
216
[s [1 :] - o for s , o in zip (indptr_indices [1 :], offsets )]
200
217
)
201
218
indptr = np .concatenate ([start_indptr , end_indptr ])
202
- return CompressedVectors (data , indices , indptr )
219
+ return CompressedVectors (
220
+ data ,
221
+ indices ,
222
+ indptr if isinstance (indptr , np .ndarray ) else CupyArray (indptr ),
223
+ )
203
224
204
225
def get_compressed_vector (self , idx : int ) -> CompressedVectors :
205
226
s = slice (* (self .indptr [idx : idx + 2 ]))
206
227
data : np .ndarray = self .data [s ]
207
228
indices : np .ndarray = self .indices [s ]
208
- indptr : np .ndarray = [0 , len (data )]
209
- return CompressedVectors (data , indices , indptr )
229
+ indptr : np .ndarray = np .array ([0 , len (data )])
230
+ return CompressedVectors (
231
+ data ,
232
+ indices ,
233
+ indptr if isinstance (indptr , np .ndarray ) else CupyArray (indptr ),
234
+ )
210
235
211
236
def __getitem__ (self , key ):
212
237
if isinstance (key , tuple ):
@@ -546,6 +571,8 @@ def _indptr(self) -> np.ndarray:
546
571
It should therefore fit into memory, so we cache it for faster access.
547
572
"""
548
573
arr = self .group ["indptr" ][...]
574
+ if isinstance (arr , CupyArray ):
575
+ arr = arr .get ()
549
576
return arr
550
577
551
578
@cached_property
0 commit comments