@@ -23,8 +23,13 @@ class Plotter:
2323 def __init__ (self , cfg : dict , model_id : str = "" ):
2424 """
2525 Initialize the Plotter class.
26- :param cfg: config from the yaml file
27- :param model_id: if a model_id is given, the output will be saved in a folder called as the model_id
26+
27+ Parameters
28+ ----------
29+ cfg:
30+ Configuration dictionary containing all information for the plotting.
31+ model_id:
32+ If a model_id is given, the output will be saved in a folder called as the model_id.
2833 """
2934
3035 self .cfg = cfg
@@ -49,7 +54,14 @@ def __init__(self, cfg: dict, model_id: str = ""):
4954 def update_data_selection (self , select : dict ):
5055 """
5156 Set the selection for the plots. This will be used to filter the data for plotting.
52- :param select: dictionary containing the selection parameters
57+
58+ Parameters
59+ ----------
60+ select:
61+ Dictionary containing the selection criteria. Expected keys are:
62+ - "sample": Sample identifier
63+ - "stream": Stream identifier
64+ - "forecast_step": Forecast step identifier
5365 """
5466 self .select = select
5567
@@ -78,9 +90,7 @@ def update_data_selection(self, select: dict):
7890
7991 def clean_data_selection (self ):
8092 """
81- :param sample: sample name
82- :param stream: stream name
83- :param fstep: forecasting step
93+ Clean the data selection by resetting all selected values.
8494 """
8595 self .sample = None
8696 self .stream = None
@@ -91,9 +101,17 @@ def clean_data_selection(self):
91101 def select_from_da (self , da : xr .DataArray , selection : dict ) -> xr .DataArray :
92102 """
93103 Select data from an xarray DataArray based on given selectors.
94- :param da: xarray DataArray to select data from.
95- :param selection: Dictionary of selectors where keys are coordinate names and values are the values to select.
96- :return: xarray DataArray with selected data.
104+
105+ Parameters
106+ ----------
107+ da:
108+ xarray DataArray to select data from.
109+ selection:
110+ Dictionary of selectors where keys are coordinate names and values are the values to select.
111+
112+ Returns
113+ -------
114+ xarray DataArray with selected data.
97115 """
98116 for key , value in selection .items ():
99117 if key in da .coords and key not in da .dims :
@@ -111,15 +129,26 @@ def histogram(
111129 variables : list ,
112130 select : dict ,
113131 tag : str = "" ,
114- number : str = "" ,
115132 ) -> list [str ]:
116133 """
117134 Plot histogram of target vs predictions for a set of variables.
118135
119- :param target: target sample for a specific (stream, sample, fstep)
120- :param preds: predictions sample for a specific (stream, sample, fstep)
121- :param variables: list of variables to be plotted
122- :param label: any tag you want to add to the plot
136+ Parameters
137+ ----------
138+ target: xr.DataArray
139+ Target sample for a specific (stream, sample, fstep)
140+ preds: xr.DataArray
141+ Predictions sample for a specific (stream, sample, fstep)
142+ variables: list
143+ List of variables to be plotted
144+ select: dict
145+ Selection to be applied to the DataArray
146+ tag: str
147+ Any tag you want to add to the plot
148+
149+ Returns
150+ -------
151+ List of plot names for the saved histograms.
123152 """
124153 plot_names = []
125154
@@ -166,16 +195,45 @@ def histogram(
166195 return plot_names
167196
168197 def map (
169- self , data : xr .DataArray , variables : list , select : dict , tag : str = ""
198+ self ,
199+ data : xr .DataArray ,
200+ variables : list ,
201+ select : dict ,
202+ tag : str = "" ,
203+ map_kwargs : dict | None = None ,
170204 ) -> list [str ]:
171205 """
172206 Plot 2D map for a dataset
173207
174- :param data: DataArray for a specific (stream, sample, fstep)
175- :param variables: list of variables to be plotted
176- :param label: any tag you want to add to the plot
177- :param select: selection to be applied to the DataArray
208+ Parameters
209+ ----------
210+ data: xr.DataArray
211+ DataArray for a specific (stream, sample, fstep)
212+ variables: list
213+ List of variables to be plotted
214+ label: str
215+ Any tag you want to add to the plot
216+ select: dict
217+ Selection to be applied to the DataArray
218+ tag: str
219+ Any tag you want to add to the plot
220+ map_kwargs: dict
221+ Additional keyword arguments for the map.
222+ Known keys are:
223+ - marker_size: base size of the marker (default is 1)
224+ - scale_marker_size: if True, the marker size will be scaled based on latitude (default is False)
225+ - marker: marker style (default is 'o')
226+ Unknown keys will be passed to the scatter plot function.
227+
228+ Returns
229+ -------
230+ List of plot names for the saved maps.
178231 """
232+ map_kwargs_save = map_kwargs .copy () if map_kwargs is not None else {}
233+ # check for known keys in map_kwargs
234+ marker_size_base = map_kwargs_save .pop ("marker_size" , 1 )
235+ scale_marker_size = map_kwargs_save .pop ("scale_marker_size" , False )
236+ marker = map_kwargs_save .pop ("marker" , "o" )
179237
180238 self .update_data_selection (select )
181239
@@ -187,13 +245,20 @@ def map(
187245 ax .coastlines ()
188246 da = self .select_from_da (data , select_var ).compute ()
189247
248+ marker_size = marker_size_base
249+ if scale_marker_size :
250+ marker_size = (marker_size + 1.0 ) * np .cos (np .radians (da ["lat" ]))
251+
190252 scatter_plt = ax .scatter (
191253 da ["lon" ],
192254 da ["lat" ],
193255 c = da ,
194256 cmap = "coolwarm" ,
195- s = 1 ,
257+ s = marker_size ,
258+ marker = marker ,
196259 transform = ccrs .PlateCarree (),
260+ linewidths = 0.0 , # only markers, avoids aliasing for very small markers
261+ ** map_kwargs_save ,
197262 )
198263 plt .colorbar (
199264 scatter_plt , ax = ax , orientation = "horizontal" , label = f"Variable: { var } "
@@ -215,7 +280,9 @@ def map(
215280 str (self .fstep ).zfill (3 ),
216281 ]
217282 name = "_" .join (filter (None , parts ))
218- plt .savefig (f"{ self .out_plot_dir .joinpath (name )} .{ self .image_format } " )
283+ fname = f"{ self .out_plot_dir .joinpath (name )} .{ self .image_format } "
284+ _logger .debug (f"Saving map to { fname } " )
285+ plt .savefig (fname )
219286 plt .close ()
220287 plot_names .append (name )
221288
@@ -246,9 +313,17 @@ def _check_lengths(
246313 ) -> tuple [list , list ]:
247314 """
248315 Check if the lengths of data and labels match.
249- :param data: DataArray or list of DataArrays to be plotted
250- :param labels: Label or list of labels for each dataset
251- :return: data_list, label_list - lists of data and labels
316+
317+ Parameters
318+ ----------
319+ data:
320+ DataArray or list of DataArrays to be plotted
321+ labels:
322+ Label or list of labels for each dataset
323+
324+ Returns
325+ -------
326+ data_list, label_list - lists of data and labels
252327 """
253328 assert type (data ) == xr .DataArray or type (data ) == list , (
254329 "Compare::plot - Data should be of type xr.DataArray or list"
@@ -291,12 +366,21 @@ def plot(
291366 ) -> None :
292367 """
293368 Plot a line graph comparing multiple datasets.
294- :param data: DataArray or list of DataArrays to be plotted
295- :param labels: Label or list of labels for each dataset
296- :param tag: Tag to be added to the plot title and filename
297- :param x_dim: Dimension to be used for the x-axis. The code will average over all other dimensions. (default is "forecast_step")
298- :param y_dim: Name of the dimension to be used for the y-axis (default is "value")
299- :return: None
369+
370+ Parameters
371+ ----------
372+ data:
373+ DataArray or list of DataArrays to be plotted
374+ labels:
375+ Label or list of labels for each dataset
376+ tag:
377+ Tag to be added to the plot title and filename
378+ x_dim:
379+ Dimension to be used for the x-axis. The code will average over all other dimensions.
380+ y_dim:
381+ Name of the dimension to be used for the y-axis.
382+ print_summary:
383+ If True, print a summary of the values from the graph.
300384 """
301385
302386 data_list , label_list = self ._check_lengths (data , labels )
@@ -345,3 +429,49 @@ def plot(
345429 name = "_" .join (filter (None , parts ))
346430 plt .savefig (f"{ self .out_plot_dir .joinpath (name )} .{ self .image_format } " )
347431 plt .close ()
432+
433+
434+ class DefaultMarkerSize :
435+ """
436+ Utility class for managing default configuration values, such as marker sizes
437+ for various data streams.
438+ """
439+
440+ _marker_size_stream = {
441+ "era5" : 2.5 ,
442+ "imerg" : 0.25 ,
443+ "cerra" : 0.1 ,
444+ }
445+
446+ _default_marker_size = 0.5
447+
448+ @classmethod
449+ def get_marker_size (cls , stream_name : str ) -> float :
450+ """
451+ Get the default marker size for a given stream name.
452+
453+ Parameters
454+ ----------
455+ stream_name : str
456+ The name of the stream.
457+
458+ Returns
459+ -------
460+ float
461+ The default marker size for the stream.
462+ """
463+ return cls ._marker_size_stream .get (
464+ stream_name .lower (), cls ._default_marker_size
465+ )
466+
467+ @classmethod
468+ def list_streams (cls ):
469+ """
470+ List all streams with defined marker sizes.
471+
472+ Returns
473+ -------
474+ list[str]
475+ List of stream names.
476+ """
477+ return list (cls ._marker_size_stream .keys ())
0 commit comments