Skip to content

Commit 0f0975d

Browse files
authored
Fixed bug in obs data reading (ecmwf#698)
* Restored old prediction had functionally. Other adjustments/reverts, in particular in attention. * Ruff'ed * Fixed bug in obs data reading so that data violated window * Fix * Update data_reader_obs.py * Restoring to develop * Fix * Ruffed
1 parent 97176bd commit 0f0975d

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

src/weathergen/datasets/data_reader_base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,11 +254,9 @@ def check_reader_data(rdata: ReaderData, dtr: DTRange) -> None:
254254
f"{rdata.datetimes.shape[0]}"
255255
)
256256

257-
assert np.logical_and(
258-
rdata.datetimes >= dtr.start,
259-
# rdata.datetimes < dtr.end # TODO: enforce monotonicty also for obs
260-
rdata.datetimes <= dtr.end,
261-
).all(), f"datetimes for data points violate window {dtr}."
257+
assert np.logical_and(rdata.datetimes >= dtr.start, rdata.datetimes < dtr.end).all(), (
258+
f"datetimes for data points violate window {dtr}."
259+
)
262260

263261

264262
class DataReaderBase(metaclass=ABCMeta):

src/weathergen/datasets/data_reader_obs.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _get(self, idx: int, channels_idx: list[int]) -> ReaderData:
225225
num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx)
226226
)
227227

228-
start_row = self.indices_start[idx]
228+
start_row = self.indices_start[idx - 1]
229229
end_row = self.indices_end[idx]
230230

231231
coords = self.data.oindex[start_row:end_row, self.coords_idx]
@@ -238,11 +238,17 @@ def _get(self, idx: int, channels_idx: list[int]) -> ReaderData:
238238
data = self.data.oindex[start_row:end_row, channels_idx]
239239
datetimes = self.dt[start_row:end_row][:, 0]
240240

241+
# indices_start, indices_end above work with [t_start, t_end] and violate
242+
# our convention [t_start, t_end) where endpoint is excluded
243+
# compute mask to enforce it
244+
t_win = self.time_window_handler.window(idx)
245+
t_mask = np.logical_and(datetimes >= t_win.start, datetimes < t_win.end)
246+
241247
rdata = ReaderData(
242-
coords=coords,
243-
geoinfos=geoinfos,
244-
data=data,
245-
datetimes=datetimes,
248+
coords=coords[t_mask],
249+
geoinfos=geoinfos[t_mask],
250+
data=data[t_mask],
251+
datetimes=datetimes[t_mask],
246252
)
247253

248254
dtr = self.time_window_handler.window(idx)

0 commit comments

Comments
 (0)