11
11
import os
12
12
import re
13
13
import warnings
14
- from functools import cached_property
14
+ from functools import cached_property , wraps
15
15
from pathlib import PurePath
16
16
17
17
import numpy as np
20
20
21
21
import ecml_tools
22
22
23
+ from .indexing import (
24
+ apply_index_to_slices_changes ,
25
+ expand_list_indexing ,
26
+ index_to_slices ,
27
+ length_to_slices ,
28
+ update_tuple ,
29
+ )
30
+
23
31
LOG = logging .getLogger (__name__ )
24
32
25
33
__all__ = ["open_dataset" , "open_zarr" , "debug_zarr_loading" ]
26
34
27
35
DEBUG_ZARR_LOADING = int (os .environ .get ("DEBUG_ZARR_LOADING" , "0" ))
28
36
37
+ DEPTH = 0
38
+
39
+
40
+ def _debug_indexing (method ):
41
+ @wraps (method )
42
+ def wrapper (self , index ):
43
+ global DEPTH
44
+ if isinstance (index , tuple ):
45
+ print (" " * DEPTH , "->" , self , method .__name__ , index )
46
+ DEPTH += 1
47
+ result = method (self , index )
48
+ DEPTH -= 1
49
+ if isinstance (index , tuple ):
50
+ print (" " * DEPTH , "<-" , self , method .__name__ , result .shape )
51
+ return result
52
+
53
+ return wrapper
54
+
55
+
56
+ if True :
57
+
58
+ def debug_indexing (x ):
59
+ return x
60
+
61
+ else :
62
+ debug_indexing = _debug_indexing
63
+
29
64
30
65
def debug_zarr_loading (on_off ):
31
66
global DEBUG_ZARR_LOADING
@@ -190,11 +225,19 @@ def metadata_specific(self, **kwargs):
190
225
def __repr__ (self ):
191
226
return self .__class__ .__name__ + "()"
192
227
228
+ @debug_indexing
229
+ @expand_list_indexing
193
230
def _get_tuple (self , n ):
194
- raise NotImplementedError (f"Tuple not supported: { n } (class { self .__class__ .__name__ } )" )
231
+ raise NotImplementedError (
232
+ f"Tuple not supported: { n } (class { self .__class__ .__name__ } )"
233
+ )
195
234
196
235
197
236
class Source :
237
+ """
238
+ Class used to follow the provenance of a data point.
239
+ """
240
+
198
241
def __init__ (self , dataset , index , source = None , info = None ):
199
242
self .dataset = dataset
200
243
self .index = index
@@ -340,31 +383,11 @@ def __init__(self, path):
340
383
def __len__ (self ):
341
384
return self .data .shape [0 ]
342
385
386
+ @debug_indexing
387
+ @expand_list_indexing
343
388
def __getitem__ (self , n ):
344
- if isinstance (n , tuple ) and any (not isinstance (i , (int , slice )) for i in n ):
345
- return self ._getitem_extended (n )
346
-
347
389
return self .data [n ]
348
390
349
- def _getitem_extended (self , index ):
350
- """
351
- Allows to use slices, lists, and tuples to select data from the dataset.
352
- Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
353
- """
354
-
355
- if not isinstance (index , tuple ):
356
- return self [index ]
357
-
358
- shape = self .data .shape
359
-
360
- axes = []
361
- data = []
362
- for n in self ._unwind (index [0 ], index [1 :], shape , 0 , axes ):
363
- data .append (self .data [n ])
364
-
365
- assert len (axes ) == 1 , axes # Not implemented for more than one axis
366
- return np .concatenate (data , axis = axes [0 ])
367
-
368
391
def _unwind (self , index , rest , shape , axis , axes ):
369
392
if not isinstance (index , (int , slice , list , tuple )):
370
393
try :
@@ -377,7 +400,7 @@ def _unwind(self, index, rest, shape, axis, axes):
377
400
if isinstance (index , (list , tuple )):
378
401
axes .append (axis ) # Dimension of the concatenation
379
402
for i in index :
380
- yield from self ._unwind (i , rest , shape , axis , axes )
403
+ yield from self ._unwind (( slice ( i , i + 1 ),) , rest , shape , axis , axes )
381
404
return
382
405
383
406
if len (rest ) == 0 :
@@ -635,6 +658,23 @@ class Concat(Combined):
635
658
def __len__ (self ):
636
659
return sum (len (i ) for i in self .datasets )
637
660
661
+ @debug_indexing
662
+ @expand_list_indexing
663
+ def _get_tuple (self , index ):
664
+ index , changes = index_to_slices (index , self .shape )
665
+ print (index , changes )
666
+ lengths = [d .shape [0 ] for d in self .datasets ]
667
+ slices = length_to_slices (index [0 ], lengths )
668
+ print ("slies" , slices )
669
+ result = [
670
+ d [update_tuple (index , 0 , i )[0 ]]
671
+ for (d , i ) in zip (self .datasets , slices )
672
+ if i is not None
673
+ ]
674
+ result = np .concatenate (result , axis = 0 )
675
+ return apply_index_to_slices_changes (result , changes )
676
+
677
+ @debug_indexing
638
678
def __getitem__ (self , n ):
639
679
if isinstance (n , tuple ):
640
680
return self ._get_tuple (n )
@@ -649,24 +689,14 @@ def __getitem__(self, n):
649
689
k += 1
650
690
return self .datasets [k ][n ]
651
691
692
+ @debug_indexing
652
693
def _get_slice (self , s ):
653
694
result = []
654
695
655
- start , stop , step = s .indices (self ._len )
656
-
657
- for d in self .datasets :
658
- length = d ._len
659
-
660
- result .append (d [start :stop :step ])
661
-
662
- start -= length
663
- while start < 0 :
664
- start += step
696
+ lengths = [d .shape [0 ] for d in self .datasets ]
697
+ slices = length_to_slices (s , lengths )
665
698
666
- stop -= length
667
-
668
- if start > stop :
669
- break
699
+ result = [d [i ] for (d , i ) in zip (self .datasets , slices ) if i is not None ]
670
700
671
701
return np .concatenate (result )
672
702
@@ -716,9 +746,25 @@ def shape(self):
716
746
assert False not in result , result
717
747
return result
718
748
749
+ @debug_indexing
750
+ @expand_list_indexing
751
+ def _get_tuple (self , index ):
752
+ index , changes = index_to_slices (index , self .shape )
753
+ lengths = [d .shape [self .axis ] for d in self .datasets ]
754
+ slices = length_to_slices (index [self .axis ], lengths )
755
+ result = [
756
+ d [update_tuple (index , self .axis , i )[0 ]]
757
+ for (d , i ) in zip (self .datasets , slices )
758
+ if i is not None
759
+ ]
760
+ result = np .concatenate (result , axis = self .axis )
761
+ return apply_index_to_slices_changes (result , changes )
762
+
763
+ @debug_indexing
719
764
def _get_slice (self , s ):
720
765
return np .stack ([self [i ] for i in range (* s .indices (self ._len ))])
721
766
767
+ @debug_indexing
722
768
def __getitem__ (self , n ):
723
769
if isinstance (n , tuple ):
724
770
return self ._get_tuple (n )
@@ -769,9 +815,23 @@ def check_same_variables(self, d1, d2):
769
815
def __len__ (self ):
770
816
return len (self .datasets [0 ])
771
817
818
+ @debug_indexing
819
+ @expand_list_indexing
820
+ def _get_tuple (self , index ):
821
+ index , changes = index_to_slices (index , self .shape )
822
+ index , previous = update_tuple (index , 1 , slice (None ))
823
+
824
+ # TODO: optimize if index does not access all datasets, so we don't load chunks we don't need
825
+ result = [d [index ] for d in self .datasets ]
826
+
827
+ result = np .concatenate (result , axis = 1 )
828
+ return apply_index_to_slices_changes (result [:, previous ], changes )
829
+
830
+ @debug_indexing
772
831
def _get_slice (self , s ):
773
832
return np .stack ([self [i ] for i in range (* s .indices (self ._len ))])
774
833
834
+ @debug_indexing
775
835
def __getitem__ (self , n ):
776
836
if isinstance (n , tuple ):
777
837
return self ._get_tuple (n )
@@ -857,10 +917,14 @@ def __init__(self, dataset, indices):
857
917
858
918
self .dataset = dataset
859
919
self .indices = list (indices )
920
+ self .slice = _make_slice_or_index_from_list_or_tuple (self .indices )
921
+ assert isinstance (self .slice , slice )
922
+ print ("SUBSET" , self .slice )
860
923
861
924
# Forward other properties to the super dataset
862
925
super ().__init__ (dataset )
863
926
927
+ @debug_indexing
864
928
def __getitem__ (self , n ):
865
929
if isinstance (n , tuple ):
866
930
return self ._get_tuple (n )
@@ -871,25 +935,23 @@ def __getitem__(self, n):
871
935
n = self .indices [n ]
872
936
return self .dataset [n ]
873
937
938
+ @debug_indexing
874
939
def _get_slice (self , s ):
875
940
# TODO: check if the indices can be simplified to a slice
876
941
# the time checking maybe be longer than the time saved
877
942
# using a slice
878
943
indices = [self .indices [i ] for i in range (* s .indices (self ._len ))]
879
944
return np .stack ([self .dataset [i ] for i in indices ])
880
945
946
+ @debug_indexing
947
+ @expand_list_indexing
881
948
def _get_tuple (self , n ):
882
- first , rest = n [0 ], n [1 :]
883
-
884
- if isinstance (first , int ):
885
- return self .dataset [(self .indices [first ],) + rest ]
886
-
887
- if isinstance (first , slice ):
888
- indices = tuple (self .indices [i ] for i in range (* first .indices (self ._len )))
889
- indices = _make_slice_or_index_from_list_or_tuple (indices )
890
- return self .dataset [(indices ,) + rest ]
891
-
892
- raise NotImplementedError (f"Only int and slice supported not { type (first )} " )
949
+ index , changes = index_to_slices (n , self .shape )
950
+ index , previous = update_tuple (index , 0 , self .slice )
951
+ result = self .dataset [index ]
952
+ result = result [previous ]
953
+ result = apply_index_to_slices_changes (result , changes )
954
+ return result
893
955
894
956
def __len__ (self ):
895
957
return len (self .indices )
@@ -929,12 +991,24 @@ def __init__(self, dataset, indices):
929
991
# Forward other properties to the main dataset
930
992
super ().__init__ (dataset )
931
993
994
+ @debug_indexing
995
+ @expand_list_indexing
996
+ def _get_tuple (self , index ):
997
+ index , changes = index_to_slices (index , self .shape )
998
+ index , previous = update_tuple (index , 1 , slice (None ))
999
+ result = self .dataset [index ]
1000
+ result = result [:, self .indices ]
1001
+ result = result [:, previous ]
1002
+ result = apply_index_to_slices_changes (result , changes )
1003
+ return result
1004
+
1005
+ @debug_indexing
932
1006
def __getitem__ (self , n ):
933
- # if isinstance(n, tuple):
934
- # return self._get_tuple(n)
1007
+ if isinstance (n , tuple ):
1008
+ return self ._get_tuple (n )
935
1009
936
1010
row = self .dataset [n ]
937
- if isinstance (n , ( slice , tuple ) ):
1011
+ if isinstance (n , slice ):
938
1012
return row [:, self .indices ]
939
1013
940
1014
return row [self .indices ]
0 commit comments