10
10
import logging
11
11
import os
12
12
import re
13
+ import warnings
13
14
from functools import cached_property
14
15
from pathlib import PurePath
15
16
23
24
24
25
__all__ = ["open_dataset" , "open_zarr" , "debug_zarr_loading" ]
25
26
26
- DEBUG_ZARR_LOADING = False
27
+ DEBUG_ZARR_LOADING = int ( os . environ . get ( "DEBUG_ZARR_LOADING" , "0" ))
27
28
28
29
29
30
def debug_zarr_loading (on_off ):
30
31
global DEBUG_ZARR_LOADING
31
32
DEBUG_ZARR_LOADING = on_off
32
33
33
34
35
+ def _make_slice_or_index_from_list_or_tuple (indices ):
36
+ """
37
+ Convert a list or tuple of indices to a slice or an index, if possible.
38
+ """
39
+ if len (indices ) == 1 :
40
+ return indices [0 ]
41
+
42
+ step = indices [1 ] - indices [0 ]
43
+
44
+ if step > 0 and all (
45
+ indices [i ] - indices [i - 1 ] == step for i in range (1 , len (indices ))
46
+ ):
47
+ return slice (indices [0 ], indices [- 1 ] + step , step )
48
+
49
+ return indices
50
+
51
+
34
52
class Dataset :
35
53
arguments = {}
36
54
@@ -173,6 +191,7 @@ def __repr__(self):
173
191
return self .__class__ .__name__ + "()"
174
192
175
193
def _get_tuple (self , n ):
194
+ warnings .warn (f"Naive tuple indexing used with { self } , likely to be slow." )
176
195
first , rest = n [0 ], n [1 :]
177
196
return self [first ][rest ]
178
197
@@ -267,18 +286,25 @@ def __getitem__(self, key):
267
286
268
287
class DebugStore (ReadOnlyStore ):
269
288
def __init__ (self , store ):
289
+ assert not isinstance (store , DebugStore )
270
290
self .store = store
271
291
272
292
def __getitem__ (self , key ):
273
- print ("GET" , key )
293
+ # print()
294
+ print ("GET" , key , self )
295
+ # traceback.print_stack(file=sys.stdout)
274
296
return self .store [key ]
275
297
276
298
def __len__ (self ):
277
299
return len (self .store )
278
300
279
301
def __iter__ (self ):
302
+ warnings .warn ("DebugStore: iterating over the store" )
280
303
return iter (self .store )
281
304
305
+ def __contains__ (self , key ):
306
+ return key in self .store
307
+
282
308
283
309
def open_zarr (path ):
284
310
try :
@@ -317,11 +343,11 @@ def __len__(self):
317
343
return self .data .shape [0 ]
318
344
319
345
def __getitem__ (self , n ):
320
- try :
321
- return self .data [n ]
322
- except IndexError :
346
+ if isinstance (n , tuple ) and any (not isinstance (i , (int , slice )) for i in n ):
323
347
return self ._getitem_extended (n )
324
348
349
+ return self .data [n ]
350
+
325
351
def _getitem_extended (self , index ):
326
352
"""
327
353
Allows to use slices, lists, and tuples to select data from the dataset.
@@ -336,14 +362,12 @@ def _getitem_extended(self, index):
336
362
axes = []
337
363
data = []
338
364
for n in self ._unwind (index [0 ], index [1 :], shape , 0 , axes ):
339
- data .append (self [n ])
365
+ data .append (self . data [n ])
340
366
341
- assert len (axes ) == 1 , axes
367
+ assert len (axes ) == 1 , axes # Not implemented for more than one axis
342
368
return np .concatenate (data , axis = axes [0 ])
343
369
344
370
def _unwind (self , index , rest , shape , axis , axes ):
345
- # print(' ' * axis, '====>', index, '+', rest)
346
-
347
371
if not isinstance (index , (int , slice , list , tuple )):
348
372
try :
349
373
# NumPy arrays, TensorFlow tensors, etc.
@@ -856,6 +880,19 @@ def _get_slice(self, s):
856
880
indices = [self .indices [i ] for i in range (* s .indices (self ._len ))]
857
881
return np .stack ([self .dataset [i ] for i in indices ])
858
882
883
+ def _get_tuple (self , n ):
884
+ first , rest = n [0 ], n [1 :]
885
+
886
+ if isinstance (first , int ):
887
+ return self .dataset [(self .indices [first ],) + rest ]
888
+
889
+ if isinstance (first , slice ):
890
+ indices = tuple (self .indices [i ] for i in range (* first .indices (self ._len )))
891
+ indices = _make_slice_or_index_from_list_or_tuple (indices )
892
+ return self .dataset [(indices ,) + rest ]
893
+
894
+ raise NotImplementedError (f"Only int and slice supported not { type (first )} " )
895
+
859
896
def __len__ (self ):
860
897
return len (self .indices )
861
898
0 commit comments