Skip to content

Commit 295a27b

Browse files
iluiseclessig
andauthored
Iluise/fix empty io 819 plotting (ecmwf#826)
* Fix to io problems. * Fix issues in input * fix plotting * ruffed --------- Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
1 parent 8dee3e5 commit 295a27b

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

packages/common/src/weathergen/common/io.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,11 @@ def _offset_key(self, key: ItemKey):
436436

437437
def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordinates:
438438
_coords = self.targets_coords[offset_key.forecast_step][stream_idx][datapoints].numpy()
439+
439440
# ensure _coords has size (?,2)
440441
if len(_coords) == 0:
441442
_coords = np.zeros((0, 2), dtype=np.float32)
443+
442444
coords = _coords[..., :2] # first two columns are lat,lon
443445
geoinfo = _coords[..., 2:] # the rest is geoinfo => potentially empty
444446
if geoinfo.size > 0: # TODO: set geoinfo to be empty for now

packages/evaluate/src/weathergen/evaluate/plotter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ def create_histograms_per_sample(
212212

213213
for (valid_time, targ_t), (_, prd_t) in groups:
214214
if valid_time is not None:
215-
_logger.debug(f"Plotting map for {var} at valid_time {valid_time}")
215+
_logger.debug(
216+
f"Plotting histogram for {var} at valid_time {valid_time}"
217+
)
216218

217219
name = self.plot_histogram(targ_t, prd_t, hist_output_dir, var, tag=tag)
218220
plot_names.append(name)
@@ -341,7 +343,6 @@ def create_maps_per_sample(
341343
self.stream
342344
)
343345

344-
345346
# Basic map output directory for this stream
346347
map_output_dir = self.get_map_output_dir(tag)
347348

packages/evaluate/src/weathergen/evaluate/utils.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def get_data(
116116
else:
117117
points_per_sample = None
118118

119+
fsteps_final = fsteps
120+
119121
for fstep in fsteps:
120122
_logger.info(f"RUN {run_id} - {stream}: Processing fstep {fstep}...")
121123
da_tars_fs, da_preds_fs = [], []
@@ -135,6 +137,12 @@ def get_data(
135137
pred = bbox.apply_mask(pred)
136138

137139
npoints = len(target.ipoint)
140+
if npoints == 0:
141+
_logger.info(
142+
f"Skipping {stream} sample {sample} forecast step: {fstep}. Dataset is empty."
143+
)
144+
fsteps_final.remove(fstep)
145+
continue
138146

139147
da_tars_fs.append(target.squeeze())
140148
da_preds_fs.append(pred.squeeze())
@@ -143,31 +151,37 @@ def get_data(
143151
_logger.debug(
144152
f"Concatenating targets and predictions for stream {stream}, forecast_step {fstep}..."
145153
)
146-
da_tars_fs = xr.concat(da_tars_fs, dim="ipoint")
147-
da_preds_fs = xr.concat(da_preds_fs, dim="ipoint")
148154

149-
if set(channels) != set(all_channels):
150-
_logger.debug(
151-
f"Restricting targets and predictions to channels {channels} for stream {stream}..."
152-
)
153-
available_channels = da_tars_fs.channel.values
154-
existing_channels = [ch for ch in channels if ch in available_channels]
155-
if len(existing_channels) < len(channels):
156-
_logger.warning(
157-
f"The following channels were not found: {list(set(channels) - set(existing_channels))}. Skipping them."
155+
if da_tars_fs:
156+
da_tars_fs = xr.concat(da_tars_fs, dim="ipoint")
157+
da_preds_fs = xr.concat(da_preds_fs, dim="ipoint")
158+
159+
if set(channels) != set(all_channels):
160+
_logger.debug(
161+
f"Restricting targets and predictions to channels {channels} for stream {stream}..."
158162
)
163+
available_channels = da_tars_fs.channel.values
164+
existing_channels = [
165+
ch for ch in channels if ch in available_channels
166+
]
167+
if len(existing_channels) < len(channels):
168+
_logger.warning(
169+
f"The following channels were not found: {list(set(channels) - set(existing_channels))}. Skipping them."
170+
)
159171

160-
da_tars_fs = da_tars_fs.sel(channel=existing_channels)
161-
da_preds_fs = da_preds_fs.sel(channel=existing_channels)
172+
da_tars_fs = da_tars_fs.sel(channel=existing_channels)
173+
da_preds_fs = da_preds_fs.sel(channel=existing_channels)
162174

163175
da_tars.append(da_tars_fs)
164176
da_preds.append(da_preds_fs)
165177
if return_counts:
166178
points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps)
167179

168180
# Safer than a list
169-
da_tars = {fstep: da for fstep, da in zip(fsteps, da_tars, strict=False)}
170-
da_preds = {fstep: da for fstep, da in zip(fsteps, da_preds, strict=False)}
181+
da_tars = {fstep: da for fstep, da in zip(fsteps_final, da_tars, strict=False)}
182+
da_preds = {
183+
fstep: da for fstep, da in zip(fsteps_final, da_preds, strict=False)
184+
}
171185

172186
return WeatherGeneratorOutput(
173187
target=da_tars, prediction=da_preds, points_per_sample=points_per_sample
@@ -396,6 +410,10 @@ def plot_data(
396410
da_tars = model_output.target
397411
da_preds = model_output.prediction
398412

413+
if not da_tars:
414+
_logger.info(f"Skipping Plot Data for {stream}. Targets are empty.")
415+
return
416+
399417
maps_config = common_ranges(da_tars, da_preds, plot_chs, maps_config)
400418

401419
plot_names = []

0 commit comments

Comments
 (0)