From 0da3e3625c920c20329f24f2403e0aaf1d7dc50c Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 17 Oct 2025 10:05:03 +0200 Subject: [PATCH 1/7] draft changes --- workflow/rules/common.smk | 6 +- workflow/rules/inference.smk | 204 +++++++++++++++++++---------------- workflow/rules/verif.smk | 6 +- 3 files changed, 120 insertions(+), 96 deletions(-) diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index cea00b2..2a04595 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -115,9 +115,11 @@ def _inference_routing_fn(wc): run_config = RUN_CONFIGS[wc.run_id] if run_config["model_type"] == "forecaster": - input_path = f"logs/inference_forecaster/{wc.run_id}-{wc.init_time}.ok" + input_path = f"logs/prepare_inference_forecaster/{wc.run_id}-{wc.init_time}.ok" elif run_config["model_type"] == "interpolator": - input_path = f"logs/inference_interpolator/{wc.run_id}-{wc.init_time}.ok" + input_path = ( + f"logs/prepare_inference_interpolator/{wc.run_id}-{wc.init_time}.ok" + ) else: raise ValueError(f"Unsupported model type: {run_config['model_type']}") diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index 38ebf28..da99b82 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -8,6 +8,13 @@ from datetime import datetime rule create_inference_pyproject: + """ + Generate a pyproject.toml that contains the information needed + to set up a virtual environment for inference of a specific checkpoint. + The list of dependencies is taken from the checkpoint's MLFlow run metadata, + and additional dependencies can be specified under a run entry in the main + config file. + """ input: toml="workflow/envs/anemoi_inference.toml", output: @@ -25,6 +32,11 @@ rule create_inference_pyproject: rule create_inference_venv: + """ + Create a virtual environment for inference, using the pyproject.toml created above. + The virtual environment is managed with uv. The created virtual environment is relocatable, + so it can be squashed later. Pre-compilation to bytecode is done to speed up imports. + """ input: pyproject=rules.create_inference_pyproject.output.pyproject, output: @@ -56,11 +68,12 @@ rule create_inference_venv: """ -# optionally, precompile to bytecode to reduce the import times -# find {output.venv} -exec stat --format='%i' {} + | sort -u | wc -l # optionally, how many files did I create? - - rule make_squashfs_image: + """ + Create a squashfs image for the inference virtual environment of + a specific checkpoint. Find more about this at + https://docs.cscs.ch/guides/storage/#python-virtual-environments-with-uenv. + """ input: venv=rules.create_inference_venv.output.venv, output: @@ -76,7 +89,11 @@ rule make_squashfs_image: rule create_inference_sandbox: - """Generate a zipped directory that can be used as a sandbox for running inference jobs. + """ + Create a zipped directory that, when extracted, can be used as a sandbox + for running inference jobs for a specific checkpoint. Its main purpose is + to serve as a development environment for anemoi-inference and to facilitate + sharing with external collaborators. TO use this sandbox, unzip it to a target directory. @@ -118,14 +135,16 @@ def get_resource(wc, field: str, default): return getattr(rc["inference_resources"], field) or default -rule inference_forecaster: +rule prepare_inference_forecaster: localrule: True input: pyproject=rules.create_inference_pyproject.output.pyproject, - image=rules.make_squashfs_image.output.image, config=lambda wc: Path(RUN_CONFIGS[wc.run_id]["config"]).resolve(), output: - okfile=touch(OUT_ROOT / "logs/inference_forecaster/{run_id}-{init_time}.ok"), + config=Path(OUT_ROOT / "data/runs/{run_id}/{init_time}/config.yaml"), + resources=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/resources"), + grib_out_dir=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/grib"), + okfile=OUT_ROOT / "logs/prepare_inference_forecaster/{run_id}-{init_time}.ok", params: checkpoints_path=parse_input( input.pyproject, parse_toml, key="tool.anemoi.checkpoints_path" @@ -136,53 +155,30 @@ rule inference_forecaster: reftime_to_iso=lambda wc: datetime.strptime( wc.init_time, "%Y%m%d%H%M" ).strftime("%Y-%m-%dT%H:%M"), - image_path=lambda wc, input: f"{Path(input.image).resolve()}", log: - OUT_ROOT / "logs/inference_forecaster/{run_id}-{init_time}.log", - resources: - slurm_partition=lambda wc: get_resource(wc, "slurm_partition", "short-shared"), - cpus_per_task=lambda wc: get_resource(wc, "cpus_per_task", 24), - mem_mb_per_cpu=lambda wc: get_resource(wc, "mem_mb_per_cpu", 8000), - runtime=lambda wc: get_resource(wc, "runtime", "20m"), - gres=lambda wc: f"gpu:{get_resource(wc, 'gpu',1)}", - ntasks=lambda wc: get_resource(wc, "tasks", 1), - slurm_extra=lambda wc, input: f"--uenv={Path(input.image).resolve()}:/user-environment", - gpus=lambda wc: get_resource(wc, "gpu", 1), - shell: - r""" - ( - set -euo pipefail - squashfs-mount {params.image_path}:/user-environment -- bash -c ' - export TZ=UTC - source /user-environment/bin/activate - export ECCODES_DEFINITION_PATH=/user-environment/share/eccodes-cosmo-resources/definitions + OUT_ROOT / "logs/prepare_inference_forecaster/{run_id}-{init_time}.log", + run: + import yaml + import shutil - # prepare the working directory - WORKDIR={params.output_root}/runs/{wildcards.run_id}/{wildcards.init_time} - mkdir -p $WORKDIR && cd $WORKDIR && mkdir -p grib raw _resources - cp {input.config} config.yaml && cp -r {params.resources_root}/templates/* _resources/ - CMD_ARGS=( - date={params.reftime_to_iso} - checkpoint={params.checkpoints_path}/inference-last.ckpt - lead_time={params.lead_time} + # prepare working directory + workdir = ( + Path(params.output_root) / "runs" / wildcards.run_id / wildcards.init_time ) + workdir.mkdir(parents=True, exist_ok=True) + (workdir / "grib").mkdir(parents=True, exist_ok=True) + (workdir / "resources").mkdir(parents=True, exist_ok=True) - # is GPU > 1, add runner=parallel to CMD_ARGS - if [ {resources.gpus} -gt 1 ]; then - CMD_ARGS+=(runner=parallel) - fi + # prepare and write config file + config = yaml.safe_load(open(input.config)) + config["checkpoint"] = f"{params.checkpoints_path}/inference-last.ckpt" + config["date"] = params.reftime_to_iso + config["lead_time"] = params.lead_time + with open(output.config, "w") as f: + yaml.safe_dump(config, f) - srun \ - --partition={resources.slurm_partition} \ - --cpus-per-task={resources.cpus_per_task} \ - --mem-per-cpu={resources.mem_mb_per_cpu} \ - --time={resources.runtime} \ - --gres={resources.gres} \ - --ntasks={resources.ntasks} \ - anemoi-inference run config.yaml "${{CMD_ARGS[@]}}" - ' - ) > {log} 2>&1 - """ + # copy resources + shutil.copytree(params.resources_root / "templates", output.resources) def _get_forecaster_run_id(run_id): @@ -190,11 +186,10 @@ def _get_forecaster_run_id(run_id): return RUN_CONFIGS[run_id]["forecaster"]["mlflow_id"][0:9] -rule inference_interpolator: +rule prepare_inference_interpolator: """Run the interpolator for a specific run ID.""" input: pyproject=rules.create_inference_pyproject.output.pyproject, - image=rules.make_squashfs_image.output.image, config=lambda wc: Path(RUN_CONFIGS[wc.run_id]["config"]).resolve(), forecasts=lambda wc: ( [ @@ -205,7 +200,13 @@ rule inference_interpolator: else [] ), output: - okfile=touch(OUT_ROOT / "logs/inference_interpolator/{run_id}-{init_time}.ok"), + config=Path(OUT_ROOT / "data/runs/{run_id}/{init_time}/config.yaml"), + resources=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/resources"), + grib_out_dir=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/grib"), + forecaster_grib_dir=directory( + OUT_ROOT / "data/runs/{run_id}/{init_time}/forecaster_grib" + ), + okfile=OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.ok", params: checkpoints_path=parse_input( input.pyproject, parse_toml, key="tool.anemoi.checkpoints_path" @@ -222,43 +223,72 @@ rule inference_interpolator: else _get_forecaster_run_id(wc.run_id) ), log: - OUT_ROOT / "logs/inference_interpolator/{run_id}-{init_time}.log", + OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.log", + run: + import yaml + import shutil + + fct_run_id = params.forecaster_run_id + init_time = wildcards.init_time + + # prepare working directory + workdir = ( + Path(params.output_root) / "runs" / wildcards.run_id / wildcards.init_time + ) + workdir.mkdir(parents=True, exist_ok=True) + (workdir / "grib").mkdir(parents=True, exist_ok=True) + (workdir / "resources").mkdir(parents=True, exist_ok=True) + + # if forecaster_run_id is not "null", create symbolic link to forecaster grib directory + if fct_run_id != "null": + forecaster_workdir = ( + Path(params.output_root) / "runs" / fct_run_id / init_time + ) + (workdir / "forecaster_grib").symlink_to(forecaster_workdir / "grib") + + # prepare and write config file + config = yaml.safe_load(open(input.config)) + config["checkpoint"] = f"{params.checkpoints_path}/inference-last.ckpt" + config["date"] = params.reftime_to_iso + config["lead_time"] = params.lead_time + with open(output.config, "w") as f: + yaml.safe_dump(config, f) + + # copy resources + shutil.copytree(params.resources_root / "templates", output.resources) + + +rule execute_inference: + localrule: True + input: + _inference_routing_fn, + output: + okfile=OUT_ROOT / "logs/execute_inference/{run_id}-{init_time}.ok", + log: + OUT_ROOT / "logs/execute_inference/{run_id}-{init_time}.log", + params: + image_path=(OUT_ROOT / "data/runs/{run_id}/venv.squashfs").resolve(), + workdir=(OUT_ROOT / "data/runs/{run_id}/{init_time}").resolve(), resources: - slurm_partition=lambda wc: get_resource(wc, "slurm_partition", "short-shared"), - cpus_per_task=lambda wc: get_resource(wc, "cpus_per_task", 24), - mem_mb_per_cpu=lambda wc: get_resource(wc, "mem_mb_per_cpu", 8000), - runtime=lambda wc: get_resource(wc, "runtime", "20m"), - gres=lambda wc: f"gpu:{get_resource(wc, 'gpu',1)}", - ntasks=lambda wc: get_resource(wc, "tasks", 1), - slurm_extra=lambda wc, input: f"--uenv={Path(input.image).resolve()}:/user-environment", - gpus=lambda wc: get_resource(wc, "gpu", 1), + slurm_partition=lambda wc: get_resource(wc, "slurm_partition", "standard"), + cpus_per_task=lambda wc: get_resource(wc, "cpus_per_task", 4), + mem_mb_per_cpu=lambda wc: get_resource(wc, "mem_mb_per_cpu", 4000), + runtime=lambda wc: get_resource(wc, "runtime", "40m"), + gres=lambda wc: get_resource(wc, "gres", "gpu:1"), + ntasks=lambda wc: get_resource(wc, "ntasks", 1), + gpus=lambda wc: get_resource(wc, "gpus", 1), shell: - r""" + """ ( set -euo pipefail + + cd {params.workdir} + squashfs-mount {params.image_path}:/user-environment -- bash -c ' - export TZ=UTC source /user-environment/bin/activate export ECCODES_DEFINITION_PATH=/user-environment/share/eccodes-cosmo-resources/definitions - # prepare the working directory - WORKDIR={params.output_root}/runs/{wildcards.run_id}/{wildcards.init_time} - mkdir -p $WORKDIR && cd $WORKDIR && mkdir -p grib raw _resources - cp {input.config} config.yaml && cp -r {params.resources_root}/templates/* _resources/ - - # if forecaster_run_id is not "null", link the forecaster grib directory; else, run from files. - if [ "{params.forecaster_run_id}" != "null" ]; then - FORECASTER_WORKDIR={params.output_root}/runs/{params.forecaster_run_id}/{wildcards.init_time} - ln -fns $FORECASTER_WORKDIR/grib forecaster_grib - else - echo "Forecaster configuration is null; proceeding with file-based inputs." - fi - - CMD_ARGS=( - date={params.reftime_to_iso} - checkpoint={params.checkpoints_path}/inference-last.ckpt - lead_time={params.lead_time} - ) + CMD_ARGS=() # is GPU > 1, add runner=parallel to CMD_ARGS if [ {resources.gpus} -gt 1 ]; then @@ -266,6 +296,7 @@ rule inference_interpolator: fi srun \ + --unbuffered \ --partition={resources.slurm_partition} \ --cpus-per-task={resources.cpus_per_task} \ --mem-per-cpu={resources.mem_mb_per_cpu} \ @@ -276,12 +307,3 @@ rule inference_interpolator: ' ) > {log} 2>&1 """ - - -rule inference_routing: - localrule: True - input: - _inference_routing_fn, - output: - directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/grib"), - directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/raw"), diff --git a/workflow/rules/verif.smk b/workflow/rules/verif.smk index 677c8d7..e33e8be 100644 --- a/workflow/rules/verif.smk +++ b/workflow/rules/verif.smk @@ -55,8 +55,7 @@ def _get_no_none(dict, key, replacement): rule verif_metrics: input: script="workflow/scripts/verif_from_grib.py", - inference_okfile=_inference_routing_fn, - grib_output=rules.inference_routing.output[0], + inference_okfile=rules.execute_inference.output.okfile, analysis_zarr=config["analysis"].get("analysis_zarr"), output: OUT_ROOT / "data/runs/{run_id}/{init_time}/verif.nc", @@ -67,6 +66,7 @@ rule verif_metrics: fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), fcst_steps=lambda wc: _get_no_none(RUN_CONFIGS[wc.run_id], "steps", "0/126/6"), analysis_label=config["analysis"].get("label"), + grib_out_dir=(Path(OUT_ROOT) / "data/runs/{run_id}/{init_time}/grib").resolve(), log: OUT_ROOT / "logs/verif_metrics/{run_id}-{init_time}.log", resources: @@ -76,7 +76,7 @@ rule verif_metrics: shell: """ uv run {input.script} \ - --grib_output_dir {input.grib_output} \ + --grib_output_dir {params.grib_out_dir} \ --analysis_zarr {input.analysis_zarr} \ --lead_time "{params.fcst_steps}" \ --fcst_label "{params.fcst_label}" \ From a707cbb02ba6bac5f27f6f0eecedcbbea8ca0ea4 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 17 Oct 2025 10:55:15 +0200 Subject: [PATCH 2/7] rename workspace resources dir --- resources/inference/configs/forecaster.yaml | 2 +- .../configs/forecaster_no_trimedge.yaml | 2 +- .../forecaster_no_trimedge_fromtraining.yaml | 2 +- .../configs/forecaster_with_global.yaml | 4 ++-- resources/inference/configs/interpolator.yaml | 2 +- .../configs/interpolator_from_test_data.yaml | 2 +- .../configs/interpolator_stretched.yaml | 2 +- .../templates/templates_index_cosmo.yaml | 16 ++++++++-------- .../inference/templates/templates_index_ifs.yaml | 4 ++-- 9 files changed, 18 insertions(+), 18 deletions(-) diff --git a/resources/inference/configs/forecaster.yaml b/resources/inference/configs/forecaster.yaml index dac3e49..437ee31 100644 --- a/resources/inference/configs/forecaster.yaml +++ b/resources/inference/configs/forecaster.yaml @@ -18,7 +18,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml - printer write_initial_state: true diff --git a/resources/inference/configs/forecaster_no_trimedge.yaml b/resources/inference/configs/forecaster_no_trimedge.yaml index 2e3417d..306c62f 100644 --- a/resources/inference/configs/forecaster_no_trimedge.yaml +++ b/resources/inference/configs/forecaster_no_trimedge.yaml @@ -15,7 +15,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml - printer write_initial_state: true diff --git a/resources/inference/configs/forecaster_no_trimedge_fromtraining.yaml b/resources/inference/configs/forecaster_no_trimedge_fromtraining.yaml index 11188b9..b5097f5 100644 --- a/resources/inference/configs/forecaster_no_trimedge_fromtraining.yaml +++ b/resources/inference/configs/forecaster_no_trimedge_fromtraining.yaml @@ -15,7 +15,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml - printer write_initial_state: true diff --git a/resources/inference/configs/forecaster_with_global.yaml b/resources/inference/configs/forecaster_with_global.yaml index ae00c3a..167274e 100644 --- a/resources/inference/configs/forecaster_with_global.yaml +++ b/resources/inference/configs/forecaster_with_global.yaml @@ -17,13 +17,13 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml - grib: path: grib/ifs-{dateTime}_{step:03}.grib encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_ifs.yaml + samples: resources/templates_index_ifs.yaml post_processors: - extract_slice: [189699, -1] - assign_mask: "global/cutout_mask" diff --git a/resources/inference/configs/interpolator.yaml b/resources/inference/configs/interpolator.yaml index 8cbb98f..aefc1e3 100644 --- a/resources/inference/configs/interpolator.yaml +++ b/resources/inference/configs/interpolator.yaml @@ -54,7 +54,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml forcings: test: diff --git a/resources/inference/configs/interpolator_from_test_data.yaml b/resources/inference/configs/interpolator_from_test_data.yaml index aaa938f..97d0e7d 100644 --- a/resources/inference/configs/interpolator_from_test_data.yaml +++ b/resources/inference/configs/interpolator_from_test_data.yaml @@ -15,7 +15,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml forcings: test: diff --git a/resources/inference/configs/interpolator_stretched.yaml b/resources/inference/configs/interpolator_stretched.yaml index 0010ffe..250d4ca 100644 --- a/resources/inference/configs/interpolator_stretched.yaml +++ b/resources/inference/configs/interpolator_stretched.yaml @@ -100,7 +100,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml verbosity: 1 allow_nans: true diff --git a/resources/inference/templates/templates_index_cosmo.yaml b/resources/inference/templates/templates_index_cosmo.yaml index 8f15004..632164a 100644 --- a/resources/inference/templates/templates_index_cosmo.yaml +++ b/resources/inference/templates/templates_index_cosmo.yaml @@ -1,26 +1,26 @@ # COSMO-2 templates - - {grid: 0.02, levtype: pl} - - _resources/co2-typeOfLevel=isobaricInhPa.grib + - resources/co2-typeOfLevel=isobaricInhPa.grib - - {grid: 0.02, levtype: sfc, param: [T_2M, TD_2M, U_10M, V_10M]} - - _resources/co2-typeOfLevel=heightAboveGround.grib + - resources/co2-typeOfLevel=heightAboveGround.grib - - {grid: 0.02, levtype: sfc, param: [FR_LAND, TOC_PREC, PMSL, PS, FIS, T_G]} - - _resources/co2-typeOfLevel=surface.grib + - resources/co2-typeOfLevel=surface.grib - - {grid: 0.02, levtype: sfc, param: [TOT_PREC]} - - _resources/co2-shortName=TOT_PREC.grib + - resources/co2-shortName=TOT_PREC.grib # COSMO-1E templates - - {grid: 0.01, levtype: pl} - - _resources/co1e-typeOfLevel=isobaricInhPa.grib + - resources/co1e-typeOfLevel=isobaricInhPa.grib - - {grid: 0.01, levtype: sfc, param: [T_2M, TD_2M, U_10M, V_10M]} - - _resources/co1e-typeOfLevel=heightAboveGround.grib + - resources/co1e-typeOfLevel=heightAboveGround.grib - - {grid: 0.01, levtype: sfc, param: [FR_LAND, TOC_PREC, PMSL, PS, FIS, T_G]} - - _resources/co1e-typeOfLevel=surface.grib + - resources/co1e-typeOfLevel=surface.grib - - {grid: 0.01, levtype: sfc, param: [TOT_PREC]} - - _resources/co1e-shortName=TOT_PREC.grib + - resources/co1e-shortName=TOT_PREC.grib diff --git a/resources/inference/templates/templates_index_ifs.yaml b/resources/inference/templates/templates_index_ifs.yaml index a399ed9..c0700cf 100644 --- a/resources/inference/templates/templates_index_ifs.yaml +++ b/resources/inference/templates/templates_index_ifs.yaml @@ -1,5 +1,5 @@ - - {levtype: pl} - - _resources/ifs-levtype=pl.grib + - resources/ifs-levtype=pl.grib - - {levtype: sfc} - - _resources/ifs-levtype=sfc.grib + - resources/ifs-levtype=sfc.grib From 68421aab8deeb87849a4507c1a69d63ebc52e31a Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 17 Oct 2025 11:52:33 +0200 Subject: [PATCH 3/7] working for config/forecasters.yaml --- workflow/rules/inference.smk | 41 ++++++++++++++++++------------------ workflow/rules/verif.smk | 4 +++- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index da99b82..ce4d428 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -144,7 +144,9 @@ rule prepare_inference_forecaster: config=Path(OUT_ROOT / "data/runs/{run_id}/{init_time}/config.yaml"), resources=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/resources"), grib_out_dir=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/grib"), - okfile=OUT_ROOT / "logs/prepare_inference_forecaster/{run_id}-{init_time}.ok", + okfile=touch( + OUT_ROOT / "logs/prepare_inference_forecaster/{run_id}-{init_time}.ok" + ), params: checkpoints_path=parse_input( input.pyproject, parse_toml, key="tool.anemoi.checkpoints_path" @@ -167,7 +169,7 @@ rule prepare_inference_forecaster: ) workdir.mkdir(parents=True, exist_ok=True) (workdir / "grib").mkdir(parents=True, exist_ok=True) - (workdir / "resources").mkdir(parents=True, exist_ok=True) + shutil.copytree(params.resources_root / "templates", output.resources) # prepare and write config file config = yaml.safe_load(open(input.config)) @@ -177,9 +179,6 @@ rule prepare_inference_forecaster: with open(output.config, "w") as f: yaml.safe_dump(config, f) - # copy resources - shutil.copytree(params.resources_root / "templates", output.resources) - def _get_forecaster_run_id(run_id): """Get the forecaster run ID from the RUN_CONFIGS.""" @@ -206,7 +205,9 @@ rule prepare_inference_interpolator: forecaster_grib_dir=directory( OUT_ROOT / "data/runs/{run_id}/{init_time}/forecaster_grib" ), - okfile=OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.ok", + okfile=touch( + OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.ok" + ), params: checkpoints_path=parse_input( input.pyproject, parse_toml, key="tool.anemoi.checkpoints_path" @@ -237,7 +238,7 @@ rule prepare_inference_interpolator: ) workdir.mkdir(parents=True, exist_ok=True) (workdir / "grib").mkdir(parents=True, exist_ok=True) - (workdir / "resources").mkdir(parents=True, exist_ok=True) + shutil.copytree(params.resources_root / "templates", output.resources) # if forecaster_run_id is not "null", create symbolic link to forecaster grib directory if fct_run_id != "null": @@ -254,29 +255,29 @@ rule prepare_inference_interpolator: with open(output.config, "w") as f: yaml.safe_dump(config, f) - # copy resources - shutil.copytree(params.resources_root / "templates", output.resources) - rule execute_inference: localrule: True input: - _inference_routing_fn, + okfile=_inference_routing_fn, + image=rules.make_squashfs_image.output.image, output: - okfile=OUT_ROOT / "logs/execute_inference/{run_id}-{init_time}.ok", + okfile=touch(OUT_ROOT / "logs/execute_inference/{run_id}-{init_time}.ok"), log: OUT_ROOT / "logs/execute_inference/{run_id}-{init_time}.log", params: - image_path=(OUT_ROOT / "data/runs/{run_id}/venv.squashfs").resolve(), - workdir=(OUT_ROOT / "data/runs/{run_id}/{init_time}").resolve(), + image_path=lambda wc, input: f"{Path(input.image).resolve()}", + workdir=lambda wc: ( + OUT_ROOT / f"data/runs/{wc.run_id}/{wc.init_time}" + ).resolve(), resources: - slurm_partition=lambda wc: get_resource(wc, "slurm_partition", "standard"), - cpus_per_task=lambda wc: get_resource(wc, "cpus_per_task", 4), - mem_mb_per_cpu=lambda wc: get_resource(wc, "mem_mb_per_cpu", 4000), + slurm_partition=lambda wc: get_resource(wc, "slurm_partition", "short-shared"), + cpus_per_task=lambda wc: get_resource(wc, "cpus_per_task", 24), + mem_mb_per_cpu=lambda wc: get_resource(wc, "mem_mb_per_cpu", 8000), runtime=lambda wc: get_resource(wc, "runtime", "40m"), - gres=lambda wc: get_resource(wc, "gres", "gpu:1"), - ntasks=lambda wc: get_resource(wc, "ntasks", 1), - gpus=lambda wc: get_resource(wc, "gpus", 1), + gres=lambda wc: f"gpu:{get_resource(wc, 'gpu',1)}", + ntasks=lambda wc: get_resource(wc, "tasks", 1), + gpus=lambda wc: get_resource(wc, "gpu", 1), shell: """ ( diff --git a/workflow/rules/verif.smk b/workflow/rules/verif.smk index e33e8be..f22a433 100644 --- a/workflow/rules/verif.smk +++ b/workflow/rules/verif.smk @@ -66,7 +66,9 @@ rule verif_metrics: fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), fcst_steps=lambda wc: _get_no_none(RUN_CONFIGS[wc.run_id], "steps", "0/126/6"), analysis_label=config["analysis"].get("label"), - grib_out_dir=(Path(OUT_ROOT) / "data/runs/{run_id}/{init_time}/grib").resolve(), + grib_out_dir=lambda wc: ( + Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" + ).resolve(), log: OUT_ROOT / "logs/verif_metrics/{run_id}-{init_time}.log", resources: From 63545b451ea34a90a235448e49539ae67e745836 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 17 Oct 2025 17:11:06 +0200 Subject: [PATCH 4/7] improve logging --- workflow/rules/common.smk | 31 ++++++++ workflow/rules/inference.smk | 143 ++++++++++++++++++++++++----------- 2 files changed, 128 insertions(+), 46 deletions(-) diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 2a04595..33ac882 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -1,3 +1,4 @@ +import logging import copy from datetime import datetime, timedelta import yaml @@ -126,6 +127,36 @@ def _inference_routing_fn(wc): return OUT_ROOT / input_path +def setup_logger(logger_name, log_file, level=logging.INFO): + """ + Set up a logger with a file handler. + + Args: + logger_name (str): Name of the logger. + log_file (str): Path to the log file. + level (int): Logging level (e.g., logging.INFO, logging.DEBUG). + + Returns: + logging.Logger: Configured logger. + """ + logger = logging.getLogger(logger_name) + logger.setLevel(level) + + # Prevent adding multiple handlers if the logger already exists + if not logger.handlers: + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(level) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + + return logger + + RUN_CONFIGS = collect_all_runs() BASELINE_CONFIGS = collect_all_baselines() EXPERIMENT_PARTICIPANTS = collect_experiment_participants() diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index ce4d428..876fdac 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -160,24 +160,45 @@ rule prepare_inference_forecaster: log: OUT_ROOT / "logs/prepare_inference_forecaster/{run_id}-{init_time}.log", run: - import yaml - import shutil + LOG = setup_logger("prepare_inference_forecaster", log_file=log[0]) + try: + import yaml + import shutil + + L( + "Preparing inference forecaster for run_id=%s, init_time=%s", + wildcards.run_id, + wildcards.init_time, + ) - # prepare working directory - workdir = ( - Path(params.output_root) / "runs" / wildcards.run_id / wildcards.init_time - ) - workdir.mkdir(parents=True, exist_ok=True) - (workdir / "grib").mkdir(parents=True, exist_ok=True) - shutil.copytree(params.resources_root / "templates", output.resources) + # prepare working directory + workdir = ( + Path(params.output_root) + / "runs" + / wildcards.run_id + / wildcards.init_time + ) + workdir.mkdir(parents=True, exist_ok=True) + LOG.info("Created working directory at %s", workdir) + (workdir / "grib").mkdir(parents=True, exist_ok=True) + LOG.info("Created GRIB output directory at %s", workdir / "grib") + shutil.copytree(params.resources_root / "templates", output.resources) + LOG.info("Copied resources to %s", output.resources) + LOG.info("Resources: \n%s", list(Path(output.resources).rglob("*"))) - # prepare and write config file - config = yaml.safe_load(open(input.config)) - config["checkpoint"] = f"{params.checkpoints_path}/inference-last.ckpt" - config["date"] = params.reftime_to_iso - config["lead_time"] = params.lead_time - with open(output.config, "w") as f: - yaml.safe_dump(config, f) + # prepare and write config file + with open(input.config, "r") as f: + config = yaml.safe_load(f) + config["checkpoint"] = f"{params.checkpoints_path}/inference-last.ckpt" + config["date"] = params.reftime_to_iso + config["lead_time"] = params.lead_time + with open(output.config, "w") as f: + yaml.safe_dump(config, f) + LOG.info("Config: \n%s", config) + LOG.info("Wrote config file at %s", output.config) + except Exception as e: + LOG.error("An error occurred: %s", str(e)) + raise e def _get_forecaster_run_id(run_id): @@ -187,13 +208,14 @@ def _get_forecaster_run_id(run_id): rule prepare_inference_interpolator: """Run the interpolator for a specific run ID.""" + localrule: True input: pyproject=rules.create_inference_pyproject.output.pyproject, config=lambda wc: Path(RUN_CONFIGS[wc.run_id]["config"]).resolve(), forecasts=lambda wc: ( [ OUT_ROOT - / f"logs/inference_forecaster/{_get_forecaster_run_id(wc.run_id)}-{wc.init_time}.ok" + / f"logs/execute_inference/{_get_forecaster_run_id(wc.run_id)}-{wc.init_time}.ok" ] if RUN_CONFIGS[wc.run_id].get("forecaster") is not None else [] @@ -202,9 +224,7 @@ rule prepare_inference_interpolator: config=Path(OUT_ROOT / "data/runs/{run_id}/{init_time}/config.yaml"), resources=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/resources"), grib_out_dir=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/grib"), - forecaster_grib_dir=directory( - OUT_ROOT / "data/runs/{run_id}/{init_time}/forecaster_grib" - ), + forecaster=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/forecaster"), okfile=touch( OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.ok" ), @@ -226,34 +246,65 @@ rule prepare_inference_interpolator: log: OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.log", run: - import yaml - import shutil - - fct_run_id = params.forecaster_run_id - init_time = wildcards.init_time - - # prepare working directory - workdir = ( - Path(params.output_root) / "runs" / wildcards.run_id / wildcards.init_time + LOG = setup_logger( + f"prepare_inference_interpolator_{OUT_ROOT.stem}", log_file=log[0] ) - workdir.mkdir(parents=True, exist_ok=True) - (workdir / "grib").mkdir(parents=True, exist_ok=True) - shutil.copytree(params.resources_root / "templates", output.resources) - - # if forecaster_run_id is not "null", create symbolic link to forecaster grib directory - if fct_run_id != "null": - forecaster_workdir = ( - Path(params.output_root) / "runs" / fct_run_id / init_time + try: + import yaml + import shutil + + fct_run_id = params.forecaster_run_id + init_time = wildcards.init_time + + LOG.info( + "Preparing inference interpolator for run_id=%s, init_time=%s, forecaster_run_id=%s", + wildcards.run_id, + wildcards.init_time, + fct_run_id, ) - (workdir / "forecaster_grib").symlink_to(forecaster_workdir / "grib") - - # prepare and write config file - config = yaml.safe_load(open(input.config)) - config["checkpoint"] = f"{params.checkpoints_path}/inference-last.ckpt" - config["date"] = params.reftime_to_iso - config["lead_time"] = params.lead_time - with open(output.config, "w") as f: - yaml.safe_dump(config, f) + # prepare working directory + workdir = ( + Path(params.output_root) + / "runs" + / wildcards.run_id + / wildcards.init_time + ) + workdir.mkdir(parents=True, exist_ok=True) + LOG.info("Created working directory at %s", workdir) + (workdir / "grib").mkdir(parents=True, exist_ok=True) + LOG.info("Created GRIB output directory at %s", workdir / "grib") + shutil.copytree(params.resources_root / "templates", output.resources) + LOG.info("Copied resources to %s", output.resources) + LOG.info("Resources: \n%s", list(Path(output.resources).rglob("*"))) + + # if forecaster_run_id is not "null", create symbolic link to forecaster grib directory + if fct_run_id != "null": + forecaster_workdir = ( + Path(params.output_root) / "runs" / fct_run_id / init_time + ) + (workdir / "forecaster").symlink_to(forecaster_workdir / "grib") + LOG.info( + "Created symlink to forecaster GRIB directory at %s", + workdir / "forecaster", + ) + else: + (workdir / "forecaster").mkdir(parents=True, exist_ok=True) + (workdir / "forecaster/.dataset").touch() + LOG.info("No forecaster run ID provided, skipping symlink creation.") + + # prepare and write config file + with open(input.config, "r") as f: + config = yaml.safe_load(f) + config["checkpoint"] = f"{params.checkpoints_path}/inference-last.ckpt" + config["date"] = params.reftime_to_iso + config["lead_time"] = params.lead_time + with open(output.config, "w") as f: + yaml.safe_dump(config, f) + LOG.info("Config: \n%s", config) + LOG.info("Wrote config file at %s", output.config) + except Exception as e: + LOG.error("An error occurred: %s", str(e)) + raise e rule execute_inference: From 7b2826ddbb2d7a305b40d678c8718b7de312f0d1 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Sat, 18 Oct 2025 16:05:26 +0200 Subject: [PATCH 5/7] works for interpolators.yaml --- config/interpolators.yaml | 6 +++--- .../configs/forecaster_with_global.yaml | 2 +- resources/inference/configs/interpolator.yaml | 10 +++++++++- .../configs/interpolator_from_test_data.yaml | 1 - .../interpolator_from_test_data_stretched.yaml | 2 +- .../configs/interpolator_stretched.yaml | 4 ++-- workflow/envs/anemoi_inference.toml | 2 +- workflow/rules/inference.smk | 17 ++++++++++------- 8 files changed, 27 insertions(+), 17 deletions(-) diff --git a/config/interpolators.yaml b/config/interpolators.yaml index 893b579..d90323b 100644 --- a/config/interpolators.yaml +++ b/config/interpolators.yaml @@ -17,7 +17,7 @@ runs: config: resources/inference/configs/interpolator_from_test_data.yaml forecaster: null extra_dependencies: - - git+https://github.com/ecmwf/anemoi-inference@fix/interp_files + - git+https://github.com/ecmwf/anemoi-inference@fix/cutout-preprocessors - torch-geometric==2.6.1 - anemoi-graphs==0.5.2 - interpolator: @@ -28,7 +28,7 @@ runs: mlflow_id: d0846032fc7248a58b089cbe8fa4c511 config: resources/inference/configs/forecaster_with_global.yaml extra_dependencies: - - git+https://github.com/ecmwf/anemoi-inference@fix/interp_files + - git+https://github.com/ecmwf/anemoi-inference@fix/cutout-preprocessors - torch-geometric==2.6.1 - anemoi-graphs==0.5.2 - interpolator: @@ -40,7 +40,7 @@ runs: mlflow_id: d0846032fc7248a58b089cbe8fa4c511 config: resources/inference/configs/forecaster_with_global.yaml extra_dependencies: - - git+https://github.com/ecmwf/anemoi-inference@fix/interp_files + - git+https://github.com/ecmwf/anemoi-inference@fix/cutout-preprocessors - torch-geometric==2.6.1 - anemoi-graphs==0.5.2 diff --git a/resources/inference/configs/forecaster_with_global.yaml b/resources/inference/configs/forecaster_with_global.yaml index 167274e..890d3e2 100644 --- a/resources/inference/configs/forecaster_with_global.yaml +++ b/resources/inference/configs/forecaster_with_global.yaml @@ -28,7 +28,7 @@ output: - extract_slice: [189699, -1] - assign_mask: "global/cutout_mask" -forcings: +constant_forcings: test: use_original_paths: true diff --git a/resources/inference/configs/interpolator.yaml b/resources/inference/configs/interpolator.yaml index 5818604..765b093 100644 --- a/resources/inference/configs/interpolator.yaml +++ b/resources/inference/configs/interpolator.yaml @@ -10,7 +10,7 @@ post_processors: input: grib: - path: forecaster_grib/20*.grib # TODO: remove dirty fix to only use local files + path: forecaster/20*.grib # TODO: remove dirty fix to only use local files namer: rules: - - shortName: SKT @@ -60,6 +60,14 @@ constant_forcings: test: use_original_paths: true +dynamic_forcings: + test: + use_original_paths: true + +patch_metadata: + dataset: + constant_fields: [z, lsm] + verbosity: 1 allow_nans: true output_frequency: "1h" diff --git a/resources/inference/configs/interpolator_from_test_data.yaml b/resources/inference/configs/interpolator_from_test_data.yaml index c243b0d..2fdb6cd 100644 --- a/resources/inference/configs/interpolator_from_test_data.yaml +++ b/resources/inference/configs/interpolator_from_test_data.yaml @@ -1,5 +1,4 @@ runner: time_interpolator -include_forcings: true input: test: diff --git a/resources/inference/configs/interpolator_from_test_data_stretched.yaml b/resources/inference/configs/interpolator_from_test_data_stretched.yaml index 19cd733..2167489 100644 --- a/resources/inference/configs/interpolator_from_test_data_stretched.yaml +++ b/resources/inference/configs/interpolator_from_test_data_stretched.yaml @@ -17,7 +17,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml verbosity: 1 allow_nans: true diff --git a/resources/inference/configs/interpolator_stretched.yaml b/resources/inference/configs/interpolator_stretched.yaml index 82ed3b1..2928e76 100644 --- a/resources/inference/configs/interpolator_stretched.yaml +++ b/resources/inference/configs/interpolator_stretched.yaml @@ -4,9 +4,9 @@ input: cutout: lam_0: grib: - path: forecaster_grib/20* pre_processors: - extract_mask: "source0/trimedge_mask" + path: forecaster/20* namer: rules: - - shortName: T @@ -43,7 +43,7 @@ input: - tp global: grib: - path: forecaster_grib/ifs* + path: forecaster/ifs* namer: rules: - - shortName: T diff --git a/workflow/envs/anemoi_inference.toml b/workflow/envs/anemoi_inference.toml index 22618d6..982a673 100644 --- a/workflow/envs/anemoi_inference.toml +++ b/workflow/envs/anemoi_inference.toml @@ -8,7 +8,7 @@ dependencies = [ "torchaudio", "anemoi-datasets>=0.5.23,<0.7.0", "anemoi-graphs>=0.5.0,<0.7.0", - "anemoi-inference>=0.7.0,<0.8.0", + "anemoi-inference>=0.8.0,<0.9.0", "anemoi-models>=0.4.20,<0.6.0", "anemoi-training>=0.3.3,<0.5.0", "anemoi-transform>=0.1.10,<0.3.0", diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index cfdb06f..3c4618d 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -160,12 +160,15 @@ rule prepare_inference_forecaster: log: OUT_ROOT / "logs/prepare_inference_forecaster/{run_id}-{init_time}.log", run: - LOG = setup_logger("prepare_inference_forecaster", log_file=log[0]) + logger_name = ( + f"prepare_inference_forecaster_{wildcards.run_id}_{wildcards.init_time}" + ) + LOG = setup_logger(logger_name, log_file=log[0]) try: import yaml import shutil - L( + LOG.info( "Preparing inference forecaster for run_id=%s, init_time=%s", wildcards.run_id, wildcards.init_time, @@ -193,7 +196,7 @@ rule prepare_inference_forecaster: config["date"] = params.reftime_to_iso config["lead_time"] = params.lead_time with open(output.config, "w") as f: - yaml.safe_dump(config, f) + yaml.safe_dump(config, f, sort_keys=False) LOG.info("Config: \n%s", config) LOG.info("Wrote config file at %s", output.config) except Exception as e: @@ -243,13 +246,13 @@ rule prepare_inference_interpolator: if RUN_CONFIGS[wc.run_id].get("forecaster") is None else _get_forecaster_run_id(wc.run_id) ), - image_path=lambda wc, input: f"{Path(input.image).resolve()}", log: OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.log", run: - LOG = setup_logger( - f"prepare_inference_interpolator_{OUT_ROOT.stem}", log_file=log[0] + logger_name = ( + f"prepare_inference_interpolator_{wildcards.run_id}_{wildcards.init_time}" ) + LOG = setup_logger(logger_name, log_file=log[0]) try: import yaml import shutil @@ -300,7 +303,7 @@ rule prepare_inference_interpolator: config["date"] = params.reftime_to_iso config["lead_time"] = params.lead_time with open(output.config, "w") as f: - yaml.safe_dump(config, f) + yaml.safe_dump(config, f, sort_keys=False) LOG.info("Config: \n%s", config) LOG.info("Wrote config file at %s", output.config) except Exception as e: From 67bb59d645b209592881c787c068efe5ae3ae8c9 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Tue, 21 Oct 2025 15:19:38 +0200 Subject: [PATCH 6/7] re-add get_leadtime function --- workflow/rules/inference.smk | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index 3747579..440b042 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -135,6 +135,12 @@ def get_resource(wc, field: str, default): return getattr(rc["inference_resources"], field) or default +def get_leadtime(wc): + """Get the lead time from the run config.""" + start, end, step = RUN_CONFIGS[wc.run_id]["steps"].split("/") + return f"{end}h" + + rule prepare_inference_forecaster: localrule: True input: From f1c6b13e96a9050bf8b6ff9c937f4678bd327833 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Tue, 21 Oct 2025 16:38:30 +0200 Subject: [PATCH 7/7] refactor run directives into script --- src/evalml/helpers.py | 39 ++++++ workflow/rules/common.smk | 30 ----- workflow/rules/inference.smk | 108 +--------------- workflow/scripts/inference_prepare.py | 173 ++++++++++++++++++++++++++ 4 files changed, 216 insertions(+), 134 deletions(-) create mode 100644 src/evalml/helpers.py create mode 100644 workflow/scripts/inference_prepare.py diff --git a/src/evalml/helpers.py b/src/evalml/helpers.py new file mode 100644 index 0000000..bb3e03c --- /dev/null +++ b/src/evalml/helpers.py @@ -0,0 +1,39 @@ +import logging + + +def setup_logger(logger_name, log_file, level=logging.INFO): + """ + Setup a logger with the specified name and log file path. + + Can be used to set up loggers from python scripts `run` directives + used in the Snakemake workflow. + + Parameters + ---------- + logger_name : str + The name of the logger. + log_file : str + The file path where the log messages will be written. + level : int, optional + The logging level (default is logging.INFO). + + Returns + ------- + logging.Logger + Configured logger instance. + """ + logger = logging.getLogger(logger_name) + logger.setLevel(level) + + if not logger.handlers: + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(level) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + + return logger diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index c28e3b4..e9314cb 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -140,36 +140,6 @@ def _inference_routing_fn(wc): return OUT_ROOT / input_path -def setup_logger(logger_name, log_file, level=logging.INFO): - """ - Set up a logger with a file handler. - - Args: - logger_name (str): Name of the logger. - log_file (str): Path to the log file. - level (int): Logging level (e.g., logging.INFO, logging.DEBUG). - - Returns: - logging.Logger: Configured logger. - """ - logger = logging.getLogger(logger_name) - logger.setLevel(level) - - # Prevent adding multiple handlers if the logger already exists - if not logger.handlers: - file_handler = logging.FileHandler(log_file) - file_handler.setLevel(level) - - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - file_handler.setFormatter(formatter) - - logger.addHandler(file_handler) - - return logger - - RUN_CONFIGS = collect_all_runs() BASELINE_CONFIGS = collect_all_baselines() EXPERIMENT_PARTICIPANTS = collect_experiment_participants() diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index 440b042..b490dc2 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -165,49 +165,8 @@ rule prepare_inference_forecaster: ).strftime("%Y-%m-%dT%H:%M"), log: OUT_ROOT / "logs/prepare_inference_forecaster/{run_id}-{init_time}.log", - run: - logger_name = ( - f"prepare_inference_forecaster_{wildcards.run_id}_{wildcards.init_time}" - ) - LOG = setup_logger(logger_name, log_file=log[0]) - try: - import yaml - import shutil - - LOG.info( - "Preparing inference forecaster for run_id=%s, init_time=%s", - wildcards.run_id, - wildcards.init_time, - ) - - # prepare working directory - workdir = ( - Path(params.output_root) - / "runs" - / wildcards.run_id - / wildcards.init_time - ) - workdir.mkdir(parents=True, exist_ok=True) - LOG.info("Created working directory at %s", workdir) - (workdir / "grib").mkdir(parents=True, exist_ok=True) - LOG.info("Created GRIB output directory at %s", workdir / "grib") - shutil.copytree(params.resources_root / "templates", output.resources) - LOG.info("Copied resources to %s", output.resources) - LOG.info("Resources: \n%s", list(Path(output.resources).rglob("*"))) - - # prepare and write config file - with open(input.config, "r") as f: - config = yaml.safe_load(f) - config["checkpoint"] = f"{params.checkpoints_path}/inference-last.ckpt" - config["date"] = params.reftime_to_iso - config["lead_time"] = params.lead_time - with open(output.config, "w") as f: - yaml.safe_dump(config, f, sort_keys=False) - LOG.info("Config: \n%s", config) - LOG.info("Wrote config file at %s", output.config) - except Exception as e: - LOG.error("An error occurred: %s", str(e)) - raise e + script: + "../scripts/inference_prepare.py" def _get_forecaster_run_id(run_id): @@ -254,67 +213,8 @@ rule prepare_inference_interpolator: ), log: OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.log", - run: - logger_name = ( - f"prepare_inference_interpolator_{wildcards.run_id}_{wildcards.init_time}" - ) - LOG = setup_logger(logger_name, log_file=log[0]) - try: - import yaml - import shutil - - fct_run_id = params.forecaster_run_id - init_time = wildcards.init_time - - LOG.info( - "Preparing inference interpolator for run_id=%s, init_time=%s, forecaster_run_id=%s", - wildcards.run_id, - wildcards.init_time, - fct_run_id, - ) - # prepare working directory - workdir = ( - Path(params.output_root) - / "runs" - / wildcards.run_id - / wildcards.init_time - ) - workdir.mkdir(parents=True, exist_ok=True) - LOG.info("Created working directory at %s", workdir) - (workdir / "grib").mkdir(parents=True, exist_ok=True) - LOG.info("Created GRIB output directory at %s", workdir / "grib") - shutil.copytree(params.resources_root / "templates", output.resources) - LOG.info("Copied resources to %s", output.resources) - LOG.info("Resources: \n%s", list(Path(output.resources).rglob("*"))) - - # if forecaster_run_id is not "null", create symbolic link to forecaster grib directory - if fct_run_id != "null": - forecaster_workdir = ( - Path(params.output_root) / "runs" / fct_run_id / init_time - ) - (workdir / "forecaster").symlink_to(forecaster_workdir / "grib") - LOG.info( - "Created symlink to forecaster GRIB directory at %s", - workdir / "forecaster", - ) - else: - (workdir / "forecaster").mkdir(parents=True, exist_ok=True) - (workdir / "forecaster/.dataset").touch() - LOG.info("No forecaster run ID provided, skipping symlink creation.") - - # prepare and write config file - with open(input.config, "r") as f: - config = yaml.safe_load(f) - config["checkpoint"] = f"{params.checkpoints_path}/inference-last.ckpt" - config["date"] = params.reftime_to_iso - config["lead_time"] = params.lead_time - with open(output.config, "w") as f: - yaml.safe_dump(config, f, sort_keys=False) - LOG.info("Config: \n%s", config) - LOG.info("Wrote config file at %s", output.config) - except Exception as e: - LOG.error("An error occurred: %s", str(e)) - raise e + script: + "../scripts/inference_prepare.py" rule execute_inference: diff --git a/workflow/scripts/inference_prepare.py b/workflow/scripts/inference_prepare.py new file mode 100644 index 0000000..e317877 --- /dev/null +++ b/workflow/scripts/inference_prepare.py @@ -0,0 +1,173 @@ +"""Script to prepare configuration and working directory for inference runs.""" + +import logging +import yaml +import shutil +from pathlib import Path + +from evalml.helpers import setup_logger + + +def prepare_config(default_config_path: str, output_config_path: str, params: dict): + """Prepare the configuration file for the inference run. + + Overrides default configuration parameters with those provided in params + and writes the updated configuration to output_config_path. + + Parameters + ---------- + default_config_path : str + Path to the default configuration file. + output_config_path : str + Path where the updated configuration file will be written. + params : dict + Dictionary of parameters to override in the default configuration. + """ + + with open(default_config_path, "r") as f: + config = yaml.safe_load(f) + + config = _override_recursive(config, params) + + with open(output_config_path, "w") as f: + yaml.safe_dump(config, f, sort_keys=False) + + +def prepare_workdir(workdir: Path, resources_root: Path): + """Prepare the working directory for the inference run. + + Creates necessary subdirectories and copies resource files. + + Parameters + ---------- + workdir : Path + Path to the working directory. + resources_root : Path + Path to the root directory containing resource files. + """ + workdir.mkdir(parents=True, exist_ok=True) + (workdir / "grib").mkdir(parents=True, exist_ok=True) + shutil.copytree(resources_root / "templates", workdir / "resources") + + +def prepare_interpolator(smk): + """Prepare the interpolator for the inference run. + + Required steps: + - prepare working directory + - prepare forecaster directory + - prepare config + """ + LOG = _setup_logger(smk) + + # prepare working directory + workdir = _get_workdir(smk) + prepare_workdir(workdir, smk.params.resources_root) + LOG.info("Prepared working directory at %s", workdir) + res_list = "\n".join([str(fn) for fn in Path(workdir / "resources").rglob("*")]) + LOG.info("Resources: \n%s", res_list) + + # prepare forecaster directory + fct_run_id = smk.params.forecaster_run_id + if fct_run_id != "null": + fct_workdir = ( + smk.params.output_root / "runs" / fct_run_id / smk.wildcards.init_time + ) + (workdir / "forecaster").symlink_to(fct_workdir / "grib") + LOG.info( + "Created symlink to forecaster grib directory at %s", workdir / "forecaster" + ) + else: + (workdir / "forecaster").mkdir(parents=True, exist_ok=True) + (workdir / "forecaster/.dataset").touch() + LOG.info( + "No forecaster run ID provided; using dataset placeholder at %s", + workdir / "forecaster/.dataset", + ) + + # prepare config + overrides = _overrides_from_params(smk) + prepare_config(smk.input.config, smk.output.config, overrides) + LOG.info("Wrote config file at %s", smk.output.config) + with open(smk.output.config, "r") as f: + config_content = f.read() + LOG.info("Config: \n%s", config_content) + + LOG.info("Interpolator preparation complete.") + + +def prepare_forecaster(smk): + """Prepare the forecaster for the inference run. + + Required steps: + - prepare working directory + - prepare config + """ + LOG = _setup_logger(smk) + + workdir = _get_workdir(smk) + prepare_workdir(workdir, smk.params.resources_root) + LOG.info("Prepared working directory at %s", workdir) + res_list = "\n".join([str(fn) for fn in Path(workdir / "resources").rglob("*")]) + LOG.info("Resources: \n%s", res_list) + + overrides = _overrides_from_params(smk) + prepare_config(smk.input.config, smk.output.config, overrides) + LOG.info("Wrote config file at %s", smk.output.config) + with open(smk.output.config, "r") as f: + config_content = f.read() + LOG.info("Config: \n%s", config_content) + + LOG.info("Forecaster preparation complete.") + + +# TODO: just pass a dictionary of config overrides to the rule's params +def _overrides_from_params(smk) -> dict: + return { + "checkpoint": f"{smk.params.checkpoints_path}/inference-last.ckpt", + "date": smk.params.reftime_to_iso, + "lead_time": smk.params.lead_time, + } + + +def _get_workdir(smk) -> Path: + run_id = smk.wildcards.run_id + init_time = smk.wildcards.init_time + return smk.params.output_root / "runs" / run_id / init_time + + +def _setup_logger(smk) -> logging.Logger: + run_id = smk.wildcards.run_id + init_time = smk.wildcards.init_time + logger_name = f"{smk.rule}_{run_id}_{init_time}" + LOG = setup_logger(logger_name, log_file=smk.log[0]) + return LOG + + +def _override_recursive(original: dict, updates: dict) -> dict: + """Recursively override values in the original dictionary with those from the updates dictionary.""" + for key, value in updates.items(): + if ( + isinstance(value, dict) + and key in original + and isinstance(original[key], dict) + ): + original[key] = _override_recursive(original[key], value) + else: + original[key] = value + return original + + +def main(smk): + """Main function to run the Snakemake workflow.""" + if smk.rule == "prepare_inference_forecaster": + prepare_forecaster(smk) + elif smk.rule == "prepare_inference_interpolator": + prepare_interpolator(smk) + else: + raise ValueError(f"Unknown rule: {smk.rule}") + + +if __name__ == "__main__": + snakemake = snakemake # type: ignore # noqa: F821 + raise SystemExit(main(snakemake))