Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 214e3bc

Browse files
committed
Merge branch 'develop'
2 parents bdf2bfb + d14192d commit 214e3bc

File tree

4 files changed

+1053
-1051
lines changed

4 files changed

+1053
-1051
lines changed

ecml_tools/data.py

Lines changed: 127 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import os
1212
import re
1313
import warnings
14-
from functools import cached_property
14+
from functools import cached_property, wraps
1515
from pathlib import PurePath
1616

1717
import numpy as np
@@ -20,12 +20,47 @@
2020

2121
import ecml_tools
2222

23+
from .indexing import (
24+
apply_index_to_slices_changes,
25+
expand_list_indexing,
26+
index_to_slices,
27+
length_to_slices,
28+
update_tuple,
29+
)
30+
2331
LOG = logging.getLogger(__name__)
2432

2533
__all__ = ["open_dataset", "open_zarr", "debug_zarr_loading"]
2634

2735
DEBUG_ZARR_LOADING = int(os.environ.get("DEBUG_ZARR_LOADING", "0"))
2836

37+
DEPTH = 0
38+
39+
40+
def _debug_indexing(method):
41+
@wraps(method)
42+
def wrapper(self, index):
43+
global DEPTH
44+
if isinstance(index, tuple):
45+
print(" " * DEPTH, "->", self, method.__name__, index)
46+
DEPTH += 1
47+
result = method(self, index)
48+
DEPTH -= 1
49+
if isinstance(index, tuple):
50+
print(" " * DEPTH, "<-", self, method.__name__, result.shape)
51+
return result
52+
53+
return wrapper
54+
55+
56+
if True:
57+
58+
def debug_indexing(x):
59+
return x
60+
61+
else:
62+
debug_indexing = _debug_indexing
63+
2964

3065
def debug_zarr_loading(on_off):
3166
global DEBUG_ZARR_LOADING
@@ -190,11 +225,19 @@ def metadata_specific(self, **kwargs):
190225
def __repr__(self):
191226
return self.__class__.__name__ + "()"
192227

228+
@debug_indexing
229+
@expand_list_indexing
193230
def _get_tuple(self, n):
194-
raise NotImplementedError(f"Tuple not supported: {n} (class {self.__class__.__name__})")
231+
raise NotImplementedError(
232+
f"Tuple not supported: {n} (class {self.__class__.__name__})"
233+
)
195234

196235

197236
class Source:
237+
"""
238+
Class used to follow the provenance of a data point.
239+
"""
240+
198241
def __init__(self, dataset, index, source=None, info=None):
199242
self.dataset = dataset
200243
self.index = index
@@ -340,31 +383,11 @@ def __init__(self, path):
340383
def __len__(self):
341384
return self.data.shape[0]
342385

386+
@debug_indexing
387+
@expand_list_indexing
343388
def __getitem__(self, n):
344-
if isinstance(n, tuple) and any(not isinstance(i, (int, slice)) for i in n):
345-
return self._getitem_extended(n)
346-
347389
return self.data[n]
348390

349-
def _getitem_extended(self, index):
350-
"""
351-
Allows to use slices, lists, and tuples to select data from the dataset.
352-
Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
353-
"""
354-
355-
if not isinstance(index, tuple):
356-
return self[index]
357-
358-
shape = self.data.shape
359-
360-
axes = []
361-
data = []
362-
for n in self._unwind(index[0], index[1:], shape, 0, axes):
363-
data.append(self.data[n])
364-
365-
assert len(axes) == 1, axes # Not implemented for more than one axis
366-
return np.concatenate(data, axis=axes[0])
367-
368391
def _unwind(self, index, rest, shape, axis, axes):
369392
if not isinstance(index, (int, slice, list, tuple)):
370393
try:
@@ -377,7 +400,7 @@ def _unwind(self, index, rest, shape, axis, axes):
377400
if isinstance(index, (list, tuple)):
378401
axes.append(axis) # Dimension of the concatenation
379402
for i in index:
380-
yield from self._unwind(i, rest, shape, axis, axes)
403+
yield from self._unwind((slice(i, i + 1),), rest, shape, axis, axes)
381404
return
382405

383406
if len(rest) == 0:
@@ -635,6 +658,23 @@ class Concat(Combined):
635658
def __len__(self):
636659
return sum(len(i) for i in self.datasets)
637660

661+
@debug_indexing
662+
@expand_list_indexing
663+
def _get_tuple(self, index):
664+
index, changes = index_to_slices(index, self.shape)
665+
print(index, changes)
666+
lengths = [d.shape[0] for d in self.datasets]
667+
slices = length_to_slices(index[0], lengths)
668+
print("slies", slices)
669+
result = [
670+
d[update_tuple(index, 0, i)[0]]
671+
for (d, i) in zip(self.datasets, slices)
672+
if i is not None
673+
]
674+
result = np.concatenate(result, axis=0)
675+
return apply_index_to_slices_changes(result, changes)
676+
677+
@debug_indexing
638678
def __getitem__(self, n):
639679
if isinstance(n, tuple):
640680
return self._get_tuple(n)
@@ -649,24 +689,14 @@ def __getitem__(self, n):
649689
k += 1
650690
return self.datasets[k][n]
651691

692+
@debug_indexing
652693
def _get_slice(self, s):
653694
result = []
654695

