Skip to content

Commit e2cc583

Browse files
authored
[595] Changes for running a notebook script (ecmwf#598)
* Changes * Chanegs * work * change * changes * changes * changes * changes * changes * changes * changes * changes * reverse old changes * linter
1 parent b460521 commit e2cc583

File tree

9 files changed

+538
-41
lines changed

9 files changed

+538
-41
lines changed

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ instance/
7979
# Scrapy stuff:
8080
.scrapy
8181

82+
83+
# Jupyter Notebook
84+
*.ipynb_checkpoints
85+
# Use the jupytext extension instead.
86+
*.ipynb
87+
88+
*.zip
89+
8290
# Sphinx documentation
8391
docs/_build/
8492

packages/evaluate/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111
"panel",
1212
"omegaconf",
1313
"weathergen-common",
14+
"plotly>=6.2.0",
1415
]
1516

1617
[dependency-groups]

packages/evaluate/src/weathergen/evaluate/plot_inference.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from collections import defaultdict
1414
from pathlib import Path
1515

16-
from omegaconf import OmegaConf
17-
from plotter import Plotter
18-
from utils import (
16+
from omegaconf import DictConfig, OmegaConf
17+
18+
from weathergen.evaluate.utils import (
1919
calc_scores_per_stream,
2020
metric_list_to_json,
2121
plot_data,
@@ -25,23 +25,8 @@
2525

2626
_logger = logging.getLogger(__name__)
2727

28-
if __name__ == "__main__":
29-
parser = argparse.ArgumentParser(
30-
description="Fast evaluation of WeatherGenerator runs."
31-
)
32-
parser.add_argument(
33-
"--config",
34-
type=str,
35-
help="Path to the configuration yaml file for plotting. e.g. config/plottig_config.yaml",
36-
)
37-
38-
args = parser.parse_args()
39-
40-
# configure logging
41-
logging.basicConfig(level=logging.INFO)
4228

43-
# load configuration
44-
cfg = OmegaConf.load(args.config)
29+
def run_main(cfg: DictConfig) -> None:
4530
runs = cfg.run_ids
4631

4732
_logger.info(f"Detected {len(runs)} runs")
@@ -52,14 +37,12 @@
5237
out_scores_dir = Path(cfg.output_scores_dir)
5338
out_scores_dir.mkdir(parents=True, exist_ok=True)
5439

55-
results_dir = Path(cfg.results_dir)
5640
metrics = cfg.evaluation.metrics
5741

5842
# to get a structure like: scores_dict[metric][stream][run_id] = plot
5943
scores_dict = defaultdict(lambda: defaultdict(dict))
6044

6145
for run_id, run in runs.items():
62-
plotter = Plotter(cfg, run_id)
6346
_logger.info(f"RUN {run_id}: Getting data...")
6447

6548
streams = run["streams"].keys()
@@ -71,7 +54,7 @@
7154

7255
if stream_dict.get("plotting"):
7356
_logger.info(f"RUN {run_id}: Plotting stream {stream}...")
74-
plots = plot_data(cfg, run_id, stream, stream_dict)
57+
plot_data(cfg, run_id, stream, stream_dict)
7558

7659
if stream_dict.get("evaluation"):
7760
_logger.info(f"Retrieve or compute scores for {run_id} - {stream}...")
@@ -109,9 +92,27 @@
10992
scores_dict[metric][stream][run_id] = all_metrics.sel(
11093
{"metric": metric}
11194
)
95+
# plot summary
96+
if scores_dict and cfg.summary_plots:
97+
_logger.info("Started creating summary plots..")
98+
plot_summary(cfg, scores_dict, print_summary=cfg.print_summary)
11299

113100

114-
# plot summary
115-
if scores_dict and cfg.summary_plots:
116-
_logger.info("Started creating summary plots..")
117-
plot_summary(cfg, scores_dict, print_summary=cfg.print_summary)
101+
if __name__ == "__main__":
102+
parser = argparse.ArgumentParser(
103+
description="Fast evaluation of WeatherGenerator runs."
104+
)
105+
parser.add_argument(
106+
"--config",
107+
type=str,
108+
help="Path to the configuration yaml file for plotting. e.g. config/plottig_config.yaml",
109+
)
110+
111+
args = parser.parse_args()
112+
113+
# configure logging
114+
logging.basicConfig(level=logging.INFO)
115+
116+
# load configuration
117+
cfg = OmegaConf.load(args.config)
118+
run_main(cfg)

packages/evaluate/src/weathergen/evaluate/score.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import numpy as np
1515
import pandas as pd
1616
import xarray as xr
17-
from score_utils import to_list
17+
18+
from weathergen.evaluate.score_utils import to_list
1819

1920
# from common.io import MockIO
2021

packages/evaluate/src/weathergen/evaluate/utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
import numpy as np
1616
import omegaconf as oc
1717
import xarray as xr
18-
from plotter import DefaultMarkerSize, LinePlots, Plotter
19-
from score import VerifiedData, get_score
2018
from tqdm import tqdm
2119

2220
from weathergen.common.io import ZarrIO
21+
from weathergen.evaluate.plotter import DefaultMarkerSize, LinePlots, Plotter
22+
from weathergen.evaluate.score import VerifiedData, get_score
2323
from weathergen.evaluate.score_utils import to_list
2424

2525
_logger = logging.getLogger(__name__)
@@ -310,9 +310,7 @@ def plot_data(cfg: str, run_id: str, stream: str, stream_dict: dict) -> list[str
310310
# Check if histograms should be plotted
311311
plot_histograms = plot_settings.get("plot_histograms", False)
312312
if not isinstance(plot_settings.plot_histograms, bool):
313-
raise TypeError(
314-
"plot_histograms must be a boolean."
315-
)
313+
raise TypeError("plot_histograms must be a boolean.")
316314

317315
if plot_fsteps == "all":
318316
plot_fsteps = None

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies = [
2828
"dask~=2025.5.1",
2929
"hatchling",
3030
"weathergen-common",
31+
"numexpr>=2.11.0",
3132
]
3233

3334
[project.urls]
@@ -53,9 +54,12 @@ packages = ["src/weathergen"]
5354
[dependency-groups]
5455
# The development dependencies
5556
dev = [
57+
"ipykernel>=6.30.0",
58+
"jupytext>=1.17.2",
5659
"pytest~=8.3.5",
5760
"pytest-mock>=3.14.1",
5861
"ruff==0.9.7",
62+
"tensorboard>=2.20.0",
5963
]
6064

6165

scripts/actions.sh

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#!/bin/bash
22

3-
# Get the directory where the script is located
3+
# TODO: this is the root weathergenerator directory, rename the variable.
44
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && cd .. && pwd)"
55

66
case "$1" in
77
sync)
88
(
99
cd "$SCRIPT_DIR" || exit 1
10-
uv sync
10+
uv sync --all-packages
1111
)
1212
;;
1313
lint)
@@ -63,8 +63,27 @@ case "$1" in
6363
done
6464
)
6565
;;
66+
create-jupyter-kernel)
67+
(
68+
cd "$SCRIPT_DIR" || exit 1
69+
uv sync --all-packages
70+
uv run ipython kernel install --user --env VIRTUAL_ENV $(pwd)/.venv --name=weathergen_kernel --display-name "Python (WeatherGenerator)"
71+
echo "Jupyter kernel created. You can now use it in Jupyter Notebook or JupyterLab."
72+
echo "To use this kernel, select 'Python (WeatherGenerator)' from the kernel options in Jupyter Notebook or JupyterLab."
73+
echo "If you want to remove the kernel later, you can run:"
74+
echo "jupyter kernelspec uninstall weathergen_kernel"
75+
)
76+
;;
77+
jupytext-sync)
78+
(
79+
cd "$SCRIPT_DIR" || exit 1
80+
# Run on any python or jupyter notebook files in the WeatherGenerator-private/notebooks directory
81+
uv run jupytext --set-formats ipynb,py:percent --sync ../WeatherGenerator-private/notebooks/*.ipynb ../WeatherGenerator-private/notebooks/*.py
82+
echo "Jupytext sync completed."
83+
)
84+
;;
6685
*)
67-
echo "Usage: $0 {sync|lint|unit-test|integration-test|create-links}"
86+
echo "Usage: $0 {sync|lint|unit-test|integration-test|create-links|create-jupyter-kernel|jupytext-sync}"
6887
exit 1
6988
;;
7089
esac

src/weathergen/utils/config.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323
_REPO_ROOT = Path(__file__).parent.parent.parent.parent # TODO use importlib for resources
2424
_DEFAULT_CONFIG_PTH = _REPO_ROOT / "config" / "default_config.yml"
25-
_DEFAULT_MODEL_PATH = "./models"
26-
_DEFAULT_RESULT_PATH = "./results"
25+
_DEFAULT_MODEL_PATH = _REPO_ROOT / "models"
26+
_DEFAULT_RESULT_PATH = _REPO_ROOT / "results"
2727

2828
_logger = logging.getLogger(__name__)
2929

@@ -133,7 +133,7 @@ def load_config(
133133
if from_run_id is None:
134134
base_config = _load_default_conf()
135135
else:
136-
base_config = load_model_config(from_run_id, epoch, private_config["model_path"])
136+
base_config = load_model_config(from_run_id, epoch, private_config.get("model_path", None))
137137

138138
# use OmegaConf.unsafe_merge if too slow
139139
return OmegaConf.merge(base_config, private_config, *overwrite_configs)
@@ -283,7 +283,7 @@ def _load_private_conf(private_home: Path | None) -> DictConfig:
283283
)
284284
private_cf = OmegaConf.load(private_home)
285285
private_cf["model_path"] = (
286-
private_cf["model_path"] if "model_path" in private_cf.keys() else "./models"
286+
private_cf["model_path"] if "model_path" in private_cf.keys() else None
287287
)
288288

289289
if "secrets" in private_cf:
@@ -345,8 +345,9 @@ def load_streams(streams_directory: Path) -> list[Config]:
345345
def set_paths(config: Config) -> Config:
346346
"""Set the configs run_path model_path attributes to default values if not present."""
347347
config = config.copy()
348-
config.run_path = config.get("run_path", None) or _DEFAULT_RESULT_PATH
349-
config.model_path = config.get("model_path", None) or _DEFAULT_MODEL_PATH
348+
# pathlib.Path are not json serializable, so we convert them to str
349+
config.run_path = config.get("run_path", None) or str(_DEFAULT_RESULT_PATH)
350+
config.model_path = config.get("model_path", None) or str(_DEFAULT_MODEL_PATH)
350351

351352
return config
352353

0 commit comments

Comments
 (0)