Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 9a54d4f

Browse files
committed
update
1 parent a731eb1 commit 9a54d4f

File tree

3 files changed

+81
-76
lines changed

3 files changed

+81
-76
lines changed

ecml_tools/data.py

Lines changed: 27 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import os
1212
import re
1313
import warnings
14-
from functools import cached_property
14+
from functools import cached_property, wraps
1515
from pathlib import PurePath
1616

1717
import numpy as np
@@ -22,6 +22,7 @@
2222

2323
from .indexing import (
2424
apply_index_to_slices_changes,
25+
expand_list_indexing,
2526
index_to_slices,
2627
length_to_slices,
2728
update_tuple,
@@ -37,6 +38,7 @@
3738

3839

3940
def _debug_indexing(method):
41+
@wraps(method)
4042
def wrapper(self, index):
4143
global DEPTH
4244
if isinstance(index, tuple):
@@ -224,6 +226,7 @@ def __repr__(self):
224226
return self.__class__.__name__ + "()"
225227

226228
@debug_indexing
229+
@expand_list_indexing
227230
def _get_tuple(self, n):
228231
raise NotImplementedError(
229232
f"Tuple not supported: {n} (class {self.__class__.__name__})"
@@ -381,30 +384,10 @@ def __len__(self):
381384
return self.data.shape[0]
382385

383386
@debug_indexing
387+
@expand_list_indexing
384388
def __getitem__(self, n):
385-
if isinstance(n, tuple) and any(not isinstance(i, (int, slice)) for i in n):
386-
return self._getitem_extended(n)
387-
388389
return self.data[n]
389390

390-
def _getitem_extended(self, index):
391-
"""
392-
Allows to use slices, lists, and tuples to select data from the dataset.
393-
Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
394-
"""
395-
396-
assert False, index
397-
398-
shape = self.data.shape
399-
400-
axes = []
401-
data = []
402-
for n in self._unwind(index[0], index[1:], shape, 0, axes):
403-
data.append(self.data[n])
404-
405-
assert len(axes) == 1, axes # Not implemented for more than one axis
406-
return np.concatenate(data, axis=axes[0])
407-
408391
def _unwind(self, index, rest, shape, axis, axes):
409392
if not isinstance(index, (int, slice, list, tuple)):
410393
try:
@@ -676,28 +659,20 @@ def __len__(self):
676659
return sum(len(i) for i in self.datasets)
677660

678661
@debug_indexing
662+
@expand_list_indexing
679663
def _get_tuple(self, index):
680664
index, changes = index_to_slices(index, self.shape)
681-
result = []
682-
683-
first, rest = index[0], index[1:]
684-
start, stop, step = first.start, first.stop, first.step
685-
686-
for d in self.datasets:
687-
length = d._len
688-
689-
result.append(d[(slice(start, stop, step),) + rest])
690-
691-
start -= length
692-
while start < 0:
693-
start += step
694-
695-
stop -= length
696-
697-
if start > stop:
698-
break
699-
700-
return apply_index_to_slices_changes(np.concatenate(result, axis=0), changes)
665+
print(index, changes)
666+
lengths = [d.shape[0] for d in self.datasets]
667+
slices = length_to_slices(index[0], lengths)
668+
print("slies", slices)
669+
result = [
670+
d[update_tuple(index, 0, i)[0]]
671+
for (d, i) in zip(self.datasets, slices)
672+
if i is not None
673+
]
674+
result = np.concatenate(result, axis=0)
675+
return apply_index_to_slices_changes(result, changes)
701676

702677
@debug_indexing
703678
def __getitem__(self, n):
@@ -718,21 +693,10 @@ def __getitem__(self, n):
718693
def _get_slice(self, s):
719694
result = []
720695

721-
start, stop, step = s.indices(self._len)
722-
723-
for d in self.datasets:
724-
length = d._len
725-
726-
result.append(d[start:stop:step])
727-
728-
start -= length
729-
while start < 0:
730-
start += step
731-
732-
stop -= length
696+
lengths = [d.shape[0] for d in self.datasets]
697+
slices = length_to_slices(s, lengths)
733698

734-
if start > stop:
735-
break
699+
result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]
736700

737701
return np.concatenate(result)
738702

@@ -783,13 +747,15 @@ def shape(self):
783747
return result
784748

785749
@debug_indexing
750+
@expand_list_indexing
786751
def _get_tuple(self, index):
787752
index, changes = index_to_slices(index, self.shape)
788753
lengths = [d.shape[self.axis] for d in self.datasets]
789754
slices = length_to_slices(index[self.axis], lengths)
790-
before = index[: self.axis]
791755
result = [
792-
d[before + (i,)] for (d, i) in zip(self.datasets, slices) if i is not None
756+
d[update_tuple(index, self.axis, i)[0]]
757+
for (d, i) in zip(self.datasets, slices)
758+
if i is not None
793759
]
794760
result = np.concatenate(result, axis=self.axis)
795761
return apply_index_to_slices_changes(result, changes)
@@ -850,6 +816,7 @@ def __len__(self):
850816
return len(self.datasets[0])
851817

852818
@debug_indexing
819+
@expand_list_indexing
853820
def _get_tuple(self, index):
854821
index, changes = index_to_slices(index, self.shape)
855822
index, previous = update_tuple(index, 1, slice(None))
@@ -977,6 +944,7 @@ def _get_slice(self, s):
977944
return np.stack([self.dataset[i] for i in indices])
978945

979946
@debug_indexing
947+
@expand_list_indexing
980948
def _get_tuple(self, n):
981949
index, changes = index_to_slices(n, self.shape)
982950
index, previous = update_tuple(index, 0, self.slice)
@@ -1024,6 +992,7 @@ def __init__(self, dataset, indices):
1024992
super().__init__(dataset)
1025993

1026994
@debug_indexing
995+
@expand_list_indexing
1027996
def _get_tuple(self, index):
1028997
index, changes = index_to_slices(index, self.shape)
1029998
index, previous = update_tuple(index, 1, slice(None))

ecml_tools/indexing.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# nor does it submit to any jurisdiction.
77

88

9+
from functools import wraps
10+
911
import numpy as np
1012

1113

@@ -109,21 +111,53 @@ def length_to_slices(index, lengths):
109111
return result
110112

111113

112-
class IndexTester:
113-
def __init__(self, shape):
114-
self.shape = shape
114+
def _as_tuples(index):
115+
def _(i):
116+
if hasattr(i, "tolist"):
117+
# NumPy arrays, TensorFlow tensors, etc.
118+
i = i.tolist()
119+
assert not isinstance(i[0], bool), "Mask not supported"
120+
return tuple(i)
121+
122+
if isinstance(i, list):
123+
return tuple(i)
124+
125+
return i
126+
127+
return tuple(_(i) for i in index)
128+
129+
130+
def expand_list_indexing(method):
131+
"""
132+
Allows to use slices, lists, and tuples to select data from the dataset.
133+
Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
134+
"""
135+
136+
@wraps(method)
137+
def wrapper(self, index):
138+
if not isinstance(index, tuple):
139+
return method(self, index)
140+
141+
if not any(isinstance(i, (list, tuple)) for i in index):
142+
return method(self, index)
143+
144+
which = []
145+
for i, idx in enumerate(index):
146+
if isinstance(idx, (list, tuple)):
147+
which.append(i)
148+
149+
assert which, "No list index found"
115150

116-
def __getitem__(self, index):
117-
return index_to_slices(index, self.shape)
151+
if len(which) > 1:
152+
raise IndexError("Only one list index is allowed")
118153

154+
which = which[0]
155+
index = _as_tuples(index)
156+
result = []
157+
for i in index[which]:
158+
index, _ = update_tuple(index, which, slice(i, i + 1))
159+
result.append(method(self, index))
119160

120-
if __name__ == "__main__":
121-
t = IndexTester((1000, 8, 10, 20000))
122-
i = t[0, 1, 2, 3]
123-
print(i)
161+
return np.concatenate(result, axis=which)
124162

125-
# print(t[0])
126-
# print(t[0, 1, 2, 3])
127-
# print(t[0:10])
128-
# print(t[...])
129-
# print(t[:-1])
163+
return wrapper

tests/test_data.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,16 +201,16 @@ def indexing(ds):
201201
t[0:10, 0:3, 0]
202202
t[:, :, :]
203203

204-
# t[:, (1, 3), :]
205-
# t[:, (1, 3)]
204+
t[:, (1, 3), :]
205+
t[:, (1, 3)]
206206

207207
t[0]
208208
t[0, :]
209209
t[0, 0, :]
210210
t[0, 0, 0, :]
211211

212-
# if ds.shape[2] > 1: # Ensemble dimension
213-
# t[0:10, :, (0, 1)]
212+
if ds.shape[2] > 1: # Ensemble dimension
213+
t[0:10, :, (0, 1)]
214214

215215

216216
def slices(ds, start=None, end=None, step=None):
@@ -1134,6 +1134,8 @@ def test_ensemble_1():
11341134
)
11351135
ds = test.ds
11361136

1137+
ds[0:10,:,(1,2)]
1138+
11371139
assert isinstance(ds, Ensemble)
11381140
assert len(ds) == 365 * 1 * 4
11391141
assert len([row for row in ds]) == len(ds)

0 commit comments

Comments
 (0)