11
11
import os
12
12
import re
13
13
import warnings
14
- from functools import cached_property
14
+ from functools import cached_property , wraps
15
15
from pathlib import PurePath
16
16
17
17
import numpy as np
22
22
23
23
from .indexing import (
24
24
apply_index_to_slices_changes ,
25
+ expand_list_indexing ,
25
26
index_to_slices ,
26
27
length_to_slices ,
27
28
update_tuple ,
37
38
38
39
39
40
def _debug_indexing (method ):
41
+ @wraps (method )
40
42
def wrapper (self , index ):
41
43
global DEPTH
42
44
if isinstance (index , tuple ):
@@ -224,6 +226,7 @@ def __repr__(self):
224
226
return self .__class__ .__name__ + "()"
225
227
226
228
@debug_indexing
229
+ @expand_list_indexing
227
230
def _get_tuple (self , n ):
228
231
raise NotImplementedError (
229
232
f"Tuple not supported: { n } (class { self .__class__ .__name__ } )"
@@ -381,30 +384,10 @@ def __len__(self):
381
384
return self .data .shape [0 ]
382
385
383
386
@debug_indexing
387
+ @expand_list_indexing
384
388
def __getitem__ (self , n ):
385
- if isinstance (n , tuple ) and any (not isinstance (i , (int , slice )) for i in n ):
386
- return self ._getitem_extended (n )
387
-
388
389
return self .data [n ]
389
390
390
- def _getitem_extended (self , index ):
391
- """
392
- Allows to use slices, lists, and tuples to select data from the dataset.
393
- Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
394
- """
395
-
396
- assert False , index
397
-
398
- shape = self .data .shape
399
-
400
- axes = []
401
- data = []
402
- for n in self ._unwind (index [0 ], index [1 :], shape , 0 , axes ):
403
- data .append (self .data [n ])
404
-
405
- assert len (axes ) == 1 , axes # Not implemented for more than one axis
406
- return np .concatenate (data , axis = axes [0 ])
407
-
408
391
def _unwind (self , index , rest , shape , axis , axes ):
409
392
if not isinstance (index , (int , slice , list , tuple )):
410
393
try :
@@ -676,28 +659,20 @@ def __len__(self):
676
659
return sum (len (i ) for i in self .datasets )
677
660
678
661
@debug_indexing
662
+ @expand_list_indexing
679
663
def _get_tuple (self , index ):
680
664
index , changes = index_to_slices (index , self .shape )
681
- result = []
682
-
683
- first , rest = index [0 ], index [1 :]
684
- start , stop , step = first .start , first .stop , first .step
685
-
686
- for d in self .datasets :
687
- length = d ._len
688
-
689
- result .append (d [(slice (start , stop , step ),) + rest ])
690
-
691
- start -= length
692
- while start < 0 :
693
- start += step
694
-
695
- stop -= length
696
-
697
- if start > stop :
698
- break
699
-
700
- return apply_index_to_slices_changes (np .concatenate (result , axis = 0 ), changes )
665
+ print (index , changes )
666
+ lengths = [d .shape [0 ] for d in self .datasets ]
667
+ slices = length_to_slices (index [0 ], lengths )
668
+ print ("slies" , slices )
669
+ result = [
670
+ d [update_tuple (index , 0 , i )[0 ]]
671
+ for (d , i ) in zip (self .datasets , slices )
672
+ if i is not None
673
+ ]
674
+ result = np .concatenate (result , axis = 0 )
675
+ return apply_index_to_slices_changes (result , changes )
701
676
702
677
@debug_indexing
703
678
def __getitem__ (self , n ):
@@ -718,21 +693,10 @@ def __getitem__(self, n):
718
693
def _get_slice (self , s ):
719
694
result = []
720
695
721
- start , stop , step = s .indices (self ._len )
722
-
723
- for d in self .datasets :
724
- length = d ._len
725
-
726
- result .append (d [start :stop :step ])
727
-
728
- start -= length
729
- while start < 0 :
730
- start += step
731
-
732
- stop -= length
696
+ lengths = [d .shape [0 ] for d in self .datasets ]
697
+ slices = length_to_slices (s , lengths )
733
698
734
- if start > stop :
735
- break
699
+ result = [d [i ] for (d , i ) in zip (self .datasets , slices ) if i is not None ]
736
700
737
701
return np .concatenate (result )
738
702
@@ -783,13 +747,15 @@ def shape(self):
783
747
return result
784
748
785
749
@debug_indexing
750
+ @expand_list_indexing
786
751
def _get_tuple (self , index ):
787
752
index , changes = index_to_slices (index , self .shape )
788
753
lengths = [d .shape [self .axis ] for d in self .datasets ]
789
754
slices = length_to_slices (index [self .axis ], lengths )
790
- before = index [: self .axis ]
791
755
result = [
792
- d [before + (i ,)] for (d , i ) in zip (self .datasets , slices ) if i is not None
756
+ d [update_tuple (index , self .axis , i )[0 ]]
757
+ for (d , i ) in zip (self .datasets , slices )
758
+ if i is not None
793
759
]
794
760
result = np .concatenate (result , axis = self .axis )
795
761
return apply_index_to_slices_changes (result , changes )
@@ -850,6 +816,7 @@ def __len__(self):
850
816
return len (self .datasets [0 ])
851
817
852
818
@debug_indexing
819
+ @expand_list_indexing
853
820
def _get_tuple (self , index ):
854
821
index , changes = index_to_slices (index , self .shape )
855
822
index , previous = update_tuple (index , 1 , slice (None ))
@@ -977,6 +944,7 @@ def _get_slice(self, s):
977
944
return np .stack ([self .dataset [i ] for i in indices ])
978
945
979
946
@debug_indexing
947
+ @expand_list_indexing
980
948
def _get_tuple (self , n ):
981
949
index , changes = index_to_slices (n , self .shape )
982
950
index , previous = update_tuple (index , 0 , self .slice )
@@ -1024,6 +992,7 @@ def __init__(self, dataset, indices):
1024
992
super ().__init__ (dataset )
1025
993
1026
994
@debug_indexing
995
+ @expand_list_indexing
1027
996
def _get_tuple (self , index ):
1028
997
index , changes = index_to_slices (index , self .shape )
1029
998
index , previous = update_tuple (index , 1 , slice (None ))
0 commit comments