1919from  itertools  import  accumulate , chain , pairwise 
2020from  math  import  floor 
2121from  pathlib  import  Path 
22- from  typing  import  TYPE_CHECKING , NamedTuple 
22+ from  typing  import  TYPE_CHECKING , Generic ,  NamedTuple ,  TypeVar 
2323
2424import  h5py 
2525import  numpy  as  np 
3333from  ..compat  import  (
3434    CSArray ,
3535    CSMatrix ,
36+     CupyArray ,
3637    CupyCSCMatrix ,
3738    CupyCSMatrix ,
3839    CupyCSRMatrix ,
5354
5455SCIPY_1_15  =  Version (scipy .__version__ ) >=  Version ("1.15rc0" )
5556
57+ DenseType  =  TypeVar ("DenseType" , np .ndarray , CupyArray )
5658
57- class  CompressedVectors (NamedTuple ):
58-     data : np .ndarray 
59-     indices : np .ndarray 
60-     indptr : np .ndarray 
59+ 
60+ class  CompressedVectors (NamedTuple , Generic [DenseType ]):
61+     data : DenseType 
62+     indices : DenseType 
63+     indptr : DenseType 
6164
6265
6366def  slice_len (s : slice , l : int ) ->  int :
@@ -153,7 +156,11 @@ def _get_contiguous_compressed_slice(self, s: slice) -> CompressedVectors:
153156        new_data : np .ndarray  =  self .data [start :stop ]
154157        new_indices : np .ndarray  =  self .indices [start :stop ]
155158
156-         return  CompressedVectors (new_data , new_indices , new_indptr )
159+         return  CompressedVectors (
160+             new_data ,
161+             new_indices ,
162+             new_indptr  if  isinstance (new_indptr , np .ndarray ) else  CupyArray (new_indptr ),
163+         )
157164
158165    def  get_compressed_vectors (self , row_idxs : Iterable [int ]) ->  CompressedVectors :
159166        indptr_slices  =  [slice (* (self .indptr [i  : i  +  2 ])) for  i  in  row_idxs ]
@@ -172,7 +179,11 @@ def get_compressed_vectors(self, row_idxs: Iterable[int]) -> CompressedVectors:
172179        indptr  =  np .array (
173180            list (accumulate (chain ((0 ,), (s .stop  -  s .start  for  s  in  indptr_slices ))))
174181        )
175-         return  CompressedVectors (data , indices , indptr )
182+         return  CompressedVectors (
183+             data ,
184+             indices ,
185+             indptr  if  isinstance (indptr , np .ndarray ) else  CupyArray (indptr ),
186+         )
176187
177188    def  get_compressed_vectors_for_slices (
178189        self , slices : Iterable [slice ]
@@ -194,19 +205,33 @@ def get_compressed_vectors_for_slices(
194205        offsets  =  accumulate (chain ([indptr_limits [0 ].start ], gaps ))
195206        start_indptr  =  indptr_indices [0 ] -  next (offsets )
196207        if  len (slices ) <  2 :  # there is only one slice so no need to concatenate 
197-             return  CompressedVectors (data , indices , start_indptr )
208+             return  CompressedVectors (
209+                 data ,
210+                 indices ,
211+                 start_indptr 
212+                 if  isinstance (start_indptr , np .ndarray )
213+                 else  CupyArray (start_indptr ),
214+             )
198215        end_indptr  =  np .concatenate (
199216            [s [1 :] -  o  for  s , o  in  zip (indptr_indices [1 :], offsets )]
200217        )
201218        indptr  =  np .concatenate ([start_indptr , end_indptr ])
202-         return  CompressedVectors (data , indices , indptr )
219+         return  CompressedVectors (
220+             data ,
221+             indices ,
222+             indptr  if  isinstance (indptr , np .ndarray ) else  CupyArray (indptr ),
223+         )
203224
204225    def  get_compressed_vector (self , idx : int ) ->  CompressedVectors :
205226        s  =  slice (* (self .indptr [idx  : idx  +  2 ]))
206227        data : np .ndarray  =  self .data [s ]
207228        indices : np .ndarray  =  self .indices [s ]
208-         indptr : np .ndarray  =  [0 , len (data )]
209-         return  CompressedVectors (data , indices , indptr )
229+         indptr : np .ndarray  =  np .array ([0 , len (data )])
230+         return  CompressedVectors (
231+             data ,
232+             indices ,
233+             indptr  if  isinstance (indptr , np .ndarray ) else  CupyArray (indptr ),
234+         )
210235
211236    def  __getitem__ (self , key ):
212237        if  isinstance (key , tuple ):
@@ -546,6 +571,8 @@ def _indptr(self) -> np.ndarray:
546571        It should therefore fit into memory, so we cache it for faster access. 
547572        """ 
548573        arr  =  self .group ["indptr" ][...]
574+         if  isinstance (arr , CupyArray ):
575+             arr  =  arr .get ()
549576        return  arr 
550577
551578    @cached_property  
0 commit comments