diff --git a/config/interpolators.yaml b/config/interpolators.yaml index 662f067..0f5a042 100644 --- a/config/interpolators.yaml +++ b/config/interpolators.yaml @@ -15,7 +15,7 @@ runs: config: resources/inference/configs/interpolator_from_test_data_stretched.yaml forecaster: null extra_dependencies: - - git+https://github.com/ecmwf/anemoi-inference@14189907b4f4e3b204b7994f828831b8aa51e9b6 + - git+https://github.com/ecmwf/anemoi-inference@fix/cutout-preprocessors - torch-geometric==2.6.1 - anemoi-graphs==0.5.2 - interpolator: @@ -28,7 +28,7 @@ runs: config: resources/inference/configs/forecaster_with_global.yaml steps: 0/120/6 extra_dependencies: - - git+https://github.com/ecmwf/anemoi-inference@14189907b4f4e3b204b7994f828831b8aa51e9b6 + - git+https://github.com/ecmwf/anemoi-inference@fix/cutout-preprocessors - torch-geometric==2.6.1 - anemoi-graphs==0.5.2 - forecaster: diff --git a/resources/inference/configs/forecaster.yaml b/resources/inference/configs/forecaster.yaml index 4f558a3..8b318c8 100644 --- a/resources/inference/configs/forecaster.yaml +++ b/resources/inference/configs/forecaster.yaml @@ -17,7 +17,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml - printer write_initial_state: true diff --git a/resources/inference/configs/forecaster_no_trimedge.yaml b/resources/inference/configs/forecaster_no_trimedge.yaml index 2e3417d..306c62f 100644 --- a/resources/inference/configs/forecaster_no_trimedge.yaml +++ b/resources/inference/configs/forecaster_no_trimedge.yaml @@ -15,7 +15,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml - printer write_initial_state: true diff --git a/resources/inference/configs/forecaster_no_trimedge_fromtraining.yaml b/resources/inference/configs/forecaster_no_trimedge_fromtraining.yaml index 11188b9..b5097f5 100644 --- a/resources/inference/configs/forecaster_no_trimedge_fromtraining.yaml +++ b/resources/inference/configs/forecaster_no_trimedge_fromtraining.yaml @@ -15,7 +15,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml - printer write_initial_state: true diff --git a/resources/inference/configs/forecaster_with_global.yaml b/resources/inference/configs/forecaster_with_global.yaml index ae00c3a..890d3e2 100644 --- a/resources/inference/configs/forecaster_with_global.yaml +++ b/resources/inference/configs/forecaster_with_global.yaml @@ -17,18 +17,18 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml - grib: path: grib/ifs-{dateTime}_{step:03}.grib encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_ifs.yaml + samples: resources/templates_index_ifs.yaml post_processors: - extract_slice: [189699, -1] - assign_mask: "global/cutout_mask" -forcings: +constant_forcings: test: use_original_paths: true diff --git a/resources/inference/configs/interpolator.yaml b/resources/inference/configs/interpolator.yaml index 41253c0..765b093 100644 --- a/resources/inference/configs/interpolator.yaml +++ b/resources/inference/configs/interpolator.yaml @@ -10,7 +10,7 @@ post_processors: input: grib: - path: forecaster_grib/20*.grib # TODO: remove dirty fix to only use local files + path: forecaster/20*.grib # TODO: remove dirty fix to only use local files namer: rules: - - shortName: SKT @@ -54,12 +54,20 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml constant_forcings: test: use_original_paths: true +dynamic_forcings: + test: + use_original_paths: true + +patch_metadata: + dataset: + constant_fields: [z, lsm] + verbosity: 1 allow_nans: true output_frequency: "1h" diff --git a/resources/inference/configs/interpolator_from_test_data.yaml b/resources/inference/configs/interpolator_from_test_data.yaml index 07aea41..2fdb6cd 100644 --- a/resources/inference/configs/interpolator_from_test_data.yaml +++ b/resources/inference/configs/interpolator_from_test_data.yaml @@ -1,5 +1,4 @@ runner: time_interpolator -include_forcings: true input: test: @@ -15,7 +14,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml verbosity: 1 allow_nans: true diff --git a/resources/inference/configs/interpolator_from_test_data_stretched.yaml b/resources/inference/configs/interpolator_from_test_data_stretched.yaml index 19cd733..2167489 100644 --- a/resources/inference/configs/interpolator_from_test_data_stretched.yaml +++ b/resources/inference/configs/interpolator_from_test_data_stretched.yaml @@ -17,7 +17,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml verbosity: 1 allow_nans: true diff --git a/resources/inference/configs/interpolator_stretched.yaml b/resources/inference/configs/interpolator_stretched.yaml index 300d6c6..2928e76 100644 --- a/resources/inference/configs/interpolator_stretched.yaml +++ b/resources/inference/configs/interpolator_stretched.yaml @@ -4,9 +4,9 @@ input: cutout: lam_0: grib: - path: forecaster_grib/20* pre_processors: - extract_mask: "source0/trimedge_mask" + path: forecaster/20* namer: rules: - - shortName: T @@ -43,7 +43,7 @@ input: - tp global: grib: - path: forecaster_grib/ifs* + path: forecaster/ifs* namer: rules: - - shortName: T @@ -100,7 +100,7 @@ output: encoding: typeOfGeneratingProcess: 2 templates: - samples: _resources/templates_index_cosmo.yaml + samples: resources/templates_index_cosmo.yaml verbosity: 1 allow_nans: true diff --git a/resources/inference/templates/templates_index_cosmo.yaml b/resources/inference/templates/templates_index_cosmo.yaml index 8f15004..632164a 100644 --- a/resources/inference/templates/templates_index_cosmo.yaml +++ b/resources/inference/templates/templates_index_cosmo.yaml @@ -1,26 +1,26 @@ # COSMO-2 templates - - {grid: 0.02, levtype: pl} - - _resources/co2-typeOfLevel=isobaricInhPa.grib + - resources/co2-typeOfLevel=isobaricInhPa.grib - - {grid: 0.02, levtype: sfc, param: [T_2M, TD_2M, U_10M, V_10M]} - - _resources/co2-typeOfLevel=heightAboveGround.grib + - resources/co2-typeOfLevel=heightAboveGround.grib - - {grid: 0.02, levtype: sfc, param: [FR_LAND, TOC_PREC, PMSL, PS, FIS, T_G]} - - _resources/co2-typeOfLevel=surface.grib + - resources/co2-typeOfLevel=surface.grib - - {grid: 0.02, levtype: sfc, param: [TOT_PREC]} - - _resources/co2-shortName=TOT_PREC.grib + - resources/co2-shortName=TOT_PREC.grib # COSMO-1E templates - - {grid: 0.01, levtype: pl} - - _resources/co1e-typeOfLevel=isobaricInhPa.grib + - resources/co1e-typeOfLevel=isobaricInhPa.grib - - {grid: 0.01, levtype: sfc, param: [T_2M, TD_2M, U_10M, V_10M]} - - _resources/co1e-typeOfLevel=heightAboveGround.grib + - resources/co1e-typeOfLevel=heightAboveGround.grib - - {grid: 0.01, levtype: sfc, param: [FR_LAND, TOC_PREC, PMSL, PS, FIS, T_G]} - - _resources/co1e-typeOfLevel=surface.grib + - resources/co1e-typeOfLevel=surface.grib - - {grid: 0.01, levtype: sfc, param: [TOT_PREC]} - - _resources/co1e-shortName=TOT_PREC.grib + - resources/co1e-shortName=TOT_PREC.grib diff --git a/resources/inference/templates/templates_index_ifs.yaml b/resources/inference/templates/templates_index_ifs.yaml index a399ed9..c0700cf 100644 --- a/resources/inference/templates/templates_index_ifs.yaml +++ b/resources/inference/templates/templates_index_ifs.yaml @@ -1,5 +1,5 @@ - - {levtype: pl} - - _resources/ifs-levtype=pl.grib + - resources/ifs-levtype=pl.grib - - {levtype: sfc} - - _resources/ifs-levtype=sfc.grib + - resources/ifs-levtype=sfc.grib diff --git a/src/evalml/helpers.py b/src/evalml/helpers.py new file mode 100644 index 0000000..bb3e03c --- /dev/null +++ b/src/evalml/helpers.py @@ -0,0 +1,39 @@ +import logging + + +def setup_logger(logger_name, log_file, level=logging.INFO): + """ + Setup a logger with the specified name and log file path. + + Can be used to set up loggers from python scripts `run` directives + used in the Snakemake workflow. + + Parameters + ---------- + logger_name : str + The name of the logger. + log_file : str + The file path where the log messages will be written. + level : int, optional + The logging level (default is logging.INFO). + + Returns + ------- + logging.Logger + Configured logger instance. + """ + logger = logging.getLogger(logger_name) + logger.setLevel(level) + + if not logger.handlers: + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(level) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + + return logger diff --git a/workflow/envs/anemoi_inference.toml b/workflow/envs/anemoi_inference.toml index 22618d6..982a673 100644 --- a/workflow/envs/anemoi_inference.toml +++ b/workflow/envs/anemoi_inference.toml @@ -8,7 +8,7 @@ dependencies = [ "torchaudio", "anemoi-datasets>=0.5.23,<0.7.0", "anemoi-graphs>=0.5.0,<0.7.0", - "anemoi-inference>=0.7.0,<0.8.0", + "anemoi-inference>=0.8.0,<0.9.0", "anemoi-models>=0.4.20,<0.6.0", "anemoi-training>=0.3.3,<0.5.0", "anemoi-transform>=0.1.10,<0.3.0", diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 98283be..e9314cb 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -1,3 +1,4 @@ +import logging import copy from datetime import datetime, timedelta import yaml @@ -128,9 +129,11 @@ def _inference_routing_fn(wc): run_config = RUN_CONFIGS[wc.run_id] if run_config["model_type"] == "forecaster": - input_path = f"logs/inference_forecaster/{wc.run_id}-{wc.init_time}.ok" + input_path = f"logs/prepare_inference_forecaster/{wc.run_id}-{wc.init_time}.ok" elif run_config["model_type"] == "interpolator": - input_path = f"logs/inference_interpolator/{wc.run_id}-{wc.init_time}.ok" + input_path = ( + f"logs/prepare_inference_interpolator/{wc.run_id}-{wc.init_time}.ok" + ) else: raise ValueError(f"Unsupported model type: {run_config['model_type']}") diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index 8ec1d98..b490dc2 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -8,6 +8,13 @@ from datetime import datetime rule create_inference_pyproject: + """ + Generate a pyproject.toml that contains the information needed + to set up a virtual environment for inference of a specific checkpoint. + The list of dependencies is taken from the checkpoint's MLFlow run metadata, + and additional dependencies can be specified under a run entry in the main + config file. + """ input: toml="workflow/envs/anemoi_inference.toml", output: @@ -25,6 +32,11 @@ rule create_inference_pyproject: rule create_inference_venv: + """ + Create a virtual environment for inference, using the pyproject.toml created above. + The virtual environment is managed with uv. The created virtual environment is relocatable, + so it can be squashed later. Pre-compilation to bytecode is done to speed up imports. + """ input: pyproject=rules.create_inference_pyproject.output.pyproject, output: @@ -56,11 +68,12 @@ rule create_inference_venv: """ -# optionally, precompile to bytecode to reduce the import times -# find {output.venv} -exec stat --format='%i' {} + | sort -u | wc -l # optionally, how many files did I create? - - rule make_squashfs_image: + """ + Create a squashfs image for the inference virtual environment of + a specific checkpoint. Find more about this at + https://docs.cscs.ch/guides/storage/#python-virtual-environments-with-uenv. + """ input: venv=rules.create_inference_venv.output.venv, output: @@ -76,7 +89,11 @@ rule make_squashfs_image: rule create_inference_sandbox: - """Generate a zipped directory that can be used as a sandbox for running inference jobs. + """ + Create a zipped directory that, when extracted, can be used as a sandbox + for running inference jobs for a specific checkpoint. Its main purpose is + to serve as a development environment for anemoi-inference and to facilitate + sharing with external collaborators. TO use this sandbox, unzip it to a target directory. @@ -124,14 +141,18 @@ def get_leadtime(wc): return f"{end}h" -rule inference_forecaster: +rule prepare_inference_forecaster: localrule: True input: pyproject=rules.create_inference_pyproject.output.pyproject, - image=rules.make_squashfs_image.output.image, config=lambda wc: Path(RUN_CONFIGS[wc.run_id]["config"]).resolve(), output: - okfile=touch(OUT_ROOT / "logs/inference_forecaster/{run_id}-{init_time}.ok"), + config=Path(OUT_ROOT / "data/runs/{run_id}/{init_time}/config.yaml"), + resources=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/resources"), + grib_out_dir=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/grib"), + okfile=touch( + OUT_ROOT / "logs/prepare_inference_forecaster/{run_id}-{init_time}.ok" + ), params: checkpoints_path=parse_input( input.pyproject, parse_toml, key="tool.anemoi.checkpoints_path" @@ -142,53 +163,10 @@ rule inference_forecaster: reftime_to_iso=lambda wc: datetime.strptime( wc.init_time, "%Y%m%d%H%M" ).strftime("%Y-%m-%dT%H:%M"), - image_path=lambda wc, input: f"{Path(input.image).resolve()}", log: - OUT_ROOT / "logs/inference_forecaster/{run_id}-{init_time}.log", - resources: - slurm_partition=lambda wc: get_resource(wc, "slurm_partition", "short-shared"), - cpus_per_task=lambda wc: get_resource(wc, "cpus_per_task", 24), - mem_mb_per_cpu=lambda wc: get_resource(wc, "mem_mb_per_cpu", 8000), - runtime=lambda wc: get_resource(wc, "runtime", "40m"), - gres=lambda wc: f"gpu:{get_resource(wc, 'gpu',1)}", - ntasks=lambda wc: get_resource(wc, "tasks", 1), - slurm_extra=lambda wc, input: f"--uenv={Path(input.image).resolve()}:/user-environment", - gpus=lambda wc: get_resource(wc, "gpu", 1), - shell: - r""" - ( - set -euo pipefail - squashfs-mount {params.image_path}:/user-environment -- bash -c ' - export TZ=UTC - source /user-environment/bin/activate - export ECCODES_DEFINITION_PATH=/user-environment/share/eccodes-cosmo-resources/definitions - - # prepare the working directory - WORKDIR={params.output_root}/runs/{wildcards.run_id}/{wildcards.init_time} - mkdir -p $WORKDIR && cd $WORKDIR && mkdir -p grib raw _resources - cp {input.config} config.yaml && cp -r {params.resources_root}/templates/* _resources/ - CMD_ARGS=( - date={params.reftime_to_iso} - checkpoint={params.checkpoints_path}/inference-last.ckpt - lead_time={params.lead_time} - ) - - # is GPU > 1, add runner=parallel to CMD_ARGS - if [ {resources.gpus} -gt 1 ]; then - CMD_ARGS+=(runner=parallel) - fi - - srun \ - --partition={resources.slurm_partition} \ - --cpus-per-task={resources.cpus_per_task} \ - --mem-per-cpu={resources.mem_mb_per_cpu} \ - --time={resources.runtime} \ - --gres={resources.gres} \ - --ntasks={resources.ntasks} \ - anemoi-inference run config.yaml "${{CMD_ARGS[@]}}" - ' - ) > {log} 2>&1 - """ + OUT_ROOT / "logs/prepare_inference_forecaster/{run_id}-{init_time}.log", + script: + "../scripts/inference_prepare.py" def _get_forecaster_run_id(run_id): @@ -196,23 +174,28 @@ def _get_forecaster_run_id(run_id): return RUN_CONFIGS[run_id]["forecaster"]["mlflow_id"][0:9] -rule inference_interpolator: +rule prepare_inference_interpolator: """Run the interpolator for a specific run ID.""" localrule: True input: pyproject=rules.create_inference_pyproject.output.pyproject, - image=rules.make_squashfs_image.output.image, config=lambda wc: Path(RUN_CONFIGS[wc.run_id]["config"]).resolve(), forecasts=lambda wc: ( [ OUT_ROOT - / f"logs/inference_forecaster/{_get_forecaster_run_id(wc.run_id)}-{wc.init_time}.ok" + / f"logs/execute_inference/{_get_forecaster_run_id(wc.run_id)}-{wc.init_time}.ok" ] if RUN_CONFIGS[wc.run_id].get("forecaster") is not None else [] ), output: - okfile=touch(OUT_ROOT / "logs/inference_interpolator/{run_id}-{init_time}.ok"), + config=Path(OUT_ROOT / "data/runs/{run_id}/{init_time}/config.yaml"), + resources=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/resources"), + grib_out_dir=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/grib"), + forecaster=directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/forecaster"), + okfile=touch( + OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.ok" + ), params: checkpoints_path=parse_input( input.pyproject, parse_toml, key="tool.anemoi.checkpoints_path" @@ -228,9 +211,26 @@ rule inference_interpolator: if RUN_CONFIGS[wc.run_id].get("forecaster") is None else _get_forecaster_run_id(wc.run_id) ), - image_path=lambda wc, input: f"{Path(input.image).resolve()}", log: - OUT_ROOT / "logs/inference_interpolator/{run_id}-{init_time}.log", + OUT_ROOT / "logs/prepare_inference_interpolator/{run_id}-{init_time}.log", + script: + "../scripts/inference_prepare.py" + + +rule execute_inference: + localrule: True + input: + okfile=_inference_routing_fn, + image=rules.make_squashfs_image.output.image, + output: + okfile=touch(OUT_ROOT / "logs/execute_inference/{run_id}-{init_time}.ok"), + log: + OUT_ROOT / "logs/execute_inference/{run_id}-{init_time}.log", + params: + image_path=lambda wc, input: f"{Path(input.image).resolve()}", + workdir=lambda wc: ( + OUT_ROOT / f"data/runs/{wc.run_id}/{wc.init_time}" + ).resolve(), resources: slurm_partition=lambda wc: get_resource(wc, "slurm_partition", "short-shared"), cpus_per_task=lambda wc: get_resource(wc, "cpus_per_task", 24), @@ -238,35 +238,19 @@ rule inference_interpolator: runtime=lambda wc: get_resource(wc, "runtime", "40m"), gres=lambda wc: f"gpu:{get_resource(wc, 'gpu',1)}", ntasks=lambda wc: get_resource(wc, "tasks", 1), - slurm_extra=lambda wc, input: f"--uenv={Path(input.image).resolve()}:/user-environment", gpus=lambda wc: get_resource(wc, "gpu", 1), shell: - r""" + """ ( set -euo pipefail + + cd {params.workdir} + squashfs-mount {params.image_path}:/user-environment -- bash -c ' - export TZ=UTC source /user-environment/bin/activate export ECCODES_DEFINITION_PATH=/user-environment/share/eccodes-cosmo-resources/definitions - # prepare the working directory - WORKDIR={params.output_root}/runs/{wildcards.run_id}/{wildcards.init_time} - mkdir -p $WORKDIR && cd $WORKDIR && mkdir -p grib raw _resources - cp {input.config} config.yaml && cp -r {params.resources_root}/templates/* _resources/ - - # if forecaster_run_id is not "null", link the forecaster grib directory; else, run from files. - if [ "{params.forecaster_run_id}" != "null" ]; then - FORECASTER_WORKDIR={params.output_root}/runs/{params.forecaster_run_id}/{wildcards.init_time} - ln -fns $FORECASTER_WORKDIR/grib forecaster_grib - else - echo "Forecaster configuration is null; proceeding with file-based inputs." - fi - - CMD_ARGS=( - date={params.reftime_to_iso} - checkpoint={params.checkpoints_path}/inference-last.ckpt - lead_time={params.lead_time} - ) + CMD_ARGS=() # is GPU > 1, add runner=parallel to CMD_ARGS if [ {resources.gpus} -gt 1 ]; then @@ -274,6 +258,7 @@ rule inference_interpolator: fi srun \ + --unbuffered \ --partition={resources.slurm_partition} \ --cpus-per-task={resources.cpus_per_task} \ --mem-per-cpu={resources.mem_mb_per_cpu} \ @@ -284,12 +269,3 @@ rule inference_interpolator: ' ) > {log} 2>&1 """ - - -rule inference_routing: - localrule: True - input: - _inference_routing_fn, - output: - directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/grib"), - directory(OUT_ROOT / "data/runs/{run_id}/{init_time}/raw"), diff --git a/workflow/rules/verif.smk b/workflow/rules/verif.smk index bef4522..6c732db 100644 --- a/workflow/rules/verif.smk +++ b/workflow/rules/verif.smk @@ -55,9 +55,7 @@ def _get_no_none(dict, key, replacement): rule verif_metrics: input: script="workflow/scripts/verif_from_grib.py", - module="src/verification/__init__.py", - inference_okfile=_inference_routing_fn, - grib_output=rules.inference_routing.output[0], + inference_okfile=rules.execute_inference.output.okfile, analysis_zarr=config["analysis"].get("analysis_zarr"), output: OUT_ROOT / "data/runs/{run_id}/{init_time}/verif.nc", @@ -68,6 +66,9 @@ rule verif_metrics: fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], analysis_label=config["analysis"].get("label"), + grib_out_dir=lambda wc: ( + Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" + ).resolve(), log: OUT_ROOT / "logs/verif_metrics/{run_id}-{init_time}.log", resources: @@ -77,7 +78,7 @@ rule verif_metrics: shell: """ uv run {input.script} \ - --grib_output_dir {input.grib_output} \ + --grib_output_dir {params.grib_out_dir} \ --analysis_zarr {input.analysis_zarr} \ --steps "{params.fcst_steps}" \ --fcst_label "{params.fcst_label}" \ diff --git a/workflow/scripts/inference_prepare.py b/workflow/scripts/inference_prepare.py new file mode 100644 index 0000000..e317877 --- /dev/null +++ b/workflow/scripts/inference_prepare.py @@ -0,0 +1,173 @@ +"""Script to prepare configuration and working directory for inference runs.""" + +import logging +import yaml +import shutil +from pathlib import Path + +from evalml.helpers import setup_logger + + +def prepare_config(default_config_path: str, output_config_path: str, params: dict): + """Prepare the configuration file for the inference run. + + Overrides default configuration parameters with those provided in params + and writes the updated configuration to output_config_path. + + Parameters + ---------- + default_config_path : str + Path to the default configuration file. + output_config_path : str + Path where the updated configuration file will be written. + params : dict + Dictionary of parameters to override in the default configuration. + """ + + with open(default_config_path, "r") as f: + config = yaml.safe_load(f) + + config = _override_recursive(config, params) + + with open(output_config_path, "w") as f: + yaml.safe_dump(config, f, sort_keys=False) + + +def prepare_workdir(workdir: Path, resources_root: Path): + """Prepare the working directory for the inference run. + + Creates necessary subdirectories and copies resource files. + + Parameters + ---------- + workdir : Path + Path to the working directory. + resources_root : Path + Path to the root directory containing resource files. + """ + workdir.mkdir(parents=True, exist_ok=True) + (workdir / "grib").mkdir(parents=True, exist_ok=True) + shutil.copytree(resources_root / "templates", workdir / "resources") + + +def prepare_interpolator(smk): + """Prepare the interpolator for the inference run. + + Required steps: + - prepare working directory + - prepare forecaster directory + - prepare config + """ + LOG = _setup_logger(smk) + + # prepare working directory + workdir = _get_workdir(smk) + prepare_workdir(workdir, smk.params.resources_root) + LOG.info("Prepared working directory at %s", workdir) + res_list = "\n".join([str(fn) for fn in Path(workdir / "resources").rglob("*")]) + LOG.info("Resources: \n%s", res_list) + + # prepare forecaster directory + fct_run_id = smk.params.forecaster_run_id + if fct_run_id != "null": + fct_workdir = ( + smk.params.output_root / "runs" / fct_run_id / smk.wildcards.init_time + ) + (workdir / "forecaster").symlink_to(fct_workdir / "grib") + LOG.info( + "Created symlink to forecaster grib directory at %s", workdir / "forecaster" + ) + else: + (workdir / "forecaster").mkdir(parents=True, exist_ok=True) + (workdir / "forecaster/.dataset").touch() + LOG.info( + "No forecaster run ID provided; using dataset placeholder at %s", + workdir / "forecaster/.dataset", + ) + + # prepare config + overrides = _overrides_from_params(smk) + prepare_config(smk.input.config, smk.output.config, overrides) + LOG.info("Wrote config file at %s", smk.output.config) + with open(smk.output.config, "r") as f: + config_content = f.read() + LOG.info("Config: \n%s", config_content) + + LOG.info("Interpolator preparation complete.") + + +def prepare_forecaster(smk): + """Prepare the forecaster for the inference run. + + Required steps: + - prepare working directory + - prepare config + """ + LOG = _setup_logger(smk) + + workdir = _get_workdir(smk) + prepare_workdir(workdir, smk.params.resources_root) + LOG.info("Prepared working directory at %s", workdir) + res_list = "\n".join([str(fn) for fn in Path(workdir / "resources").rglob("*")]) + LOG.info("Resources: \n%s", res_list) + + overrides = _overrides_from_params(smk) + prepare_config(smk.input.config, smk.output.config, overrides) + LOG.info("Wrote config file at %s", smk.output.config) + with open(smk.output.config, "r") as f: + config_content = f.read() + LOG.info("Config: \n%s", config_content) + + LOG.info("Forecaster preparation complete.") + + +# TODO: just pass a dictionary of config overrides to the rule's params +def _overrides_from_params(smk) -> dict: + return { + "checkpoint": f"{smk.params.checkpoints_path}/inference-last.ckpt", + "date": smk.params.reftime_to_iso, + "lead_time": smk.params.lead_time, + } + + +def _get_workdir(smk) -> Path: + run_id = smk.wildcards.run_id + init_time = smk.wildcards.init_time + return smk.params.output_root / "runs" / run_id / init_time + + +def _setup_logger(smk) -> logging.Logger: + run_id = smk.wildcards.run_id + init_time = smk.wildcards.init_time + logger_name = f"{smk.rule}_{run_id}_{init_time}" + LOG = setup_logger(logger_name, log_file=smk.log[0]) + return LOG + + +def _override_recursive(original: dict, updates: dict) -> dict: + """Recursively override values in the original dictionary with those from the updates dictionary.""" + for key, value in updates.items(): + if ( + isinstance(value, dict) + and key in original + and isinstance(original[key], dict) + ): + original[key] = _override_recursive(original[key], value) + else: + original[key] = value + return original + + +def main(smk): + """Main function to run the Snakemake workflow.""" + if smk.rule == "prepare_inference_forecaster": + prepare_forecaster(smk) + elif smk.rule == "prepare_inference_interpolator": + prepare_interpolator(smk) + else: + raise ValueError(f"Unknown rule: {smk.rule}") + + +if __name__ == "__main__": + snakemake = snakemake # type: ignore # noqa: F821 + raise SystemExit(main(snakemake))