1111import os
1212
1313import numpy as np
14+ from anemoi .utils .grib import units
1415
1516from anemoi .inference .context import Context
17+ from anemoi .inference .decorators import main_argument
1618from anemoi .inference .types import FloatArray
1719from anemoi .inference .types import ProcessorConfig
1820from anemoi .inference .types import State
@@ -40,22 +42,26 @@ def fix(lons: FloatArray) -> FloatArray:
4042
4143
4244@output_registry .register ("plot" )
45+ @main_argument ("path" )
4346class PlotOutput (Output ):
44- """Plot output class ."""
47+ """Use `earthkit-plots` to plot the outputs ."""
4548
4649 def __init__ (
4750 self ,
4851 context : Context ,
4952 path : str ,
53+ * ,
54+ variables : list [str ] | None = None ,
55+ mode : str = "subplots" ,
56+ domain : str | list [str ] | None = None ,
5057 strftime : str = "%Y%m%d%H%M%S" ,
51- template : str = "plot_{variable}_{date}.{format}" ,
52- dpi : int = 300 ,
58+ template : str = "plot_{date}.{format}" ,
5359 format : str = "png" ,
54- variables : list [str ] | None = None ,
5560 missing_value : float | None = None ,
5661 post_processors : list [ProcessorConfig ] | None = None ,
5762 output_frequency : int | None = None ,
5863 write_initial_state : bool | None = None ,
64+ ** kwargs ,
5965 ) -> None :
6066 """Initialize the PlotOutput.
6167
@@ -67,12 +73,14 @@ def __init__(
6773 The path to save the plots.
6874 variables : list, optional
6975 The list of variables to plot, by default all.
76+ mode : str, optional
77+ The plotting mode, can be "subplots" or "overlay", by default "subplots".
78+ domain : str | list[str] | None, optional
79+ The domain/s to plot, by default None.
7080 strftime : str, optional
7181 The date format string, by default "%Y%m%d%H%M%S".
7282 template : str, optional
73- The template for plot filenames, by default "plot_{variable}_{date}.{format}".
74- dpi : int, optional
75- The resolution of the plot, by default 300.
83+ The template for plot filenames, by default "plot_{date}.{format}".
7684 format : str, optional
7785 The format of the plot, by default "png".
7886 missing_value : float, optional
@@ -97,8 +105,10 @@ def __init__(
97105 self .variables = variables
98106 self .strftime = strftime
99107 self .template = template
100- self .dpi = dpi
101108 self .missing_value = missing_value
109+ self .domain = domain
110+ self .mode = mode
111+ self .kwargs = kwargs
102112
103113 def write_step (self , state : State ) -> None :
104114 """Write a step of the state.
@@ -108,50 +118,45 @@ def write_step(self, state: State) -> None:
108118 state : State
109119 The state dictionary.
110120 """
111- import cartopy .crs as ccrs
112- import cartopy .feature as cfeature
113- import matplotlib .pyplot as plt
114- import matplotlib .tri as tri
121+ import earthkit .data as ekd
122+ import earthkit .plots as ekp
115123
116124 os .makedirs (self .path , exist_ok = True )
117125
118- longitudes = state ["longitudes" ]
126+ longitudes = fix ( state ["longitudes" ])
119127 latitudes = state ["latitudes" ]
120- triangulation = tri .Triangulation (fix (longitudes ), latitudes )
128+ date = state ["date" ]
129+ basetime = date - state ["step" ]
130+
131+ plotting_fields = []
121132
122133 for name , values in state ["fields" ].items ():
123134 if self .skip_variable (name ):
124135 continue
125136
126- _ , ax = plt .subplots (subplot_kw = {"projection" : ccrs .PlateCarree ()})
127- ax .coastlines ()
128- ax .add_feature (cfeature .BORDERS , linestyle = ":" )
129-
130- missing_values = np .isnan (values )
131- missing_value = self .missing_value
132- if missing_value is None :
133- min = np .nanmin (values )
134- missing_value = min - np .abs (min ) * 0.001
135-
136- values = np .where (missing_values , self .missing_value , values ).astype (np .float32 )
137-
138- _ = ax .tricontourf (triangulation , values , levels = 10 , transform = ccrs .PlateCarree ())
139-
140- ax .tricontour (
141- triangulation ,
142- values ,
143- levels = 10 ,
144- colors = "black" ,
145- linewidths = 0.5 ,
146- transform = ccrs .PlateCarree (),
137+ variable = self .context .checkpoint .typed_variables [name ]
138+ param = variable .param
139+
140+ plotting_fields .append (
141+ ekd .ArrayField (
142+ values ,
143+ {
144+ "shortName" : param ,
145+ "variable_name" : param ,
146+ "step" : state ["step" ],
147+ "base_datetime" : basetime ,
148+ "latitudes" : latitudes ,
149+ "longitudes" : longitudes ,
150+ "units" : units (param ),
151+ },
152+ )
147153 )
148154
149- date = state ["date" ].strftime ("%Y-%m-%d %H:%M:%S" )
150- ax .set_title (f"{ name } at { date } " )
151-
152- date = state ["date" ].strftime (self .strftime )
153- fname = self .template .format (date = date , variable = name , format = self .format )
154- fname = os .path .join (self .path , fname )
155+ fig = ekp .quickplot (
156+ ekd .FieldList .from_fields ((plotting_fields )), mode = self .mode , domain = self .domain , ** self .kwargs
157+ )
158+ fname = self .template .format (date = date , format = self .format )
159+ fname = os .path .join (self .path , fname )
155160
156- plt . savefig (fname , dpi = self . dpi , bbox_inches = "tight" )
157- plt . close ()
161+ fig . save (fname )
162+ del fig
0 commit comments