655-
start, stop, step = s.indices(self._len)
656-
657-
for d in self.datasets:
658-
length = d._len
659-
660-
result.append(d[start:stop:step])
661-
662-
start -= length
663-
while start < 0:
664-
start += step
696+
lengths = [d.shape[0] for d in self.datasets]
697+
slices = length_to_slices(s, lengths)
665698

666-
stop -= length
667-
668-
if start > stop:
669-
break
699+
result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]
670700

671701
return np.concatenate(result)
672702

@@ -716,9 +746,25 @@ def shape(self):
716746
assert False not in result, result
717747
return result
718748

749+
@debug_indexing
750+
@expand_list_indexing
751+
def _get_tuple(self, index):
752+
index, changes = index_to_slices(index, self.shape)
753+
lengths = [d.shape[self.axis] for d in self.datasets]
754+
slices = length_to_slices(index[self.axis], lengths)
755+
result = [
756+
d[update_tuple(index, self.axis, i)[0]]
757+
for (d, i) in zip(self.datasets, slices)
758+
if i is not None
759+
]
760+
result = np.concatenate(result, axis=self.axis)
761+
return apply_index_to_slices_changes(result, changes)
762+
763+
@debug_indexing
719764
def _get_slice(self, s):
720765
return np.stack([self[i] for i in range(*s.indices(self._len))])
721766

767+
@debug_indexing
722768
def __getitem__(self, n):
723769
if isinstance(n, tuple):
724770
return self._get_tuple(n)
@@ -769,9 +815,23 @@ def check_same_variables(self, d1, d2):
769815
def __len__(self):
770816
return len(self.datasets[0])
771817

818+
@debug_indexing
819+
@expand_list_indexing
820+
def _get_tuple(self, index):
821+
index, changes = index_to_slices(index, self.shape)
822+
index, previous = update_tuple(index, 1, slice(None))
823+
824+
# TODO: optimize if index does not access all datasets, so we don't load chunks we don't need
825+
result = [d[index] for d in self.datasets]
826+
827+
result = np.concatenate(result, axis=1)
828+
return apply_index_to_slices_changes(result[:, previous], changes)
829+
830+
@debug_indexing
772831
def _get_slice(self, s):
773832
return np.stack([self[i] for i in range(*s.indices(self._len))])
774833

834+
@debug_indexing
775835
def __getitem__(self, n):
776836
if isinstance(n, tuple):
777837
return self._get_tuple(n)
@@ -857,10 +917,14 @@ def __init__(self, dataset, indices):
857917

858918
self.dataset = dataset
859919
self.indices = list(indices)
920+
self.slice = _make_slice_or_index_from_list_or_tuple(self.indices)
921+
assert isinstance(self.slice, slice)
922+
print("SUBSET", self.slice)
860923

861924
# Forward other properties to the super dataset
862925
super().__init__(dataset)
863926

927+
@debug_indexing
864928
def __getitem__(self, n):
865929
if isinstance(n, tuple):
866930
return self._get_tuple(n)
@@ -871,25 +935,23 @@ def __getitem__(self, n):
871935
n = self.indices[n]
872936
return self.dataset[n]
873937

938+
@debug_indexing
874939
def _get_slice(self, s):
875940
# TODO: check if the indices can be simplified to a slice
876941
# the time checking maybe be longer than the time saved
877942
# using a slice
878943
indices = [self.indices[i] for i in range(*s.indices(self._len))]
879944
return np.stack([self.dataset[i] for i in indices])
880945

946+
@debug_indexing
947+
@expand_list_indexing
881948
def _get_tuple(self, n):
882-
first, rest = n[0], n[1:]
883-
884-
if isinstance(first, int):
885-
return self.dataset[(self.indices[first],) + rest]
886-
887-
if isinstance(first, slice):
888-
indices = tuple(self.indices[i] for i in range(*first.indices(self._len)))
889-
indices = _make_slice_or_index_from_list_or_tuple(indices)
890-
return self.dataset[(indices,) + rest]
891-
892-
raise NotImplementedError(f"Only int and slice supported not {type(first)}")
949+
index, changes = index_to_slices(n, self.shape)
950+
index, previous = update_tuple(index, 0, self.slice)
951+
result = self.dataset[index]
952+
result = result[previous]
953+
result = apply_index_to_slices_changes(result, changes)
954+
return result
893955

894956
def __len__(self):
895957
return len(self.indices)
@@ -929,12 +991,24 @@ def __init__(self, dataset, indices):
929991
# Forward other properties to the main dataset
930992
super().__init__(dataset)
931993

994+
@debug_indexing
995+
@expand_list_indexing
996+
def _get_tuple(self, index):
997+
index, changes = index_to_slices(index, self.shape)
998+
index, previous = update_tuple(index, 1, slice(None))
999+
result = self.dataset[index]
1000+
result = result[:, self.indices]
1001+
result = result[:, previous]
1002+
result = apply_index_to_slices_changes(result, changes)
1003+
return result
1004+
1005+
@debug_indexing
9321006
def __getitem__(self, n):
933-
# if isinstance(n, tuple):
934-
# return self._get_tuple(n)
1007+
if isinstance(n, tuple):
1008+
return self._get_tuple(n)
9351009

9361010
row = self.dataset[n]
937-
if isinstance(n, (slice, tuple)):
1011+
if isinstance(n, slice):
9381012
return row[:, self.indices]
9391013

9401014
return row[self.indices]

0 commit comments

Comments
 (0)