Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit af5a27c

Browse files
committed
Improved indexing
1 parent a92070c commit af5a27c

File tree

2 files changed

+47
-10
lines changed

2 files changed

+47
-10
lines changed

ecml_tools/data.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
import re
13+
import warnings
1314
from functools import cached_property
1415
from pathlib import PurePath
1516

@@ -23,14 +24,31 @@
2324

2425
__all__ = ["open_dataset", "open_zarr", "debug_zarr_loading"]
2526

26-
DEBUG_ZARR_LOADING = False
27+
DEBUG_ZARR_LOADING = int(os.environ.get("DEBUG_ZARR_LOADING", "0"))
2728

2829

2930
def debug_zarr_loading(on_off):
3031
global DEBUG_ZARR_LOADING
3132
DEBUG_ZARR_LOADING = on_off
3233

3334

35+
def _make_slice_or_index_from_list_or_tuple(indices):
36+
"""
37+
Convert a list or tuple of indices to a slice or an index, if possible.
38+
"""
39+
if len(indices) == 1:
40+
return indices[0]
41+
42+
step = indices[1] - indices[0]
43+
44+
if step > 0 and all(
45+
indices[i] - indices[i - 1] == step for i in range(1, len(indices))
46+
):
47+
return slice(indices[0], indices[-1] + step, step)
48+
49+
return indices
50+
51+
3452
class Dataset:
3553
arguments = {}
3654

@@ -173,6 +191,7 @@ def __repr__(self):
173191
return self.__class__.__name__ + "()"
174192

175193
def _get_tuple(self, n):
194+
warnings.warn(f"Naive tuple indexing used with {self}, likely to be slow.")
176195
first, rest = n[0], n[1:]
177196
return self[first][rest]
178197

@@ -267,18 +286,25 @@ def __getitem__(self, key):
267286

268287
class DebugStore(ReadOnlyStore):
269288
def __init__(self, store):
289+
assert not isinstance(store, DebugStore)
270290
self.store = store
271291

272292
def __getitem__(self, key):
273-
print("GET", key)
293+
# print()
294+
print("GET", key, self)
295+
# traceback.print_stack(file=sys.stdout)
274296
return self.store[key]
275297

276298
def __len__(self):
277299
return len(self.store)
278300

279301
def __iter__(self):
302+
warnings.warn("DebugStore: iterating over the store")
280303
return iter(self.store)
281304

305+
def __contains__(self, key):
306+
return key in self.store
307+
282308

283309
def open_zarr(path):
284310
try:
@@ -317,11 +343,11 @@ def __len__(self):
317343
return self.data.shape[0]
318344

319345
def __getitem__(self, n):
320-
try:
321-
return self.data[n]
322-
except IndexError:
346+
if isinstance(n, tuple) and any(not isinstance(i, (int, slice)) for i in n):
323347
return self._getitem_extended(n)
324348

349+
return self.data[n]
350+
325351
def _getitem_extended(self, index):
326352
"""
327353
Allows to use slices, lists, and tuples to select data from the dataset.
@@ -336,14 +362,12 @@ def _getitem_extended(self, index):
336362
axes = []
337363
data = []
338364
for n in self._unwind(index[0], index[1:], shape, 0, axes):
339-
data.append(self[n])
365+
data.append(self.data[n])
340366

341-
assert len(axes) == 1, axes
367+
assert len(axes) == 1, axes # Not implemented for more than one axis
342368
return np.concatenate(data, axis=axes[0])
343369

344370
def _unwind(self, index, rest, shape, axis, axes):
345-
# print(' ' * axis, '====>', index, '+', rest)
346-
347371
if not isinstance(index, (int, slice, list, tuple)):
348372
try:
349373
# NumPy arrays, TensorFlow tensors, etc.
@@ -856,6 +880,19 @@ def _get_slice(self, s):
856880
indices = [self.indices[i] for i in range(*s.indices(self._len))]
857881
return np.stack([self.dataset[i] for i in indices])
858882

883+
def _get_tuple(self, n):
884+
first, rest = n[0], n[1:]
885+
886+
if isinstance(first, int):
887+
return self.dataset[(self.indices[first],) + rest]
888+
889+
if isinstance(first, slice):
890+
indices = tuple(self.indices[i] for i in range(*first.indices(self._len)))
891+
indices = _make_slice_or_index_from_list_or_tuple(indices)
892+
return self.dataset[(indices,) + rest]
893+
894+
raise NotImplementedError(f"Only int and slice supported not {type(first)}")
895+
859896
def __len__(self):
860897
return len(self.indices)
861898

tests/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def slices(ds, start=None, end=None, step=None):
193193
for i, n in enumerate(range(start, end, step)):
194194
assert (s[i] == ds[n]).all()
195195

196-
x = ds[0:10, :, 0]
196+
ds[0:10, :, 0]
197197

198198
if ds.shape[2] > 1:
199199
ds[0:10, :, np.array([1, 0])]

0 commit comments

Comments
 (0)