Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit a69d94b

Browse files
committed
Work on indexing
1 parent ca53f34 commit a69d94b

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

ecml_tools/data.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def metadata_specific(self, **kwargs):
165165
def __repr__(self):
166166
return self.__class__.__name__ + "()"
167167

168+
def _get_tuple(self, n):
169+
first, rest = n[0], n[1:]
170+
return self[first][rest]
171+
168172

169173
class Source:
170174
def __init__(self, dataset, index, source=None, info=None):
@@ -287,13 +291,17 @@ def __len__(self):
287291
def __getitem__(self, n):
288292
return self.data[n]
289293

294+
@cached_property
295+
def chunks(self):
296+
return self.z.data.chunks
297+
290298
@cached_property
291299
def shape(self):
292300
return self.data.shape
293301

294302
@cached_property
295303
def dtype(self):
296-
return self.data.dtype
304+
return self.z.data.dtype
297305

298306
@cached_property
299307
def dates(self):
@@ -361,7 +369,11 @@ def end_of_statistics_date(self):
361369
return self.dates[-1]
362370

363371
def metadata_specific(self):
364-
return super().metadata_specific(attrs=dict(self.z.attrs))
372+
return super().metadata_specific(
373+
attrs=dict(self.z.attrs),
374+
chunks=self.chunks,
375+
dtype=str(self.dtype),
376+
)
365377

366378
def source(self, index):
367379
return Source(self, index, info=self.path)
@@ -528,6 +540,9 @@ def __len__(self):
528540
return sum(len(i) for i in self.datasets)
529541

530542
def __getitem__(self, n):
543+
if isinstance(n, tuple):
544+
return self._get_tuple(n)
545+
531546
if isinstance(n, slice):
532547
return self._get_slice(n)
533548

@@ -609,8 +624,12 @@ def _get_slice(self, s):
609624
return np.stack([self[i] for i in range(*s.indices(self._len))])
610625

611626
def __getitem__(self, n):
627+
if isinstance(n, tuple):
628+
return self._get_tuple(n)
629+
612630
if isinstance(n, slice):
613631
return self._get_slice(n)
632+
614633
return np.concatenate([d[n] for d in self.datasets], axis=self.axis - 1)
615634

616635

@@ -658,8 +677,12 @@ def _get_slice(self, s):
658677
return np.stack([self[i] for i in range(*s.indices(self._len))])
659678

660679
def __getitem__(self, n):
680+
if isinstance(n, tuple):
681+
return self._get_tuple(n)
682+
661683
if isinstance(n, slice):
662684
return self._get_slice(n)
685+
663686
return np.concatenate([d[n] for d in self.datasets])
664687

665688
@cached_property
@@ -743,8 +766,12 @@ def __init__(self, dataset, indices):
743766
super().__init__(dataset)
744767

745768
def __getitem__(self, n):
769+
if isinstance(n, tuple):
770+
return self._get_tuple(n)
771+
746772
if isinstance(n, slice):
747773
return self._get_slice(n)
774+
748775
n = self.indices[n]
749776
return self.dataset[n]
750777

@@ -794,9 +821,13 @@ def __init__(self, dataset, indices):
794821
super().__init__(dataset)
795822

796823
def __getitem__(self, n):
824+
if isinstance(n, tuple):
825+
return self._get_tuple(n)
826+
797827
row = self.dataset[n]
798828
if isinstance(n, slice):
799829
return row[:, self.indices]
830+
800831
return row[self.indices]
801832

802833
@cached_property

tests/test_data.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,54 @@ def create_zarr(
6666
ensembles = ensemble if ensemble is not None else 1
6767
values = grids if grids is not None else VALUES
6868

69-
data = np.zeros((len(dates), len(vars), ensembles, values))
69+
data = np.zeros(shape=(len(dates), len(vars), ensembles, values))
7070

7171
for i, date in enumerate(dates):
7272
for j, var in enumerate(vars):
7373
for e in range(ensembles):
7474
data[i, j, e] = _(date.astype(object), var, k, e, values)
7575

76-
root.data = data
77-
root.dates = dates
78-
root.latitudes = np.array([x + values for x in range(values)])
79-
root.longitudes = np.array([x + values for x in range(values)])
76+
root.create_dataset(
77+
"data",
78+
data=data,
79+
dtype=data.dtype,
80+
chunks=data.shape,
81+
)
82+
root.create_dataset(
83+
"dates",
84+
data=dates,
85+
)
86+
root.create_dataset(
87+
"latitudes",
88+
data=np.array([x + values for x in range(values)]),
89+
)
90+
root.create_dataset(
91+
"longitudes",
92+
data=np.array([x + values for x in range(values)]),
93+
)
8094

8195
root.attrs["frequency"] = frequency
8296
root.attrs["resolution"] = resolution
8397
root.attrs["name_to_index"] = {k: i for i, k in enumerate(vars)}
8498

8599
root.attrs["data_request"] = {"grid": 1, "area": "g", "param_level": {}}
86100

87-
root.mean = np.mean(data, axis=0)
88-
root.stdev = np.std(data, axis=0)
89-
root.maximum = np.max(data, axis=0)
90-
root.minimum = np.min(data, axis=0)
101+
root.create_dataset(
102+
"mean",
103+
data=np.mean(data, axis=0),
104+
)
105+
root.create_dataset(
106+
"stdev",
107+
data=np.std(data, axis=0),
108+
)
109+
root.create_dataset(
110+
"maximum",
111+
data=np.max(data, axis=0),
112+
)
113+
root.create_dataset(
114+
"minimum",
115+
data=np.min(data, axis=0),
116+
)
91117

92118
return root
93119

@@ -167,6 +193,11 @@ def slices(ds, start=None, end=None, step=None):
167193
for i, n in enumerate(range(start, end, step)):
168194
assert (s[i] == ds[n]).all()
169195

196+
x = ds[0:10, :, 0]
197+
198+
if ds.shape[2] > 1:
199+
ds[0:10, :, np.array([1, 0])]
200+
170201

171202
def make_row(args, ensemble=False, grid=False):
172203
if grid:

0 commit comments

Comments
 (0)