Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit a92070c

Browse files
committed
Advanced indexing
1 parent 33fe1ee commit a92070c

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

ecml_tools/data.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,53 @@ def __len__(self):
317317
return self.data.shape[0]
318318

319319
def __getitem__(self, n):
320-
return self.data[n]
320+
try:
321+
return self.data[n]
322+
except IndexError:
323+
return self._getitem_extended(n)
324+
325+
def _getitem_extended(self, index):
326+
"""
327+
Allows to use slices, lists, and tuples to select data from the dataset.
328+
Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
329+
"""
330+
331+
if not isinstance(index, tuple):
332+
return self[index]
333+
334+
shape = self.data.shape
335+
336+
axes = []
337+
data = []
338+
for n in self._unwind(index[0], index[1:], shape, 0, axes):
339+
data.append(self[n])
340+
341+
assert len(axes) == 1, axes
342+
return np.concatenate(data, axis=axes[0])
343+
344+
def _unwind(self, index, rest, shape, axis, axes):
345+
# print(' ' * axis, '====>', index, '+', rest)
346+
347+
if not isinstance(index, (int, slice, list, tuple)):
348+
try:
349+
# NumPy arrays, TensorFlow tensors, etc.
350+
index = tuple(index.tolist())
351+
assert not isinstance(index, bool), "Mask not supported"
352+
except AttributeError:
353+
pass
354+
355+
if isinstance(index, (list, tuple)):
356+
axes.append(axis) # Dimension of the concatenation
357+
for i in index:
358+
yield from self._unwind(i, rest, shape, axis, axes)
359+
return
360+
361+
if len(rest) == 0:
362+
yield (index,)
363+
return
364+
365+
for n in self._unwind(rest[0], rest[1:], shape, axis + 1, axes):
366+
yield (index,) + n
321367

322368
@cached_property
323369
def chunks(self):

0 commit comments

Comments
 (0)