Skip to content

Commit ecc8e15

Browse files
Add plots for tp
1 parent a5851c3 commit ecc8e15

File tree

6 files changed

+136
-32
lines changed

6 files changed

+136
-32
lines changed

src/plotting/__init__.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import contextmanager
12
from functools import cached_property
23
from pathlib import Path
34

@@ -19,7 +20,7 @@
1920

2021
# Mapping of region names to their geographic extent and projection
2122
# extent [lon_min, lon_max, lat_min, lat_max] in PlateCarree coordinates
22-
REGIONS = {
23+
DOMAINS = {
2324
"globe": {
2425
"extent": None, # full globe view
2526
"projection": _PROJECTIONS["orthographic"],
@@ -40,7 +41,7 @@
4041

4142

4243
class StatePlotter:
43-
"""A class to plot state fields on various REGIONS."""
44+
"""A class to plot state fields on various DOMAINS."""
4445

4546
def __init__(
4647
self,
@@ -168,7 +169,6 @@ def plot_field(
168169
field = field[mask]
169170
field = field[-1] if field.ndim == 2 else field.squeeze()
170171
finite = np.isfinite(field)
171-
172172
# TODO: clip data to domain would make plotting faster (especially tripcolor)
173173
# tried using Map.domain.extract() but too memory heavy (probably uses
174174
# meshgrid in the background), implement clipping with e.g.
@@ -183,24 +183,72 @@ def plot_field(
183183
# subplot.tripcolor( # also works but is slower
184184
# have to overwrite _plot_kwargs to avoid earthkit-plots trying to pass transform
185185
# PlateCarree based on NumpySource
186-
subplot._plot_kwargs = lambda source: {}
187-
subplot.tricontourf(
188-
x=x[finite],
189-
y=y[finite],
190-
z=field[finite],
191-
style=style,
192-
transform=proj,
193-
**kwargs, # for earthkit.plots to work properly cmap and norm are needed here
194-
)
186+
187+
# Normalize style and color-related kwargs
188+
style_to_use, plot_kwargs = self._prepare_plot_kwargs(style, kwargs)
189+
190+
# Temporarily suppress earthkit-plots internal source-based kwargs
191+
with self._temporary_plot_kwargs_override(subplot):
192+
subplot.tricontourf(
193+
x=x[finite],
194+
y=y[finite],
195+
z=field[finite],
196+
style=style_to_use,
197+
transform=proj,
198+
**plot_kwargs,
199+
) # for earthkit.plots to work properly cmap and norm are needed here
195200
# TODO: gridlines etc would be nicer to have in the init, but I didn't get
196201
# them to overlay the plot layer
202+
197203
subplot.standard_layers()
198204

199205
if colorbar:
200206
subplot.legend()
201207
if title:
202208
subplot.title(title)
203209

210+
def _prepare_plot_kwargs(
211+
self,
212+
style: ekp.styles.Style | None,
213+
kwargs: dict,
214+
) -> tuple[ekp.styles.Style | None, dict]:
215+
"""Return a cleaned style and plot kwargs without mutating the input."""
216+
plot_kwargs = dict(kwargs)
217+
218+
# Discrete colors mode: if explicit 'colors' provided, drop cmap
219+
colors = plot_kwargs.get("colors", None)
220+
if colors is not None:
221+
plot_kwargs.pop("cmap", None)
222+
plot_kwargs.setdefault(
223+
"no_style", True
224+
) # avoid interpolation being performed by earthkit-plots resulting in an error
225+
return style, plot_kwargs
226+
227+
# Continuous mode: remove None entries to avoid matplotlib errors
228+
if plot_kwargs.get("colors", None) is None:
229+
plot_kwargs.pop("colors", None)
230+
if plot_kwargs.get("levels", None) is None:
231+
plot_kwargs.pop("levels", None)
232+
233+
return style, plot_kwargs
234+
235+
@contextmanager
236+
def _temporary_plot_kwargs_override(self, subplot: ekp.Map):
237+
"""Temporarily override internal _plot_kwargs to avoid transform issues."""
238+
has_attr = hasattr(subplot, "_plot_kwargs")
239+
old = getattr(subplot, "_plot_kwargs", None)
240+
subplot._plot_kwargs = lambda source: {}
241+
try:
242+
yield
243+
finally:
244+
if has_attr:
245+
subplot._plot_kwargs = old
246+
else:
247+
try:
248+
delattr(subplot, "_plot_kwargs")
249+
except Exception:
250+
pass
251+
204252
@cached_property
205253
def _orthographic_tri(self) -> Triangulation:
206254
"""Compute the triangulation for the orthographic projection."""

src/plotting/colormap_defaults.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,28 @@ def _fallback():
2222
"t_850": {"cmap": plt.get_cmap("inferno", 11), "vmin": 220, "vmax": 310},
2323
"z_850": {"cmap": plt.get_cmap("coolwarm", 11), "vmin": 8000, "vmax": 17000},
2424
"q_925": load_ncl_colormap("RH_6lev.ct"),
25+
"tp": {
26+
"colors": [
27+
"#ffffff",
28+
"#04e9e7",
29+
"#019ff4",
30+
"#0300f4",
31+
"#02fd02",
32+
"#01c501",
33+
"#008e00",
34+
"#fdf802",
35+
"#e5bc00",
36+
"#fd9500",
37+
"#fd0000",
38+
"#d40000",
39+
"#bc0000",
40+
"#f800fd",
41+
],
42+
"vmin": 0,
43+
"vmax": 100,
44+
"units": "mm",
45+
"levels": [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100],
46+
},
2547
}
2648

2749
CMAP_DEFAULTS = defaultdict(_fallback, _CMAP_DEFAULTS)

src/plotting/compat.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ def load_state_from_grib(
2828
lam_hull = MultiPoint(list(zip(lons.tolist(), lats.tolist()))).convex_hull
2929
state["lam_envelope"] = gpd.GeoSeries([lam_hull], crs="EPSG:4326")
3030
state["fields"] = {}
31-
for param in paramlist:
31+
for param in paramlist or []:
3232
if param in ds:
3333
state["fields"][param] = ds[param].values.flatten()
34+
else:
35+
# initialize with NaNs to keep consistent length
36+
state["fields"][param] = np.full(lats.size, np.nan, dtype=float)
3437
global_file = str(file.parent / f"ifs-{file.stem}.grib")
3538
if Path(global_file).exists():
3639
global_file = str(file.parent / f"ifs-{file.stem}.grib")
@@ -40,18 +43,30 @@ def load_state_from_grib(
4043
for u in fds_global
4144
if u.metadata("param") in paramlist
4245
}
43-
global_lats = fds_global.metadata("latitudes")[0]
44-
global_lons = fds_global.metadata("longitudes")[0]
45-
if max(global_lons) > 180:
46-
global_lons = ((global_lons + 180) % 360) - 180
47-
mask = np.where(~np.isnan(ds_global[paramlist[0]]))[0]
48-
state["longitudes"] = np.concatenate([state["longitudes"], global_lons[mask]])
49-
state["latitudes"] = np.concatenate([state["latitudes"], global_lats[mask]])
50-
for param in paramlist:
51-
if param in ds and param in ds_global:
52-
state["fields"][param] = np.concatenate(
53-
[state["fields"][param], ds_global[param][mask]]
46+
# Use first key from ds_global instead of paramlist[0]
47+
ref_key = next(iter(ds_global), None)
48+
if ref_key is not None:
49+
global_lats = fds_global.metadata("latitudes")[0]
50+
global_lons = fds_global.metadata("longitudes")[0]
51+
if max(global_lons) > 180:
52+
global_lons = ((global_lons + 180) % 360) - 180
53+
mask = np.where(~np.isnan(ds_global[ref_key]))[0]
54+
n_add = int(mask.size)
55+
state["longitudes"] = np.concatenate(
56+
[state["longitudes"], global_lons[mask]]
57+
)
58+
state["latitudes"] = np.concatenate([state["latitudes"], global_lats[mask]])
59+
for param in paramlist or state["fields"].keys():
60+
add = (
61+
ds_global[param][mask]
62+
if param in ds_global
63+
else np.full(n_add, np.nan, dtype=float)
5464
)
65+
# ensure base array exists (in case param wasn't in local ds)
66+
base = state["fields"].get(
67+
param, np.full(lats.size, np.nan, dtype=float)
68+
)
69+
state["fields"][param] = np.concatenate([base, add])
5570
return state
5671

5772

workflow/Snakefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ rule showcase_all:
5252
/ "showcases/{run_id}/{init_time}/{init_time}_{param}_{region}.gif",
5353
init_time=[t.strftime("%Y%m%d%H%M") for t in REFTIMES],
5454
run_id=collect_all_candidates(),
55-
param=["2t", "10sp"],
55+
param=["tp", "2t", "10sp"],
5656
region=["globe", "europe", "switzerland"],
5757
),
5858

workflow/rules/plot.smk

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ rule plot_forecast_frame:
4242
def get_leadtimes(wc):
4343
"""Get all lead times from the run config."""
4444
start, end, step = map(int, RUN_CONFIGS[wc.run_id]["steps"].split("/"))
45+
if wc.param == "tp":
46+
start += step # skip lead time 0 for precipitation
4547
return [f"{i:03}" for i in range(start, end + 1, step)]
4648

4749

workflow/scripts/plot_forecast_frame.mo.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def _():
1414
import earthkit.plots as ekp
1515
import numpy as np
1616

17-
from plotting import REGIONS
17+
from plotting import DOMAINS
1818
from plotting import StatePlotter
1919
from plotting.colormap_defaults import CMAP_DEFAULTS
2020
from plotting.compat import load_state_from_grib
@@ -28,7 +28,7 @@ def _():
2828
load_state_from_grib,
2929
logging,
3030
np,
31-
REGIONS,
31+
DOMAINS,
3232
ccrs,
3333
)
3434

@@ -97,12 +97,17 @@ def get_style(param, units_override=None):
9797
units = units_override if units_override is not None else cfg.get("units", "")
9898
return {
9999
"style": ekp.styles.Style(
100-
levels=cfg.get("bounds", None),
100+
levels=cfg.get("bounds", cfg.get("levels", None)),
101101
extend="both",
102102
units=units,
103+
colors=cfg.get("colors", None),
103104
),
104-
"cmap": cfg["cmap"],
105105
"norm": cfg.get("norm", None),
106+
"cmap": cfg.get("cmap", None),
107+
"levels": cfg.get("levels", None),
108+
"vmin": cfg.get("vmin", None),
109+
"vmax": cfg.get("vmax", None),
110+
"colors": cfg.get("colors", None),
106111
}
107112

108113
return (get_style,)
@@ -132,6 +137,13 @@ def _ms_to_knots(arr):
132137
except Exception:
133138
return arr * 1.943844
134139

140+
def _m_to_mm(arr):
141+
# robust conversion with pint, fallback if dtype unsupported
142+
try:
143+
return (_ureg.Quantity(arr, _ureg.meter).to(_ureg.millimeter)).magnitude
144+
except Exception:
145+
return arr * 1000
146+
135147
except Exception:
136148
LOG.warning("pint not available; falling back hardcoded conversions")
137149

@@ -141,6 +153,9 @@ def _k_to_c(arr):
141153
def _ms_to_knots(arr):
142154
return arr * 1.943844
143155

156+
def _m_to_mm(arr):
157+
return arr * 1000
158+
144159
def preprocess_field(param: str, state: dict):
145160
"""
146161
- Temperatures (2t, 2d, t, d): K -> °C
@@ -162,6 +177,8 @@ def preprocess_field(param: str, state: dict):
162177
u = fields["u"]
163178
v = fields["v"]
164179
return np.sqrt(u**2 + v**2), "m/s"
180+
if param == "tp":
181+
return _m_to_mm(fields[param]), "mm" # convert m to mm
165182
# default: passthrough
166183
return fields[param], None
167184

@@ -179,7 +196,7 @@ def _(
179196
preprocess_field,
180197
region,
181198
state,
182-
REGIONS,
199+
DOMAINS,
183200
ccrs,
184201
):
185202
# plot individual fields
@@ -191,8 +208,8 @@ def _(
191208
fig = plotter.init_geoaxes(
192209
nrows=1,
193210
ncols=1,
194-
projection=REGIONS[region]["projection"],
195-
bbox=REGIONS[region]["extent"],
211+
projection=DOMAINS[region]["projection"],
212+
bbox=DOMAINS[region]["extent"],
196213
name=region,
197214
size=(6, 6),
198215
)

0 commit comments

Comments
 (0)