Skip to content

Commit 0eea172

Browse files
authored
feat(output, plots): Use earthkit plots (ecmwf#293)
## Description Use `earthkit-plots` for the plotting output > [!WARNING] > Requires: ecmwf/earthkit-plots#81 ## What problem does this change solve? Reuse parts of the ECMWF stack for plotting ## Example Plots go from <img width="724" height="1231" alt="image" src="https://github.com/user-attachments/assets/4bac81df-1ad0-4fb5-b209-46bbeb45d69a" /> to <img width="653" height="715" alt="image" src="https://github.com/user-attachments/assets/c8320d23-ccef-466f-ada6-04fb6457bed1" /> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 9ee8b24 commit 0eea172

File tree

2 files changed

+49
-44
lines changed

2 files changed

+49
-44
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ optional-dependencies.docs = [
7676
]
7777

7878
optional-dependencies.huggingface = [ "huggingface-hub" ]
79-
optional-dependencies.plot = [ "cartopy", "matplotlib" ]
79+
optional-dependencies.plot = [ "earthkit-plots" ]
8080

8181
optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ]
8282
optional-dependencies.tests = [ "anemoi-datasets[all]", "anemoi-inference[all]", "hypothesis", "pytest", "pytest-mock" ]

src/anemoi/inference/outputs/plot.py

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
import os
1212

1313
import numpy as np
14+
from anemoi.utils.grib import units
1415

1516
from anemoi.inference.context import Context
17+
from anemoi.inference.decorators import main_argument
1618
from anemoi.inference.types import FloatArray
1719
from anemoi.inference.types import ProcessorConfig
1820
from anemoi.inference.types import State
@@ -40,22 +42,26 @@ def fix(lons: FloatArray) -> FloatArray:
4042

4143

4244
@output_registry.register("plot")
45+
@main_argument("path")
4346
class PlotOutput(Output):
44-
"""Plot output class."""
47+
"""Use `earthkit-plots` to plot the outputs."""
4548

4649
def __init__(
4750
self,
4851
context: Context,
4952
path: str,
53+
*,
54+
variables: list[str] | None = None,
55+
mode: str = "subplots",
56+
domain: str | list[str] | None = None,
5057
strftime: str = "%Y%m%d%H%M%S",
51-
template: str = "plot_{variable}_{date}.{format}",
52-
dpi: int = 300,
58+
template: str = "plot_{date}.{format}",
5359
format: str = "png",
54-
variables: list[str] | None = None,
5560
missing_value: float | None = None,
5661
post_processors: list[ProcessorConfig] | None = None,
5762
output_frequency: int | None = None,
5863
write_initial_state: bool | None = None,
64+
**kwargs,
5965
) -> None:
6066
"""Initialize the PlotOutput.
6167
@@ -67,12 +73,14 @@ def __init__(
6773
The path to save the plots.
6874
variables : list, optional
6975
The list of variables to plot, by default all.
76+
mode : str, optional
77+
The plotting mode, can be "subplots" or "overlay", by default "subplots".
78+
domain : str | list[str] | None, optional
79+
The domain/s to plot, by default None.
7080
strftime : str, optional
7181
The date format string, by default "%Y%m%d%H%M%S".
7282
template : str, optional
73-
The template for plot filenames, by default "plot_{variable}_{date}.{format}".
74-
dpi : int, optional
75-
The resolution of the plot, by default 300.
83+
The template for plot filenames, by default "plot_{date}.{format}".
7684
format : str, optional
7785
The format of the plot, by default "png".
7886
missing_value : float, optional
@@ -97,8 +105,10 @@ def __init__(
97105
self.variables = variables
98106
self.strftime = strftime
99107
self.template = template
100-
self.dpi = dpi
101108
self.missing_value = missing_value
109+
self.domain = domain
110+
self.mode = mode
111+
self.kwargs = kwargs
102112

103113
def write_step(self, state: State) -> None:
104114
"""Write a step of the state.
@@ -108,50 +118,45 @@ def write_step(self, state: State) -> None:
108118
state : State
109119
The state dictionary.
110120
"""
111-
import cartopy.crs as ccrs
112-
import cartopy.feature as cfeature
113-
import matplotlib.pyplot as plt
114-
import matplotlib.tri as tri
121+
import earthkit.data as ekd
122+
import earthkit.plots as ekp
115123

116124
os.makedirs(self.path, exist_ok=True)
117125

118-
longitudes = state["longitudes"]
126+
longitudes = fix(state["longitudes"])
119127
latitudes = state["latitudes"]
120-
triangulation = tri.Triangulation(fix(longitudes), latitudes)
128+
date = state["date"]
129+
basetime = date - state["step"]
130+
131+
plotting_fields = []
121132

122133
for name, values in state["fields"].items():
123134
if self.skip_variable(name):
124135
continue
125136

126-
_, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
127-
ax.coastlines()
128-
ax.add_feature(cfeature.BORDERS, linestyle=":")
129-
130-
missing_values = np.isnan(values)
131-
missing_value = self.missing_value
132-
if missing_value is None:
133-
min = np.nanmin(values)
134-
missing_value = min - np.abs(min) * 0.001
135-
136-
values = np.where(missing_values, self.missing_value, values).astype(np.float32)
137-
138-
_ = ax.tricontourf(triangulation, values, levels=10, transform=ccrs.PlateCarree())
139-
140-
ax.tricontour(
141-
triangulation,
142-
values,
143-
levels=10,
144-
colors="black",
145-
linewidths=0.5,
146-
transform=ccrs.PlateCarree(),
137+
variable = self.context.checkpoint.typed_variables[name]
138+
param = variable.param
139+
140+
plotting_fields.append(
141+
ekd.ArrayField(
142+
values,
143+
{
144+
"shortName": param,
145+
"variable_name": param,
146+
"step": state["step"],
147+
"base_datetime": basetime,
148+
"latitudes": latitudes,
149+
"longitudes": longitudes,
150+
"units": units(param),
151+
},
152+
)
147153
)
148154

149-
date = state["date"].strftime("%Y-%m-%d %H:%M:%S")
150-
ax.set_title(f"{name} at {date}")
151-
152-
date = state["date"].strftime(self.strftime)
153-
fname = self.template.format(date=date, variable=name, format=self.format)
154-
fname = os.path.join(self.path, fname)
155+
fig = ekp.quickplot(
156+
ekd.FieldList.from_fields((plotting_fields)), mode=self.mode, domain=self.domain, **self.kwargs
157+
)
158+
fname = self.template.format(date=date, format=self.format)
159+
fname = os.path.join(self.path, fname)
155160

156-
plt.savefig(fname, dpi=self.dpi, bbox_inches="tight")
157-
plt.close()
161+
fig.save(fname)
162+
del fig

0 commit comments

Comments
 (0)