Skip to content

Commit 3860c1c

Browse files
authored
Mlangguth/develop/issue 586 (ecmwf#625)
* Add options to configure the marker size, the marker type and enable marker-scaling with latitude for map-plots * Update doc-strings to follow standard format. * Ruffed code. * Changes due to review comments. * Less verbose logging and improved handling of setting to plot histograms. * Corrected error-message in plot_data.
1 parent 4c8e246 commit 3860c1c

File tree

2 files changed

+258
-57
lines changed

2 files changed

+258
-57
lines changed

packages/evaluate/src/weathergen/evaluate/plotter.py

Lines changed: 160 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@ class Plotter:
2323
def __init__(self, cfg: dict, model_id: str = ""):
2424
"""
2525
Initialize the Plotter class.
26-
:param cfg: config from the yaml file
27-
:param model_id: if a model_id is given, the output will be saved in a folder called as the model_id
26+
27+
Parameters
28+
----------
29+
cfg:
30+
Configuration dictionary containing all information for the plotting.
31+
model_id:
32+
If a model_id is given, the output will be saved in a folder called as the model_id.
2833
"""
2934

3035
self.cfg = cfg
@@ -49,7 +54,14 @@ def __init__(self, cfg: dict, model_id: str = ""):
4954
def update_data_selection(self, select: dict):
5055
"""
5156
Set the selection for the plots. This will be used to filter the data for plotting.
52-
:param select: dictionary containing the selection parameters
57+
58+
Parameters
59+
----------
60+
select:
61+
Dictionary containing the selection criteria. Expected keys are:
62+
- "sample": Sample identifier
63+
- "stream": Stream identifier
64+
- "forecast_step": Forecast step identifier
5365
"""
5466
self.select = select
5567

@@ -78,9 +90,7 @@ def update_data_selection(self, select: dict):
7890

7991
def clean_data_selection(self):
8092
"""
81-
:param sample: sample name
82-
:param stream: stream name
83-
:param fstep: forecasting step
93+
Clean the data selection by resetting all selected values.
8494
"""
8595
self.sample = None
8696
self.stream = None
@@ -91,9 +101,17 @@ def clean_data_selection(self):
91101
def select_from_da(self, da: xr.DataArray, selection: dict) -> xr.DataArray:
92102
"""
93103
Select data from an xarray DataArray based on given selectors.
94-
:param da: xarray DataArray to select data from.
95-
:param selection: Dictionary of selectors where keys are coordinate names and values are the values to select.
96-
:return: xarray DataArray with selected data.
104+
105+
Parameters
106+
----------
107+
da:
108+
xarray DataArray to select data from.
109+
selection:
110+
Dictionary of selectors where keys are coordinate names and values are the values to select.
111+
112+
Returns
113+
-------
114+
xarray DataArray with selected data.
97115
"""
98116
for key, value in selection.items():
99117
if key in da.coords and key not in da.dims:
@@ -111,15 +129,26 @@ def histogram(
111129
variables: list,
112130
select: dict,
113131
tag: str = "",
114-
number: str = "",
115132
) -> list[str]:
116133
"""
117134
Plot histogram of target vs predictions for a set of variables.
118135
119-
:param target: target sample for a specific (stream, sample, fstep)
120-
:param preds: predictions sample for a specific (stream, sample, fstep)
121-
:param variables: list of variables to be plotted
122-
:param label: any tag you want to add to the plot
136+
Parameters
137+
----------
138+
target: xr.DataArray
139+
Target sample for a specific (stream, sample, fstep)
140+
preds: xr.DataArray
141+
Predictions sample for a specific (stream, sample, fstep)
142+
variables: list
143+
List of variables to be plotted
144+
select: dict
145+
Selection to be applied to the DataArray
146+
tag: str
147+
Any tag you want to add to the plot
148+
149+
Returns
150+
-------
151+
List of plot names for the saved histograms.
123152
"""
124153
plot_names = []
125154

@@ -166,16 +195,45 @@ def histogram(
166195
return plot_names
167196

168197
def map(
169-
self, data: xr.DataArray, variables: list, select: dict, tag: str = ""
198+
self,
199+
data: xr.DataArray,
200+
variables: list,
201+
select: dict,
202+
tag: str = "",
203+
map_kwargs: dict | None = None,
170204
) -> list[str]:
171205
"""
172206
Plot 2D map for a dataset
173207
174-
:param data: DataArray for a specific (stream, sample, fstep)
175-
:param variables: list of variables to be plotted
176-
:param label: any tag you want to add to the plot
177-
:param select: selection to be applied to the DataArray
208+
Parameters
209+
----------
210+
data: xr.DataArray
211+
DataArray for a specific (stream, sample, fstep)
212+
variables: list
213+
List of variables to be plotted
214+
label: str
215+
Any tag you want to add to the plot
216+
select: dict
217+
Selection to be applied to the DataArray
218+
tag: str
219+
Any tag you want to add to the plot
220+
map_kwargs: dict
221+
Additional keyword arguments for the map.
222+
Known keys are:
223+
- marker_size: base size of the marker (default is 1)
224+
- scale_marker_size: if True, the marker size will be scaled based on latitude (default is False)
225+
- marker: marker style (default is 'o')
226+
Unknown keys will be passed to the scatter plot function.
227+
228+
Returns
229+
-------
230+
List of plot names for the saved maps.
178231
"""
232+
map_kwargs_save = map_kwargs.copy() if map_kwargs is not None else {}
233+
# check for known keys in map_kwargs
234+
marker_size_base = map_kwargs_save.pop("marker_size", 1)
235+
scale_marker_size = map_kwargs_save.pop("scale_marker_size", False)
236+
marker = map_kwargs_save.pop("marker", "o")
179237

180238
self.update_data_selection(select)
181239

@@ -187,13 +245,20 @@ def map(
187245
ax.coastlines()
188246
da = self.select_from_da(data, select_var).compute()
189247

248+
marker_size = marker_size_base
249+
if scale_marker_size:
250+
marker_size = (marker_size + 1.0) * np.cos(np.radians(da["lat"]))
251+
190252
scatter_plt = ax.scatter(
191253
da["lon"],
192254
da["lat"],
193255
c=da,
194256
cmap="coolwarm",
195-
s=1,
257+
s=marker_size,
258+
marker=marker,
196259
transform=ccrs.PlateCarree(),
260+
linewidths=0.0, # only markers, avoids aliasing for very small markers
261+
**map_kwargs_save,
197262
)
198263
plt.colorbar(
199264
scatter_plt, ax=ax, orientation="horizontal", label=f"Variable: {var}"
@@ -215,7 +280,9 @@ def map(
215280
str(self.fstep).zfill(3),
216281
]
217282
name = "_".join(filter(None, parts))
218-
plt.savefig(f"{self.out_plot_dir.joinpath(name)}.{self.image_format}")
283+
fname = f"{self.out_plot_dir.joinpath(name)}.{self.image_format}"
284+
_logger.debug(f"Saving map to {fname}")
285+
plt.savefig(fname)
219286
plt.close()
220287
plot_names.append(name)
221288

@@ -246,9 +313,17 @@ def _check_lengths(
246313
) -> tuple[list, list]:
247314
"""
248315
Check if the lengths of data and labels match.
249-
:param data: DataArray or list of DataArrays to be plotted
250-
:param labels: Label or list of labels for each dataset
251-
:return: data_list, label_list - lists of data and labels
316+
317+
Parameters
318+
----------
319+
data:
320+
DataArray or list of DataArrays to be plotted
321+
labels:
322+
Label or list of labels for each dataset
323+
324+
Returns
325+
-------
326+
data_list, label_list - lists of data and labels
252327
"""
253328
assert type(data) == xr.DataArray or type(data) == list, (
254329
"Compare::plot - Data should be of type xr.DataArray or list"
@@ -291,12 +366,21 @@ def plot(
291366
) -> None:
292367
"""
293368
Plot a line graph comparing multiple datasets.
294-
:param data: DataArray or list of DataArrays to be plotted
295-
:param labels: Label or list of labels for each dataset
296-
:param tag: Tag to be added to the plot title and filename
297-
:param x_dim: Dimension to be used for the x-axis. The code will average over all other dimensions. (default is "forecast_step")
298-
:param y_dim: Name of the dimension to be used for the y-axis (default is "value")
299-
:return: None
369+
370+
Parameters
371+
----------
372+
data:
373+
DataArray or list of DataArrays to be plotted
374+
labels:
375+
Label or list of labels for each dataset
376+
tag:
377+
Tag to be added to the plot title and filename
378+
x_dim:
379+
Dimension to be used for the x-axis. The code will average over all other dimensions.
380+
y_dim:
381+
Name of the dimension to be used for the y-axis.
382+
print_summary:
383+
If True, print a summary of the values from the graph.
300384
"""
301385

302386
data_list, label_list = self._check_lengths(data, labels)
@@ -345,3 +429,49 @@ def plot(
345429
name = "_".join(filter(None, parts))
346430
plt.savefig(f"{self.out_plot_dir.joinpath(name)}.{self.image_format}")
347431
plt.close()
432+
433+
434+
class DefaultMarkerSize:
435+
"""
436+
Utility class for managing default configuration values, such as marker sizes
437+
for various data streams.
438+
"""
439+
440+
_marker_size_stream = {
441+
"era5": 2.5,
442+
"imerg": 0.25,
443+
"cerra": 0.1,
444+
}
445+
446+
_default_marker_size = 0.5
447+
448+
@classmethod
449+
def get_marker_size(cls, stream_name: str) -> float:
450+
"""
451+
Get the default marker size for a given stream name.
452+
453+
Parameters
454+
----------
455+
stream_name : str
456+
The name of the stream.
457+
458+
Returns
459+
-------
460+
float
461+
The default marker size for the stream.
462+
"""
463+
return cls._marker_size_stream.get(
464+
stream_name.lower(), cls._default_marker_size
465+
)
466+
467+
@classmethod
468+
def list_streams(cls):
469+
"""
470+
List all streams with defined marker sizes.
471+
472+
Returns
473+
-------
474+
list[str]
475+
List of stream names.
476+
"""
477+
return list(cls._marker_size_stream.keys())

0 commit comments

Comments
 (0)