Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit ba7b650

Browse files
committed
Indexing test
1 parent af5a27c commit ba7b650

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

ecml_tools/data.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,7 @@ def __repr__(self):
191191
return self.__class__.__name__ + "()"
192192

193193
def _get_tuple(self, n):
194-
warnings.warn(f"Naive tuple indexing used with {self}, likely to be slow.")
195-
first, rest = n[0], n[1:]
196-
return self[first][rest]
194+
raise NotImplementedError(f"Tuple not supported: {n} (class {self.__class__.__name__})")
197195

198196

199197
class Source:
@@ -932,11 +930,11 @@ def __init__(self, dataset, indices):
932930
super().__init__(dataset)
933931

934932
def __getitem__(self, n):
935-
if isinstance(n, tuple):
936-
return self._get_tuple(n)
933+
# if isinstance(n, tuple):
934+
# return self._get_tuple(n)
937935

938936
row = self.dataset[n]
939-
if isinstance(n, slice):
937+
if isinstance(n, (slice, tuple)):
940938
return row[:, self.indices]
941939

942940
return row[self.indices]

tests/test_data.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,15 @@ def same_stats(ds1, ds2, vars1, vars2=None):
173173
).all()
174174

175175

176+
class IndexTester:
177+
def __init__(self, ds):
178+
self.ds = ds
179+
self.np = ds[:] # Numpy array
180+
181+
def __getitem__(self, index):
182+
assert (self.ds[index] == self.np[index]).all()
183+
184+
176185
def slices(ds, start=None, end=None, step=None):
177186
if start is None:
178187
start = 5
@@ -181,22 +190,18 @@ def slices(ds, start=None, end=None, step=None):
181190
if step is None:
182191
step = len(ds) // 10
183192

184-
s = ds[start:end:step]
185-
186-
assert s[0].shape == ds[0].shape, (
187-
s.shape,
188-
ds.shape,
189-
len(list(range(start, end, step))),
190-
list(range(start, end, step)),
191-
)
193+
t = IndexTester(ds)
192194

193-
for i, n in enumerate(range(start, end, step)):
194-
assert (s[i] == ds[n]).all()
195+
t[start:end:step]
196+
t[start:end]
197+
t[start:]
198+
t[:end]
199+
t[::step]
195200

196-
ds[0:10, :, 0]
201+
t[0:10, :, 0]
197202

198-
if ds.shape[2] > 1:
199-
ds[0:10, :, np.array([1, 0])]
203+
if ds.shape[2] > 1: # Ensemble dimension
204+
t[0:10, :, (0, 1)]
200205

201206

202207
def make_row(args, ensemble=False, grid=False):

0 commit comments

Comments
 (0)