diff --git a/.gitignore b/.gitignore index 8732a47..94740d1 100644 --- a/.gitignore +++ b/.gitignore @@ -175,6 +175,7 @@ pyrightconfig.json *.torch plots/* *.npz +outputs/* # conda .conda/* @@ -194,7 +195,9 @@ temp.zarr.sync* src/hirad/eval/__pycache__/* interpolate_basic.log interpolated.torch +mlruns/ +.secrets.env out core *.png -*.nc \ No newline at end of file +*.nc diff --git a/README.md b/README.md index b0dbd2e..a6cbe61 100644 --- a/README.md +++ b/README.md @@ -2,37 +2,20 @@ HiRAD-Gen is short for high-resolution atmospheric downscaling using generative models. This repository contains the code and configuration required to train and use the model. -## Installation (Alps) +[Setup clariden/santis](#setup-claridensantis) +[Regression training - clariden/santis](#run-regression-model-training-alps) +[Diffusion training - clariden/santis](#run-diffusion-model-training-alps) +[Inference - clariden/santis](#running-inference-on-alps) +[Installation - uenv/venv - deprecated](#installation-alps-uenvvenv---deprecated) -To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these steps: - -1. **Start the PyTorch user environment**: - ```bash - uenv start pytorch/v2.6.0:v1 --view=default - ``` - -2. **Create a Python virtual environment** (replace `{env_name}` with your desired environment name): - ```bash - python -m venv ./{env_name} - ``` - -3. **Activate the virtual environment**: - ```bash - source ./{env_name}/bin/activate - ``` - -4. **Install project dependencies**: - ```bash - pip install -e . - ``` - -This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure. +## Setup clariden/santis container environment +Container environment setup needed to run training and inference experiments on clariden/santis is contained in this repository under `ci/edf/modulus_env.toml`. Image squash is on clariden/alps under `/capstor/scratch/cscs/pstamenk/hirad.sqsh`. All the jobs can be run using this environment without additional installations and setup. ## Training ### Run regression model training (Alps) -1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. +1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. Here, you can change the sbatch settings. Inside this script set the following: ```bash ### OUTPUT ### @@ -55,14 +38,9 @@ srun bash -c " ``` hydra: run: - dir: your_path_to_save_training_output + dir: your_path_to_save_training_outputs ``` -- In `training/era_cosmo_regression.yaml` set: -``` -hp: - training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) -``` -- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. +- All other parameters for training regression can be changed in the main config file `training_era_cosmo_regression.yaml` and config files the main config is referencing (default values are working for debugging purposes). 3. Submit the job with: ```bash @@ -72,8 +50,7 @@ sbatch src/hirad/train_regression.sh ### Run diffusion model training (Alps) Before training diffusion model, checkpoint for regression model has to exist. -1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. -Inside this script set the following: +1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. Here, you can change the sbatch settings. Inside this script set the following: ```bash ### OUTPUT ### #SBATCH --output=your_path_to_output_log @@ -82,12 +59,6 @@ Inside this script set the following: ```bash #SBATCH -A your_compute_group ``` -```bash -srun bash -c " - . ./{your_env_name}/bin/activate - python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml -" -``` 2. Set up the following config files in `src/hirad/conf`: @@ -97,14 +68,12 @@ hydra: run: dir: your_path_to_save_training_output ``` -- In `training/era_cosmo_regression.yaml` set: +- In `training/era_cosmo_diffusion.yaml` set: ``` -hp: - training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) io: regression_checkpoint_path: path_to_directory_containing_regression_training_model_checkpoints ``` -- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. +- All other parameters for training regression can be changed in the main config file `training_era_cosmo_diffusion.yaml` and config files the main config is referencing (default values are working for debugging purposes). 3. Submit the job with: ```bash @@ -155,13 +124,42 @@ Finally, from the dataset, subset of time steps can be chosen to do inference fo One way is to list steps under `times:` in format `%Y%m%d-%H%M` for era5_cosmo dataset. -The other way is to specify `times_range:` with three items: first time step (`%Y%m%d-%H%M`), last time step (`%Y%m%d-%H%M`), hour shift (int). Hour shift specifies distance in hours between closest time steps for specific dataset (6 for era_cosmo). - -By default, inference is done for one time step `20160101-0000` +The other way is to specify `times_range:` with three items: first time step (`%Y%m%d-%H%M`), last time step (`%Y%m%d-%H%M`), hour shift (int). Hour shift specifies distance in hours between closest time steps for specific dataset. -- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. +- In `dataset/era_cosmo_inference.yaml` set the `dataset_path` if different from default. Make sure that specified times or times_range is contained in dataset_path. 3. Submit the job with: ```bash sbatch src/hirad/generate.sh -``` \ No newline at end of file +``` + +## MLflow logging + +During training MLflow can be used to log metrics. +Logging config files for regression and diffusion are located in `src/hirad/conf/logging/`. Set `method` to `mlflow` and specify `uri` if you want to log on remote server, otherwise run will be logged locally in output directory. Other options can also be modified here. + +## Installation (Alps uenv/venv) - deprecated + +To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these steps: + +1. **Start the PyTorch user environment**: + ```bash + uenv start pytorch/v2.6.0:v1 --view=default + ``` + +2. **Create a Python virtual environment** (replace `{env_name}` with your desired environment name): + ```bash + python -m venv ./{env_name} + ``` + +3. **Activate the virtual environment**: + ```bash + source ./{env_name}/bin/activate + ``` + +4. **Install project dependencies**: + ```bash + pip install -e . + ``` + +This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure. diff --git a/ci/cscs.yml b/ci/cscs.yml index fc92645..c3ac15c 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -12,14 +12,18 @@ build_job: stage: build extends: .container-builder-cscs-gh200 variables: - DOCKERFILE: ci/docker/Dockerfile + DOCKERFILE: ci/docker/Dockerfile.ci -#test_job: -# stage: test -# extends: .container-runner-clariden-gh200 -# image: $PERSIST_IMAGE_NAME -# script: -# - /opt/helloworld/bin/hello -# variables: -# SLURM_JOB_NUM_NODES: 2 -# SLURM_NTASKS: 2 +test_job: + stage: test + extends: .container-runner-clariden-gh200 + image: $PERSIST_IMAGE_NAME + script: + - echo 'hello world' + # - pip install -e . --no-dependencies + # - python src/hirad/training/train.py --config-name=training_era_cosmo_regression_test.yaml + variables: + SLURM_JOB_NUM_NODES: 2 + SLURM_NTASKS: 2 + #SLURM_ACCOUNT: a161 + #SBATCH_ACCOUNT: a161 diff --git a/ci/docker/Dockerfile.ci b/ci/docker/Dockerfile.ci new file mode 100644 index 0000000..80925d3 --- /dev/null +++ b/ci/docker/Dockerfile.ci @@ -0,0 +1,15 @@ +FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.06 + +# setup +RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade pip + +# Install the rest of dependencies. +RUN pip install \ + Cartopy==0.22.0 \ + xskillscore + + + + + diff --git a/ci/docker/Dockerfile.corrdiff b/ci/docker/Dockerfile.corrdiff new file mode 100644 index 0000000..f4be071 --- /dev/null +++ b/ci/docker/Dockerfile.corrdiff @@ -0,0 +1,14 @@ +FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.06 + +# setup +RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade pip + +# Install the rest of dependencies. +RUN pip install \ + anemoi.datasets \ + Cartopy==0.22.0 \ + xskillscore \ + scoringrules \ + mlflow \ + meteodata-lab diff --git a/ci/docker/Dockerfile.python b/ci/docker/Dockerfile.python new file mode 100644 index 0000000..80ea1eb --- /dev/null +++ b/ci/docker/Dockerfile.python @@ -0,0 +1,29 @@ +# Following some suggestions in https://meteoswiss.atlassian.net/wiki/spaces/APN/pages/719684202/Clariden+Alps+environment+setup + +FROM ubuntu:22.04 as builder +#FROM nvcr.io/nvidia/pytorch:25.01-py3 + + +# setup +RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade \ + pip + #ninja + #wheel + #packaging + #setuptools + +# install the rest of dependencies +# TODO: Factor pydeps into a separate file(s) +# TODO: Add versions for things +RUN pip install \ + anemoi-datasets \ + cartopy \ + matplotlib \ + numpy \ + pandas \ + scipy \ + torch + + + diff --git a/ci/edf/modulus_env.toml b/ci/edf/modulus_env.toml new file mode 100644 index 0000000..55f43d8 --- /dev/null +++ b/ci/edf/modulus_env.toml @@ -0,0 +1,10 @@ +image = "/capstor/scratch/cscs/pstamenk/hirad.sqsh" + +mounts = ["/capstor", "/iopsstor", "/users"] + +# The initial directory in the container. +workdir = "${PWD}" + +[annotations] +com.hooks.aws_ofi_nccl.enabled = "true" +com.hooks.aws_ofi_nccl.variant = "cuda12" \ No newline at end of file diff --git a/interpolate.sh b/interpolate.sh index 607ad66..73e947d 100755 --- a/interpolate.sh +++ b/interpolate.sh @@ -3,4 +3,4 @@ #SBATCH --partition=postproc #SBATCH --time=23:59:00 -python src/input_data/interpolate_basic.py src/input_data/era-all.yaml src/input_data/cosmo-all.yaml /store_new/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-all-channels/ +python src/hirad/input_data/interpolate_basic.py src/hirad/input_data/era-all.yaml src/hirad/input_data/cosmo-all.yaml /store_new/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-all-channels/ diff --git a/pyproject.toml b/pyproject.toml index 1477899..7ffe3f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,19 +10,10 @@ authors = [ { name="Petar Stamenkovic", email="petar.stamenkovic@meteoswiss.ch" } ] readme = "README.md" -requires-python = ">=3.12" +#requires-python = ">=3.12" license = {file = "LICENSE"} dependencies = [ - "cartopy>=0.24.1", - "cftime>=1.6.4", - "hydra-core>=1.3.2", - "matplotlib>=3.10.1", - "omegaconf>=2.3.0", - "tensorboard>=2.19.0", - "termcolor>=3.1.0", - "torchinfo>=1.8.0", - "treelib>=1.7.1" ] [tool.setuptools] diff --git a/src/hirad/conf/compute_eval.yaml b/src/hirad/conf/compute_eval.yaml new file mode 100644 index 0000000..069ad1e --- /dev/null +++ b/src/hirad/conf/compute_eval.yaml @@ -0,0 +1,12 @@ +hydra: + job: + chdir: true + name: diffusion_era5_cosmo_7500000_test + run: + dir: ./outputs/generation/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + # Dataset + - dataset/era_cosmo_inference diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml index b1e21e6..fc81714 100644 --- a/src/hirad/conf/dataset/era_cosmo.yaml +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -1,3 +1,9 @@ type: era5_cosmo -dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_full -validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_full/validation \ No newline at end of file +dataset_path: /iopsstor/scratch/cscs/mmcgloho/run-1_2/train/ +# dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-all-channels +validation_path: /iopsstor/scratch/cscs/mmcgloho/run-1_2/validation/ +input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp] +output_channel_names: [2t, 10u, 10v, tp] +static_channel_names: ['hsurf'] +transform_channels: ['tp-box_cox_025'] +n_month_hour_channels: 4 \ No newline at end of file diff --git a/src/hirad/conf/dataset/era_cosmo_inference.yaml b/src/hirad/conf/dataset/era_cosmo_inference.yaml new file mode 100644 index 0000000..71b0db1 --- /dev/null +++ b/src/hirad/conf/dataset/era_cosmo_inference.yaml @@ -0,0 +1,8 @@ +type: era5_cosmo +# dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation +dataset_path: /iopsstor/scratch/cscs/mmcgloho/run-1_2/validation +input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp] +output_channel_names: [2t, 10u, 10v, tp] +static_channel_names: ['hsurf'] +transform_channels: ['tp-box_cox_025'] +n_month_hour_channels: 4 \ No newline at end of file diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml index 5d7649d..d448b58 100644 --- a/src/hirad/conf/generate_era_cosmo.yaml +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -1,15 +1,15 @@ hydra: job: chdir: true - name: generation_full + name: diffusion_era5_cosmo_7500000_test run: - dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + dir: ./outputs/generation/${hydra:job.name} # Get defaults defaults: - _self_ # Dataset - - dataset/era_cosmo + - dataset/era_cosmo_inference # Sampler - sampler/stochastic diff --git a/src/hirad/conf/generate_era_cosmo_test.yaml b/src/hirad/conf/generate_era_cosmo_test.yaml new file mode 100644 index 0000000..bb83890 --- /dev/null +++ b/src/hirad/conf/generate_era_cosmo_test.yaml @@ -0,0 +1,20 @@ +hydra: + job: + chdir: true + name: generation_era5_cosmo_test + run: + dir: ./outputs/generation/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + # Dataset + - dataset/era_cosmo_inference + + # Sampler + - sampler/stochastic + #- sampler/deterministic + + # Generation + - generation/era_cosmo_test + #- generation/patched_based \ No newline at end of file diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index be4219d..a5302ff 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -1,4 +1,4 @@ -num_ensembles: 8 +num_ensembles: 16 # Number of ensembles to generate per input seed_batch_size: 4 # Size of the batched inference @@ -6,7 +6,13 @@ inference_mode: all # Choose between "all" (regression + diffusion), "regression" or "diffusion" # Patch size. Patch-based sampling will be utilized if these dimensions differ from # img_shape_x and img_shape_y -# overlap_pixels: 0 +randomize: True + # Whether to randomize the random seeds for each generation. If false, fixed seeds + # from 0 to num_ensembles-1 will be used for each time step in times/times_range. +random_seed: 129 + # Base random seed. This is only used when randomize is True. + # random seed will be set for numpy random module to have reproducible randomized generative process. + # Number of overlapping pixels between adjacent patches # boundary_pixels: 0 # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary @@ -15,12 +21,12 @@ patching: False hr_mean_conditioning: True # sample_res: full # Sampling resolution -times_range: null -times: - - 20160101-0000 - # - 20160101-0600 +times_range: ['20200101-0000','20200102-0000',1] +times: null +# - 20200926-1800 + #- 20160101-0600 # - 20160101-1200 -has_laed_time: False +has_lead_time: False perf: force_fp16: False @@ -30,15 +36,15 @@ perf: # whether to use torch.compile on the diffusion model # this will make the first time stamp generation very slow due to compilation overheads # but will significantly speed up subsequent inference runs - num_writer_workers: 1 + num_writer_workers: 8 # number of workers to use for writing file # To support multiple workers a threadsafe version of the netCDF library must be used io: - res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_refactoring/checkpoints_diffusion - # res_ckpt_path: null + # res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/diffusion_full/checkpoints_diffusion + res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/diffusion_era5_cosmo/checkpoints_diffusion + # res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/diffusion_era5_cosmo_cmp_fix_2/checkpoints_diffusion # Checkpoint filename for the diffusion model - reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_refactoring/checkpoints_regression - # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression + reg_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression # Checkpoint filename for the mean predictor model - output_path: ./images \ No newline at end of file + output_path: . \ No newline at end of file diff --git a/src/hirad/conf/generation/era_cosmo_test.yaml b/src/hirad/conf/generation/era_cosmo_test.yaml new file mode 100644 index 0000000..9e10eb4 --- /dev/null +++ b/src/hirad/conf/generation/era_cosmo_test.yaml @@ -0,0 +1,42 @@ +# TODO: See if there's a way to inherit from era_cosmo.yaml +num_ensembles: 8 + # Number of ensembles to generate per input +seed_batch_size: 4 + # Size of the batched inference +inference_mode: all + # Choose between "all" (regression + diffusion), "regression" or "diffusion" + # Patch size. Patch-based sampling will be utilized if these dimensions differ from + # img_shape_x and img_shape_y +# overlap_pixels: 0 + # Number of overlapping pixels between adjacent patches +# boundary_pixels: 0 + # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary + # artifact. +patching: False +hr_mean_conditioning: True +# sample_res: full + # Sampling resolution +times_range: ['20200131-0000','20200131-2300',1] +times: null +has_lead_time: False + +perf: + force_fp16: False + # Whether to force fp16 precision for the model. If false, it'll use the precision + # specified upon training. + use_torch_compile: False + # whether to use torch.compile on the diffusion model + # this will make the first time stamp generation very slow due to compilation overheads + # but will significantly speed up subsequent inference runs + num_writer_workers: 8 + # number of workers to use for writing file + # To support multiple workers a threadsafe version of the netCDF library must be used + +io: + res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/diffusion_test/checkpoints_diffusion + # Checkpoint filename for the diffusion model + reg_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_test/checkpoints_regression + # Checkpoint filename for the mean predictor model + output_path: ./outputs/evaluation + + diff --git a/src/hirad/conf/generation/era_cosmo_training.yaml b/src/hirad/conf/generation/era_cosmo_training.yaml new file mode 100644 index 0000000..54dc840 --- /dev/null +++ b/src/hirad/conf/generation/era_cosmo_training.yaml @@ -0,0 +1,18 @@ +defaults: + - ../sampler@sampler: stochastic + - ../dataset@dataset: era_cosmo_inference + +num_ensembles: 16 + # Number of ensembles to generate per input +# overlap_pixels: 0 + # Number of overlapping pixels between adjacent patches +# boundary_pixels: 0 + # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary + # artifact. +times_range: null +times: + - 20200721-1900 + - 20200722-1900 + +perf: + num_writer_workers: 10 \ No newline at end of file diff --git a/src/hirad/conf/generation/era_cosmo_training_patched.yaml b/src/hirad/conf/generation/era_cosmo_training_patched.yaml new file mode 100644 index 0000000..e6e2d7c --- /dev/null +++ b/src/hirad/conf/generation/era_cosmo_training_patched.yaml @@ -0,0 +1,24 @@ +defaults: + - ../sampler@sampler: stochastic + - ../dataset@dataset: era_cosmo_inference + +num_ensembles: 16 + # Number of ensembles to generate per input + +patching: True +# Use patch-based sampling +overlap_pix: 4 +# Number of overlapping pixels between adjacent patches +boundary_pix: 2 +# Number of boundary pixels to be cropped out. 2 is recommended to address the boundary +# artifact. +patch_shape_x: 128 +patch_shape_y: 128 + +times_range: null +times: + - 20200926-1800 + # - 20200927-0000 + +perf: + num_writer_workers: 10 \ No newline at end of file diff --git a/src/hirad/conf/logging/era_cosmo_diffusion.yaml b/src/hirad/conf/logging/era_cosmo_diffusion.yaml new file mode 100644 index 0000000..86ec7fe --- /dev/null +++ b/src/hirad/conf/logging/era_cosmo_diffusion.yaml @@ -0,0 +1,8 @@ +# set method to mlflow to log with mlflow +method: mlflow +experiment_name: hirad-corrdiff-diffusion +run_name: era-cosmo-1h +# change uri to remote mlflow server; if null, it is stored locally +# if uri is remote make sure to have credentials set in ~/.mlflow/credentials +uri: null +log_images: false \ No newline at end of file diff --git a/src/hirad/conf/logging/era_cosmo_regression.yaml b/src/hirad/conf/logging/era_cosmo_regression.yaml new file mode 100644 index 0000000..e7a6287 --- /dev/null +++ b/src/hirad/conf/logging/era_cosmo_regression.yaml @@ -0,0 +1,8 @@ +# set method to mlflow to log with mlflow +method: mlflow +experiment_name: hirad-corrdiff-regression +run_name: era-cosmo-1h +# change uri to remote mlflow server; if null, it is stored locally +# if uri is remote make sure to have credentials set in ~/.mlflow/credentials +uri: null +log_images: false \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_diffusion.yaml b/src/hirad/conf/model/era_cosmo_diffusion.yaml index 441239e..7a060e8 100644 --- a/src/hirad/conf/model/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/model/era_cosmo_diffusion.yaml @@ -10,6 +10,6 @@ model_args: # Controls how positional information is encoded. N_grid_channels: 4 # Number of channels for positional grid embeddings - embedding_type: "zero" + embedding_type: "positional" # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, # 'zero' for none \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_diffusion_patched.yaml b/src/hirad/conf/model/era_cosmo_diffusion_patched.yaml new file mode 100644 index 0000000..9362932 --- /dev/null +++ b/src/hirad/conf/model/era_cosmo_diffusion_patched.yaml @@ -0,0 +1,13 @@ +name: patched_diffusion + # Name of the preconditioner +hr_mean_conditioning: True + # High-res mean (regression's output) as additional condition + +# Standard model parameters. +# Standard model parameters. +model_args: + gridtype: "learnable" + # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'. + # Controls how positional information is encoded. + N_grid_channels: 100 + # Number of channels for positional grid embeddings \ No newline at end of file diff --git a/src/hirad/conf/model_size/normal.yaml b/src/hirad/conf/model_size/normal.yaml index 96c29fb..e746a2c 100644 --- a/src/hirad/conf/model_size/normal.yaml +++ b/src/hirad/conf/model_size/normal.yaml @@ -24,4 +24,4 @@ model_args: # Per-resolution multipliers for the number of channels. channel_mult: [1, 2, 2, 2, 2] # Resolutions at which self-attention layers are applied. - attn_resolutions: [28] \ No newline at end of file + attn_resolutions: [22] \ No newline at end of file diff --git a/src/hirad/conf/plot_maps.yaml b/src/hirad/conf/plot_maps.yaml new file mode 100644 index 0000000..07ecb85 --- /dev/null +++ b/src/hirad/conf/plot_maps.yaml @@ -0,0 +1,21 @@ +hydra: + job: + chdir: true + name: plot_maps + run: + dir: . + +# Get defaults +defaults: + - _self_ + # Dataset + - dataset/era_cosmo_inference + +results_dir: /capstor/scratch/cscs/pstamenk/outputs/generation/test_period_advanced_stats_2 + +time_steps: null # If null, will use all time steps in the results directory + +output_dir: ${results_dir}/plots_boxcox_masked_limits # Directory to save the plots + +plot_box_precipitation: False # Whether to plot box precipitation maps, otherwise plot with box-cox transform +tp_threshold: 0.0002 # Threshold in m/h for masking precipitation values \ No newline at end of file diff --git a/src/hirad/conf/sampler/deterministic.yaml b/src/hirad/conf/sampler/deterministic.yaml index 856906b..f65e738 100644 --- a/src/hirad/conf/sampler/deterministic.yaml +++ b/src/hirad/conf/sampler/deterministic.yaml @@ -2,7 +2,8 @@ # Deterministic sampler is not implemented correctly in this codebase and shouldn't be used. type: deterministic -num_steps: 9 +params: + num_steps: 9 # Number of denoising steps -solver: euler + solver: euler # ODE solver type: euler is the simplest solver \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index 07cbb03..f673e57 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -1,10 +1,10 @@ # Hyperparameters hp: - training_duration: 5000000 + training_duration: 3500000 # Training duration based on the number of processed samples - total_batch_size: 128 + total_batch_size: "auto" # Total batch size - batch_size_per_gpu: "auto" + batch_size_per_gpu: 20 # Batch size per GPU lr: 0.0002 # Learning rate @@ -26,18 +26,22 @@ perf: # DataLoader worker processes songunet_checkpoint_level: 0 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint - + use_apex_gn: True + torch_compile: True + profile_mode: False # I/O io: - regression_checkpoint_path: /capstor/scratch/cscs/boeschf/HiRAD-Gen/outputs_full/regression/checkpoints_regression/ + # regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression + regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo_perf_apex_gn/checkpoints_regression # Where to load the regression checkpoint - print_progress_freq: 5000 + print_progress_freq: 2000 # How often to print progress save_checkpoint_freq: 250000 # How often to save the checkpoints, measured in number of processed samples - validation_freq: 25000 + visualization_freq: 250000 + # how often to visualize network outputs + validation_freq: 50000 # how often to record the validation loss, measured in number of processed samples - validation_steps: 4 - # how many loss evaluations are used to compute the validation loss per checkpoint + validation_steps: 30 # how many loss evaluations are used to compute the validation loss per checkpoint checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion_patched.yaml b/src/hirad/conf/training/era_cosmo_diffusion_patched.yaml new file mode 100644 index 0000000..679afd6 --- /dev/null +++ b/src/hirad/conf/training/era_cosmo_diffusion_patched.yaml @@ -0,0 +1,53 @@ +# Hyperparameters +hp: + training_duration: 1000000 + # Training duration based on the number of processed samples + total_batch_size: "auto" + # Total batch size + batch_size_per_gpu: 10 + # Batch size per GPU + lr: 0.0002 + # Learning rate + grad_clip_threshold: 1e6 + # no gradient clipping for defualt non-patch-based training + lr_decay: 0.7 + # LR decay rate + lr_rampup: 1000000 + # Rampup for learning rate, in number of samples + lr_decay_rate: 5e5 + # Learning rate decay threshold in number of samples, applied every lr_decay_rate samples. + patch_shape_x: 128 + patch_shape_y: 128 + # Patch size. Patch training is used if these dimensions differ from + # img_shape_x and img_shape_y. + patch_num: 7 + # Number of patches from a single sample. Total number of patches is + # patch_num * batch_size_global. + max_patch_per_gpu: 100 + # Maximum number of pataches a gpu can hold + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 10 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# I/O +io: + regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression + # Where to load the regression checkpoint + print_progress_freq: 1000 + # How often to print progress + save_checkpoint_freq: 500000 + # How often to save the checkpoints, measured in number of processed samples + visualization_freq: 250000 + # how often to visualize network outputs + validation_freq: 50000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 100 + # how many loss evaluations are used to compute the validation loss per checkpoint + checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index 98c6c24..7d35d3d 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,10 +1,10 @@ # Hyperparameters hp: - training_duration: 500000 + training_duration: 1000000 # Training duration based on the number of processed samples - total_batch_size: 64 - # Total batch size - batch_size_per_gpu: "auto" + total_batch_size: "auto" + # Total batch size -- based 8 per GPU -- 2 nodes is 2x8x4 -- see sbatch vars for how many gpus. diffusion need to point to the rgression. + batch_size_per_gpu: 20 # Batch size per GPU lr: 0.0002 # Learning rate @@ -26,17 +26,20 @@ perf: # DataLoader worker processes songunet_checkpoint_level: 0 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint - # torch_compile: True - # use_apex_gn: True + use_apex_gn: True + torch_compile: True + profile_mode: False # I/O io: - print_progress_freq: 1024 + print_progress_freq: 500 # How often to print progress - save_checkpoint_freq: 25000 + save_checkpoint_freq: 100000 # How often to save the checkpoints, measured in number of processed samples - validation_freq: 5000 + visualization_freq: 50000 + # how often to visualize network output + validation_freq: 20000 # how often to record the validation loss, measured in number of processed samples - validation_steps: 10 + validation_steps: 30 # how many loss evaluations are used to compute the validation loss per checkpoint checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 0a069e9..8362946 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -1,9 +1,9 @@ hydra: job: chdir: true - name: diffusion + name: diffusion_era5_cosmo_test run: - dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + dir: /capstor/scratch/cscs/pstamenk/outputs/training/${hydra:job.name} # Get defaults defaults: @@ -18,4 +18,10 @@ defaults: - model_size/normal # Training - - training/era_cosmo_diffusion \ No newline at end of file + - training/era_cosmo_diffusion + + # Inference visualization + # - generation/era_cosmo_training + + # Logging + - logging/era_cosmo_diffusion diff --git a/src/hirad/conf/training_era_cosmo_diffusion_patched.yaml b/src/hirad/conf/training_era_cosmo_diffusion_patched.yaml new file mode 100644 index 0000000..38baba0 --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_diffusion_patched.yaml @@ -0,0 +1,27 @@ +hydra: + job: + chdir: true + name: diffusion_era5_cosmo_patched + run: + dir: /capstor/scratch/cscs/pstamenk/outputs/training/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_diffusion_patched + + - model_size/normal + + # Training + - training/era_cosmo_diffusion_patched + + # Inference visualization + - generation/era_cosmo_training_patched + + # Logging + - logging/era_cosmo_diffusion \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion_test.yaml b/src/hirad/conf/training_era_cosmo_diffusion_test.yaml new file mode 100644 index 0000000..2723050 --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_diffusion_test.yaml @@ -0,0 +1,24 @@ +hydra: + job: + chdir: true + name: diffusion_era5_cosmo_test + run: + dir: ./outputs/training/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_diffusion + + - model_size/mini + + # Training + - training/era_cosmo_diffusion + + # Inference visualization + - generation/era_cosmo_training \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml index 1de83d9..38912ee 100644 --- a/src/hirad/conf/training_era_cosmo_regression.yaml +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -1,9 +1,9 @@ hydra: job: chdir: true - name: regression + name: regression_era5_cosmo_mlflow_test_x16 run: - dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + dir: /capstor/scratch/cscs/pstamenk/outputs/training/${hydra:job.name} # Get defaults defaults: @@ -18,4 +18,10 @@ defaults: - model_size/normal # Training - - training/era_cosmo_regression \ No newline at end of file + - training/era_cosmo_regression + + # Inference visualization + # - generation/era_cosmo_training + + # Logging + - logging/era_cosmo_regression diff --git a/src/hirad/conf/training_era_cosmo_regression_test.yaml b/src/hirad/conf/training_era_cosmo_regression_test.yaml new file mode 100644 index 0000000..96e88fa --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_regression_test.yaml @@ -0,0 +1,25 @@ +hydra: + job: + chdir: true + name: regression_era5_cosmo_test + run: + dir: ./outputs/training/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_regression + + - model_size/mini + + # Training + # Leave same as prod + - training/era_cosmo_regression + + # Inference visualization + - generation/era_cosmo_training \ No newline at end of file diff --git a/src/hirad/datasets/__init__.py b/src/hirad/datasets/__init__.py index 53e791e..2fb3d5d 100644 --- a/src/hirad/datasets/__init__.py +++ b/src/hirad/datasets/__init__.py @@ -1,3 +1,3 @@ -from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config +from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config, get_dataset_and_sampler_inference from .era5_cosmo import ERA5_COSMO from .base import DownscalingDataset diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index c09a4f2..924951e 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -37,6 +37,7 @@ def init_train_valid_datasets_from_config( batch_size: int = 1, seed: int = 0, train_test_split: bool = True, + sampler_start_idx: int = 0, ) -> Tuple[ DownscalingDataset, Iterable, @@ -52,6 +53,7 @@ def init_train_valid_datasets_from_config( - batch_size (int): The number of samples in each batch of data. Defaults to 1. - seed (int): The random seed for dataset shuffling. Defaults to 0. - train_test_split (bool): A flag to determine whether to create a validation dataset. Defaults to True. + - sampler_start_idx (int): The initial index of the sampler to use for resuming training. Defaults to 0. Returns: - Tuple[base.DownscalingDataset, Iterable, Optional[base.DownscalingDataset], Optional[Iterable]]: A tuple containing the training dataset and iterator, and optionally the validation dataset and iterator if train_test_split is True. @@ -61,7 +63,7 @@ def init_train_valid_datasets_from_config( if 'validation_path' in config: del config['validation_path'] (dataset, dataset_iter) = init_dataset_from_config( - config, dataloader_cfg, batch_size=batch_size, seed=seed + config, dataloader_cfg, batch_size=batch_size, seed=seed, sampler_start_idx=sampler_start_idx, ) if train_test_split: valid_dataset_cfg = copy.deepcopy(dataset_cfg) @@ -81,6 +83,7 @@ def init_dataset_from_config( dataloader_cfg: Union[dict, None] = None, batch_size: int = 1, seed: int = 0, + sampler_start_idx: int = 0, ) -> Tuple[DownscalingDataset, Iterable]: dataset_cfg = copy.deepcopy(dataset_cfg) dataset_type = dataset_cfg.pop("type", "era5_cosmo") @@ -97,7 +100,7 @@ def init_dataset_from_config( dist = DistributedManager() dataset_sampler = InfiniteSampler( - dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed + dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed, start_idx=sampler_start_idx, ) dataset_iterator = iter( @@ -111,3 +114,22 @@ def init_dataset_from_config( ) return (dataset_obj, dataset_iterator) + + +def get_dataset_and_sampler_inference(dataset_cfg, times, has_lead_time=False): + """ + Get a dataset and sampler for generation. + """ + (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1) + # if has_lead_time: + # plot_times = times + # else: + # plot_times = [ + # datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") + # for time in times + # ] + all_times = dataset.time() + time_indices = [all_times.index(t) for t in times] + sampler = time_indices + + return dataset, sampler diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index f97dbc6..1db8716 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -5,68 +5,170 @@ from typing import List, Tuple import yaml import torch.nn.functional as F +import time +# import zarr + +from hirad.utils.console import PythonLogger + +logger = PythonLogger(__name__) + +DATASET_ORIG_PATH = '/capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full' class ERA5_COSMO(DownscalingDataset): - def __init__(self, dataset_path: str): + def __init__(self, + dataset_path: str, + input_channel_names: List[str] = [], + output_channel_names: List[str] = [], + static_channel_names: List[str] = [], + transform_channels: List[str] = [], + n_month_hour_channels: int = None, + ): super().__init__() #TODO switch hanbdling paths to Path rather than pure strings + self._n_month_hour_channels = n_month_hour_channels self._dataset_path = dataset_path self._era5_path = os.path.join(dataset_path, 'era-interpolated') self._cosmo_path = os.path.join(dataset_path, 'cosmo') - self._info_path = os.path.join(dataset_path, 'info') + self._info_path = os.path.join(DATASET_ORIG_PATH, 'info') + # self._static_path = '/capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/static'# os.path.join(dataset_path, 'static') + self._static_path = os.path.join(DATASET_ORIG_PATH, 'static')# os.path.join(dataset_path, 'static') + # self._zarr_path = os.path.join(dataset_path, 'dataset.zarr') # load file list (each file is one date-time state) - self._file_list = os.listdir(self._cosmo_path) + self._file_list = sorted(os.listdir(self._cosmo_path)) + + # open zarr store + # self._zarr_store = zarr.open(self._zarr_path, mode='r') + # self.era5 = self._zarr_store['era5'] + # self.cosmo = self._zarr_store['cosmo'] + + # Load static info and channel names + if static_channel_names: + with open(os.path.join(self._static_path, 'cosmo-static.yaml'), 'r') as file: + self._static_info = yaml.safe_load(file) + self._static_indeces = [self._static_info['select'].index(name) for name in static_channel_names] + self._static_channels = [ChannelMetadata(name) if len(name.split('_'))==1 + else ChannelMetadata(name.split('_')[0],name.split('_')[1]) + for name in self._static_info['select'] if name in static_channel_names] + static_data = torch.load(os.path.join(self._static_path,'cosmo-static'), weights_only=False)[self._static_indeces] + orig_shape = self.image_shape() + self.static_data = np.flip(static_data \ + .squeeze() \ + .reshape(-1,*orig_shape), + 1) + self.static_mean = self.static_data.mean(axis=(1,2)) + self.static_std = self.static_data.std(axis=(1,2)) + else: + self.static_data = None # Load cosmo info and channel names with open(os.path.join(self._info_path,'cosmo.yaml'), 'r') as file: self._cosmo_info = yaml.safe_load(file) - self._cosmo_channels = [ChannelMetadata(name) for name in self._cosmo_info['select']] + if output_channel_names: + self._cosmo_indeces = [self._cosmo_info['select'].index(name) for name in output_channel_names] + else: + self._cosmo_indeces = list(range(len(self._cosmo_info['select']))) + output_channel_names = self._cosmo_info['select'] + self._cosmo_channels = [ChannelMetadata(name) if len(name.split('_'))==1 + else ChannelMetadata(name.split('_')[0],name.split('_')[1]) + for name in self._cosmo_info['select'] if name in output_channel_names] # Load era5 info and channel names with open(os.path.join(self._info_path,'era.yaml'), 'r') as file: self._era_info = yaml.safe_load(file) + if input_channel_names: + self._era_indeces = [self._era_info['select'].index(name) for name in input_channel_names] + else: + self._era_indeces = list(range(len(self._era_info['select']))) + input_channel_names = self._era_info['select'] self._era_channels = [ChannelMetadata(name) if len(name.split('_'))==1 - else ChannelMetadata(name.split('_')[0],name.split('_')[1]) - for name in self._era_info['select']] + else ChannelMetadata(name.split('_')[0],name.split('_')[1]) + for name in self._era_info['select'] if name in input_channel_names] # Load stats for normalizing channels of input and output cosmo_stats = torch.load(os.path.join(self._info_path,'cosmo-stats'), weights_only=False) - self.output_mean = cosmo_stats['mean'] - self.output_std = cosmo_stats['stdev'] + self.output_mean = cosmo_stats['mean'][self._cosmo_indeces] + self.output_std = cosmo_stats['stdev'][self._cosmo_indeces] era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False) - self.input_mean = era_stats['mean'] - self.input_std = era_stats['stdev'] + self.input_mean = era_stats['mean'][self._era_indeces] + self.input_std = era_stats['stdev'][self._era_indeces] + if self.static_data is not None: + self.input_mean = np.concatenate((self.input_mean, self.static_mean), axis=0) + self.input_std = np.concatenate((self.input_std, self.static_std), axis=0) + # FEATURE: load the mean and std values for transformed channels and update the normalization statistics + self.input_transforms = {} + self.input_inverse_transforms = {} + self.output_transforms = {} + self.output_inverse_transforms = {} + for transform_descriptor in transform_channels: + channel, transformation = transform_descriptor.split('-') + input_channel_idx = self._era_info['select'].index(channel) if channel in self._era_info['select'] else None + output_channel_idx = self._cosmo_info['select'].index(channel) if channel in self._cosmo_info['select'] else None + if transformation.startswith('box_cox'): + lmbda_str = transformation.split('_')[-1] + lmbda = float(transformation.split('_')[-1])/(10**(len(lmbda_str)-1)) + print(f"Applying Box-Cox transformation with lambda={lmbda} to channel {channel} (input idx: {input_channel_idx}, output idx: {output_channel_idx})") + if input_channel_idx is not None: + self.input_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda) + self.input_inverse_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda) + self.input_mean[input_channel_idx] = torch.load(os.path.join(self._info_path,f"era5-{transform_descriptor}-mean"), weights_only=False) + self.input_std[input_channel_idx] = torch.load(os.path.join(self._info_path,f"era5-{transform_descriptor}-std"), weights_only=False) + if output_channel_idx is not None: + self.output_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda) + self.output_inverse_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda) + self.output_mean[output_channel_idx] = torch.load(os.path.join(self._info_path,f"cosmo-{transform_descriptor}-mean"), weights_only=False) + self.output_std[output_channel_idx] = torch.load(os.path.join(self._info_path,f"cosmo-{transform_descriptor}-std"), weights_only=False) + else: + raise ValueError(f"Transformation: {transformation} for channel {channel} not implemented.") + def __getitem__(self, idx): """Get cosmo and era5 interpolated to cosmo grid""" - # get era5 data point + # get data point # squeeze the ensemble dimesnsion # reshape to image_shape # flip so that it starts in top-left corner (by default it is bottom left) # orig_shape = [350,542] #TODO currently padding to be divisible by 16 orig_shape = self.image_shape() - era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ + try: + era5_data = torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)[self._era_indeces] + except: + logger.error(f"Error loading file {os.path.join(self._era5_path,self._file_list[idx])}") + raise + era5_data = np.flip(era5_data \ .squeeze() \ .reshape(-1,*orig_shape), 1) + era5_data = np.concatenate((era5_data, self.static_data), axis=0) if self.static_data is not None else era5_data era5_data = self.normalize_input(era5_data) - # get cosmo data point - cosmo_data = np.flip(torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ + + try: + cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)[self._cosmo_indeces] + except: + logger.error(f"Error loading file {os.path.join(self._cosmo_path,self._file_list[idx])}") + raise + cosmo_data = np.flip(cosmo_data\ .squeeze() \ .reshape(-1,*orig_shape), 1) cosmo_data = self.normalize_output(cosmo_data) - # return samples + + if self._n_month_hour_channels is not None and self._n_month_hour_channels>0: + # extract month and hour from filename + filename = self._file_list[idx] + date_str, hour_str = filename.split('-') + month = int(date_str[4:6]) + hour = int(hour_str[0:2]) + + time_grid = self.make_time_grids(hour, month) + era5_data = np.concatenate((era5_data, time_grid), axis=0) + return torch.tensor(cosmo_data),\ - torch.tensor(era5_data), - # return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ - # F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ - # 0 + torch.tensor(era5_data) def __len__(self): return len(self._file_list) @@ -86,8 +188,13 @@ def latitude(self) -> np.ndarray: def input_channels(self) -> List[ChannelMetadata]: """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" - return self._era_channels - + channels = self._era_channels + self._static_channels if self.static_data is not None else self._era_channels + if self._n_month_hour_channels is not None and self._n_month_hour_channels>0: + for i in range(self._n_month_hour_channels): + channels.append(ChannelMetadata("hour-enc",f"{i}")) + for i in range(self._n_month_hour_channels): + channels.append(ChannelMetadata("month-enc",f"{i}")) + return channels def output_channels(self) -> List[ChannelMetadata]: """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" @@ -108,23 +215,84 @@ def image_shape(self) -> Tuple[int, int]: def normalize_input(self, x: np.ndarray) -> np.ndarray: """Convert input from physical units to normalized data.""" + for channel_idx, transform in self.input_transforms.items(): + x[channel_idx,::] = transform(x[channel_idx,::]) return (x - self.input_mean.reshape((self.input_mean.shape[0],1,1))) \ / self.input_std.reshape((self.input_std.shape[0],1,1)) def denormalize_input(self, x: np.ndarray) -> np.ndarray: """Convert input from normalized data to physical units.""" - return x * self.input_std.reshape((self.input_std.shape[0],1,1)) \ + if self._n_month_hour_channels is not None and self._n_month_hour_channels>0: + x = x[:,:-2*self._n_month_hour_channels,:,:] + x = x * self.input_std.reshape((self.input_std.shape[0],1,1)) \ + self.input_mean.reshape((self.input_mean.shape[0],1,1)) + for channel_idx, inverse_transform in self.input_inverse_transforms.items(): + x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::]) + return x def normalize_output(self, x: np.ndarray) -> np.ndarray: """Convert output from physical units to normalized data.""" + for channel_idx, transform in self.output_transforms.items(): + x[channel_idx,::] = transform(x[channel_idx,::]) return (x - self.output_mean.reshape((self.output_mean.shape[0],1,1))) \ / self.output_std.reshape((self.output_std.shape[0],1,1)) def denormalize_output(self, x: np.ndarray) -> np.ndarray: """Convert output from normalized data to physical units.""" - return x * self.output_std.reshape((self.output_std.shape[0],1,1)) \ - + self.output_mean.reshape((self.output_mean.shape[0],1,1)) \ No newline at end of file + x = x * self.output_std.reshape((self.output_std.shape[0],1,1)) \ + + self.output_mean.reshape((self.output_mean.shape[0],1,1)) + for channel_idx, inverse_transform in self.output_inverse_transforms.items(): + x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::]) + return x + + def box_cox_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarray: + """Apply Box-Cox transformation to the data.""" + channel_array = np.clip(channel_array, 0, None) + return (np.power(channel_array, lmbda) - 1) / lmbda + + def box_cox_inverse_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarray: + """Apply inverse Box-Cox transformation to the data.""" + channel_array = np.clip(channel_array, -1/lmbda, None) + return np.power((lmbda * channel_array) + 1, 1 / lmbda) + + def make_time_grids(self, hour, month): + """ + Create multi-frequency cyclic sin/cos feature grids for hour and month. + + Parameters + ---------- + hour : int + Hour of day, 0-23 + month : int + Month of year, 1-12 + + Returns + ------- + grid : np.ndarray, shape (C, H, W) + Channels = [sin(k*hour), cos(k*hour), sin(k*month), cos(k*month) for each k frequency] + """ + H, W = self.image_shape() + hour_freqs = np.arange(1, self._n_month_hour_channels//2 + 1) + month_freqs = np.arange(1, self._n_month_hour_channels//2 + 1) + + channels = [] + + # --- hour encodings --- + for k in hour_freqs: + angle = 2 * np.pi * k * (hour % 24) / 24.0 + channels.append(np.sin(angle)) + channels.append(np.cos(angle)) + + # --- month encodings --- + for k in month_freqs: + angle = 2 * np.pi * k * ((month - 1) % 12) / 12.0 + channels.append(np.sin(angle)) + channels.append(np.cos(angle)) + + channels = np.array(channels, dtype=np.float32) + grid = np.tile(channels[:, None, None], (1, H, W)) # (C, H, W) + + return grid \ No newline at end of file diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py index eca46c6..75c0fe7 100644 --- a/src/hirad/distributed/manager.py +++ b/src/hirad/distributed/manager.py @@ -529,7 +529,7 @@ def setup( DistributedManager._shared_state["_is_initialized"] = True manager = DistributedManager() - manager._distributed = torch.distributed.is_available() + manager._distributed = torch.distributed.is_available() and world_size > 1 if manager._distributed: # Update rank and world_size if using distributed manager._rank = rank @@ -546,22 +546,23 @@ def setup( #TODO device_id makes the init hang, couldn't figure out why if manager._distributed: # Setup distributed process group - # try: - dist.init_process_group( - backend, - rank=manager.rank, - world_size=manager.world_size, - ) + try: + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + device_id=manager.device, + ) # rank=manager.rank, # world_size=manager.world_size, # device_id=manager.device, - # except TypeError: - # # device_id only introduced in PyTorch 2.3 - # dist.init_process_group( - # backend, - # rank=manager.rank, - # world_size=manager.world_size, - # ) + except TypeError: + # device_id only introduced in PyTorch 2.3 + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + ) if torch.cuda.is_available(): # Set device for this process and empty cache to optimize memory usage diff --git a/src/hirad/eval.sh b/src/hirad/eval.sh new file mode 100644 index 0000000..64bb995 --- /dev/null +++ b/src/hirad/eval.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=normal +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +##SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/eval_compute.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/eval/compute_eval.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file diff --git a/src/hirad/eval/__init__.py b/src/hirad/eval/__init__.py index 13d9eb3..16eff0a 100644 --- a/src/hirad/eval/__init__.py +++ b/src/hirad/eval/__init__.py @@ -1,2 +1,2 @@ -from .metrics import compute_mae, average_power_spectrum -from .plotting import plot_error_projection, plot_power_spectra +from .metrics import absolute_error, compute_mae, average_power_spectrum, crps +from .plotting import plot_map, plot_error_projection, plot_power_spectra, plot_scores_vs_t diff --git a/src/hirad/eval/compute_eval.py b/src/hirad/eval/compute_eval.py new file mode 100644 index 0000000..00e524e --- /dev/null +++ b/src/hirad/eval/compute_eval.py @@ -0,0 +1,229 @@ +import hydra +import logging +import os +import json +from omegaconf import OmegaConf, DictConfig +import torch +import numpy as np +import contextlib +import datetime +from pandas import to_datetime + +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper +from concurrent.futures import ThreadPoolExecutor + +from hirad.eval import absolute_error, crps, plot_scores_vs_t, plot_error_projection +from hirad.models import EDMPrecondSuperResolution, UNet +from hirad.inference import Generator +from hirad.utils.inference_utils import save_images, save_results_as_torch +from hirad.utils.function_utils import get_time_from_range +from hirad.utils.checkpoint import load_checkpoint + +from hirad.datasets import get_dataset_and_sampler_inference + +from hirad.utils.train_helpers import set_patch_shape + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig) -> None: + + + # Initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + + # Initialize logger + logger = PythonLogger("generate") # General python logger + + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") + + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if "has_lead_time" in cfg.generation: + has_lead_time = cfg.generation["has_lead_time"] + else: + has_lead_time = False + dataset, sampler = get_dataset_and_sampler_inference( + dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time + ) + pred_path = getattr(cfg.generation.io, "output_path", "./outputs") + output_path = './plots/analysis202511' + + compute_crps_per_time(times, dataset, pred_path, output_path) + compute_crps_over_time_and_area(times, output_path) + plot_crps_over_time_and_area(times, dataset, output_path) + +def _get_data_path(output_path, time=None, filename=None): + if time: + return os.path.join(output_path, time, filename) + else: + return os.path.join(output_path, filename) + +def load_data(output_path, time=None, filename=None): + return torch.load(_get_data_path(output_path, time, filename), weights_only=False) + +def save_data(data, output_path, time=None, filename=None): + path = _get_data_path(output_path, time, filename) + torch.save(data, path) + +def compute_crps_per_time(times, dataset, pred_path, output_path): + logging.info('Computing CRPS for each time point') + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + start_time=times[0] + + # Load one prediction ensemble to get the shape + prediction_ensemble = torch.load(os.path.join(pred_path, start_time, f'{start_time}-predictions'), weights_only=False) + + # Get a map of output to input channel, for building baseline errors + output_to_input_channel_map = {} + for j in range(len(output_channels)): + index = -1 + for k in range(len(input_channels)): + if input_channels[k].name == output_channels[j].name: + index = k + output_to_input_channel_map[j] = index + + + for i in range(len(times)): + curr_time = times[i] + if i % (24*5) == 0: + logging.info(f'on time {curr_time}') + prediction_ensemble = load_data(pred_path, time=curr_time, filename=f'{curr_time}-predictions') + baseline = load_data(pred_path, time=curr_time, filename=f'{curr_time}-baseline') + target = load_data(pred_path, time=curr_time, filename=f'{curr_time}-target') + + # Calculate ensemble mean error + ensemble_mean = np.mean(prediction_ensemble, 0) + ensemble_mean_error = absolute_error(ensemble_mean, target) + + # Calculate interpolation error (baseline #1) + interpolation_error = np.zeros(target.shape) + for j in range(len(output_channels)): + k = output_to_input_channel_map[j] + if k > -1: + interpolation_error[j,::] = absolute_error(baseline[k,::], target[j,::]) + + # Calculate persistence error (baseline #2) + persistence_error = np.zeros(target.shape) + if i > 0: + prev = load_data(pred_path, time=times[i-1], filename=f'{times[i-1]}-target') + persistence_error = absolute_error(prev, target) + else: + # for the first time point, persist the next-time-point target. + # This is fiction but it keeps the plots from looking weird. + prev = load_data(pred_path, time=times[i+1], filename=f'{times[i+1]}-target') + persistence_error = absolute_error(prev, target) + + + # Calculate CRPS + crps_diffusion_area = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False) + + save_data(crps_diffusion_area, output_path, time=curr_time, filename=f'{curr_time}-crps-ensemble') + save_data(ensemble_mean_error, output_path, time=curr_time, filename=f'{curr_time}-ensemble-mean-error') + save_data(interpolation_error, output_path, time=curr_time, filename=f'{curr_time}-interpolation-error') + save_data(persistence_error, output_path, time=curr_time, filename=f'{curr_time}-persistence-error') + +def compute_crps_over_time_and_area(times, output_path): + logging.info('computing crps and errors') + start_time=times[0] + end_time=times[-1] + + # shape = (channels, x, y) + crps_area = load_data(output_path, time=start_time, filename=f'{start_time}-crps-ensemble') + num_channels = crps_area.shape[0] + + # make area and time plot + total_crps_area = np.zeros_like(crps_area) + total_ensemble_mean_area = np.zeros_like(crps_area) + total_interpolation_area = np.zeros_like(crps_area) + total_persistence_area = np.zeros_like(crps_area) + + crps_over_time = np.zeros((num_channels, len(times))) + ensemble_mean_over_time = np.zeros_like(crps_over_time) + interpolation_over_time = np.zeros_like(crps_over_time) + persistence_over_time = np.zeros_like(crps_over_time) + for i in range(len(times)): + curr_time = times[i] + if i % (24*5) == 0: + logging.info(f'on time {times[i]}') + crps_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-crps-ensemble') + total_crps_area = total_crps_area + crps_area + + ensemble_mean_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-ensemble-mean-error') + total_ensemble_mean_area = total_ensemble_mean_area + ensemble_mean_area + interpolation_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-interpolation-error') + total_interpolation_area = total_interpolation_area + interpolation_area + persistence_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-persistence-error') + if i>0: + total_persistence_area = total_persistence_area + persistence_area + + for j in range(num_channels): + crps_over_time[j,i] = np.mean(crps_area[j,::]) + ensemble_mean_over_time[j,i] = np.mean(ensemble_mean_area[j,::]) + interpolation_over_time[j,i] = np.mean(interpolation_area[j,::]) + persistence_over_time[j,i] = np.mean(persistence_area[j,::]) + mean_crps_area = total_crps_area / len(times) + mean_ensemble_mean_area = total_ensemble_mean_area / len(times) + mean_interpolation_area = total_interpolation_area / len(times) + mean_persistence_area = total_persistence_area / (len(times)-1) + save_data(mean_crps_area, output_path, filename=f'crps-ensemble-area-{start_time}-{end_time}') + save_data(mean_ensemble_mean_area, output_path, filename=f'mae-ensemble-mean-area-{start_time}-{end_time}') + save_data(mean_interpolation_area, output_path, filename=f'mae-interpolation-area-{start_time}-{end_time}') + save_data(mean_persistence_area, output_path, filename=f'mae-persistence-area-{start_time}-{end_time}') + + # Little hack to make the plots look nicer, without having to change dimensions. + persistence_over_time[:,0] = persistence_over_time[:,1] + + save_data(crps_over_time, output_path, filename=f'crps-ensemble-time-{start_time}-{end_time}') + save_data(ensemble_mean_over_time, output_path, filename=f'mae-ensemble-mean-time-{start_time}-{end_time}') + save_data(interpolation_over_time, output_path, filename=f'mae-interpolation-time-{start_time}-{end_time}') + save_data(persistence_over_time, output_path, filename=f'mae-persistence-time-{start_time}-{end_time}') + +def plot_crps_over_time_and_area(times, dataset, output_path): + logging.info('plotting crps and errors') + longitudes = dataset.longitude() + latitudes = dataset.latitude() + output_channels = dataset.output_channels() + start_time=times[0] + end_time=times[-1] + + crps_area = load_data(output_path, filename=f'crps-ensemble-area-{start_time}-{end_time}') + ensemble_mean_area = load_data(output_path, filename=f'mae-ensemble-mean-area-{start_time}-{end_time}') + interpolation_area = load_data(output_path, filename=f'mae-interpolation-area-{start_time}-{end_time}') + persistence_area = load_data(output_path, filename=f'mae-persistence-area-{start_time}-{end_time}') + + crps_ensemble_time = load_data(output_path, filename=f'crps-ensemble-time-{start_time}-{end_time}') + ensemble_mean_time = load_data(output_path, filename=f'mae-ensemble-mean-time-{start_time}-{end_time}') + interpolation_time = load_data(output_path, filename=f'mae-interpolation-time-{start_time}-{end_time}') + persistence_time = load_data(output_path, filename=f'mae-persistence-time-{start_time}-{end_time}') + + for j in range(len(output_channels)): + plot_error_projection(crps_area[j,::], latitudes, longitudes, + _get_data_path(output_path, filename=f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name, title=f'Mean absolute error: CRPS: {output_channels[j].name}') + plot_error_projection(ensemble_mean_area[j,::], latitudes, longitudes, + _get_data_path(output_path, filename=f'NEW-mae-ensemble-mean-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name, title=f'Mean absolute error: Ensemble mean: {output_channels[j].name}') + plot_error_projection(interpolation_area[j,::], latitudes, longitudes, + _get_data_path(output_path, filename=f'NEW-mae-interpolation-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name, title=f'Mean absolute error: Interpolation: {output_channels[j].name}') + plot_error_projection(persistence_area[j,::], latitudes, longitudes, + _get_data_path(output_path, filename=f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name, title=f'Mean absolute error: Persistence: {output_channels[j].name}') + + maes = {} + maes['interpolation'] = interpolation_time[j,::] + maes['ensemble mean'] = ensemble_mean_time[j,::] + maes['crps'] = crps_ensemble_time[j,:] + maes['persistence'] = persistence_time[j,::] + # TODO: consider casting times to datetime objects to avoid warnings. + # However, this seems to be working OK, and a direct cast causes plotting errors + plot_scores_vs_t(maes, times, + _get_data_path(output_path, filename=f'NEW-error-plot-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py new file mode 100644 index 0000000..03e7d70 --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py @@ -0,0 +1,149 @@ +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +import xarray as xr +from omegaconf import DictConfig, OmegaConf + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR, WET_THRESHOLD, LOG_INTERVAL, concat_and_group_diurnal + +def save_plot(hour, means, stds, labels, ylabel, title, out_path): + hrs = np.concatenate([hour.values, [24]]) + plt.figure(figsize=(8,4)) + for mean, std, label in zip(means, stds, labels): + vals = np.append(mean.values, mean.values[0]) + line, = plt.plot(hrs, vals, label=label) + if std is not None: + stdv = np.append(std.values, std.values[0]) + plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path) + plt.close() + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computations for diurnal cycle of precipitation amount and wet-hours") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + logger.info(f"Loaded {len(times)} timesteps to process") + + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + + # Location of the output from inference + out_root = Path(cfg.generation.io.output_path or './outputs') + + # Find channel indices + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Land-sea mask + land_mask = load_land_sea_mask() + + # Prepare lists to collect DataArrays + target_precip, baseline_precip, pred_precip, mean_pred_precip = [], [], [], [] + target_wet, baseline_wet, pred_wet, mean_pred_wet = [], [], [], [] + + # Collect data + for idx, ts in enumerate(times, 1): + dt = datetimes[idx-1] + target = torch.load(out_root/ts/f"{ts}-target", weights_only=False)[tp_out] * CONV_FACTOR + baseline = torch.load(out_root/ts/f"{ts}-baseline", weights_only=False)[tp_in] * CONV_FACTOR # / 6. # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False)[:, tp_out, :, :] * CONV_FACTOR + try: + mean_pred = torch.load(out_root/ts/f"{ts}-regression-prediction", weights_only=False)[tp_out] * CONV_FACTOR + except: + mean_pred = None + + # DataArrays for spatial means at each timestep + da_target = xr.DataArray(target, dims=("lat","lon"), coords=land_mask.coords) + da_baseline = xr.DataArray(baseline, dims=("lat","lon"), coords=land_mask.coords) + da_preds = xr.DataArray(preds, dims=("member","lat","lon"), coords={"member": np.arange(preds.shape[0]), **land_mask.coords}) + if mean_pred is not None: + da_mean_pred = xr.DataArray(mean_pred, dims=("lat","lon"), coords=land_mask.coords) + + # Apply land mask after conversion to xarray + da_target = da_target * land_mask + da_baseline = da_baseline * land_mask + da_preds = da_preds * land_mask + if mean_pred is not None: + da_mean_pred = da_mean_pred * land_mask + + # Spatial mean + target_precip.append(da_target.mean(dim=("lat","lon")).assign_coords(time=dt)) + baseline_precip.append(da_baseline.mean(dim=("lat","lon")).assign_coords(time=dt)) + pred_precip.append(da_preds.mean(dim=("lat","lon")).assign_coords(time=dt)) + if mean_pred is not None: + mean_pred_precip.append(da_mean_pred.mean(dim=("lat","lon")).assign_coords(time=dt)) + + # Wet-hour fraction, i.e., freq(precip) > WET_THRESHOLD + target_wet.append(((da_target / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) + baseline_wet.append(((da_baseline / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) + pred_wet.append(((da_preds / 24> WET_THRESHOLD).mean(dim=("lat","lon")).assign_coords(time=dt))) + if mean_pred is not None: + mean_pred_wet.append(((da_mean_pred / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) + + if idx % LOG_INTERVAL == 0 or idx == len(times): + logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") + + # Compute diurnal means and stds + amount_target_mean, _ = concat_and_group_diurnal(target_precip) + amount_baseline_mean, _ = concat_and_group_diurnal(baseline_precip) + amount_pred_mean, amount_pred_std = concat_and_group_diurnal(pred_precip, is_member=True) + if mean_pred_precip: + amount_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_precip) + + wet_target_mean, _ = concat_and_group_diurnal(target_wet, scale=100.0) # scale to obtain percentages + wet_baseline_mean, _ = concat_and_group_diurnal(baseline_wet, scale=100.0) + wet_pred_mean, wet_pred_std = concat_and_group_diurnal(pred_wet, is_member=True, scale=100.0) + if mean_pred_wet: + wet_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_wet, scale=100.0) + + # Generate plots + save_plot( + amount_target_mean.hour, + [amount_target_mean, amount_baseline_mean, amount_pred_mean, amount_mean_pred_mean] if mean_pred_precip else [amount_target_mean, amount_baseline_mean, amount_pred_mean], + [None, None, amount_pred_std, None] if mean_pred_precip else [None, None, amount_pred_std], + ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_precip else ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)'], + 'Precipitation (mm/day)', + 'Diurnal Cycle of Precip Amount', + out_root / 'diurnal_cycle_precip_amount.png' + ) + save_plot( + wet_target_mean.hour, + [wet_target_mean, wet_baseline_mean, wet_pred_mean, wet_mean_pred_mean] if mean_pred_wet else [wet_target_mean, wet_baseline_mean, wet_pred_mean], + [None, None, wet_pred_std, None] if mean_pred_wet else [None, None, wet_pred_std], + ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wet else ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)'], + 'Wet-Hour Fraction [%]', + 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', + out_root / 'diurnal_cycle_precip_wethours.png' + ) + + logger.info("Plots saved.") + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval/diurnal_cycle_precip_p99.py b/src/hirad/eval/diurnal_cycle_precip_p99.py new file mode 100644 index 0000000..9c54eb7 --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_precip_p99.py @@ -0,0 +1,169 @@ +""" +Plots the diurnal cycle of the all-hour 99th percentile of +precipitation, a somewhat reliable measure of the precipitation intensity. + +Each hour, member and type is treaded separately, to conserve memory... but if the +period is long, this can still be a lot of data and thus an OOM error can occur. +""" +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR + + +def save_plot(hours, lines, labels, ylabel, title, out_path): + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.figure(figsize=(8,4)) + for data, label in zip(lines, labels): + if isinstance(data, tuple): # (mean, std) + mean, std = data + lower = np.maximum(np.array(mean) - std, 0) + upper = np.array(mean) + std + line, = plt.plot(hours, mean, label=label) + plt.fill_between(hours, lower, upper, alpha=0.3, color=line.get_color()) + else: + plt.plot(hours, data, label=label) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.savefig(out_path) + plt.close() + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for diurnal cycle of 99th-percentile of precipitation") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root + out_root = Path(cfg.generation.io.output_path or './outputs') + + # Find channel indices + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Land-sea mask + land_mask = load_land_sea_mask() + + # Storage for diurnal cycles + pct99_mean = {} + pct99_std = {} + + # -- Process target and baseline -- + for mode in ['target', 'baseline', 'regression-prediction']: + logger.info(f"Processing mode: {mode}") + + data_list = [] + try: + for ts in times: + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * CONV_FACTOR * land_mask + data_list.append(data) + except: + logger.error(f"Error loading data for mode {mode}. Skipping.") + continue + + da = xr.DataArray( + np.stack(data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + + # Group by hour and compute 99th percentile + hourly_p99 = da.groupby('time.hour').quantile(0.99, dim='time') + + # Apply scaling factor for baseline + # if mode == 'baseline': + # hourly_p99 = hourly_p99 / 6.0 + + pct99_mean[mode] = hourly_p99.mean(dim=['lat', 'lon']) + + # -- Predictions: compute per hour per member, then mean+std across members -- + logger.info("Processing predictions") + + # Load all prediction data at once into xarray + pred_data_list = [] + for ts in times: + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) * CONV_FACTOR # [n_members, n_channels, lat, lon] + # Extract precipitation channel and convert to xarray for proper broadcasting + tp_data = preds[:, tp_out] # [n_members, lat, lon] + tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon']) + pred_data_list.append(tp_da * land_mask) # apply mask + + pred_da = xr.concat(pred_data_list, dim='time') # [n_members, time, lat, lon] + pred_da = pred_da.assign_coords({ + 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + }) + # Transpose to get the expected dimension order: [member, time, lat, lon] + pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') + + logger.info('Calculating 99th percentile for predictions') + # Group by hour, compute 99th percentile across time, then spatial mean + hourly_p99_by_member = pred_da.groupby('time.hour').quantile(0.99, dim='time').mean(dim=['lat', 'lon']) + + # Store ensemble statistics as xarray DataArrays + pct99_mean['prediction'] = hourly_p99_by_member.mean(dim='member') + pct99_std['prediction'] = hourly_p99_by_member.std(dim='member') + + # Prepare cyclic lists for plotting + def cycle_fn(x): + return x.values.tolist() + [x.values.tolist()[0]] + + logger.info("Preparing data for plotting") + hrs_c = list(range(24)) + [0 + 24] + pct99_lines = [ + cycle_fn(pct99_mean['target']), + cycle_fn(pct99_mean['baseline']), + ( + cycle_fn(pct99_mean['prediction']), + cycle_fn(pct99_std['prediction']) + ) + ] + if 'regression-prediction' in pct99_mean: + pct99_lines.append(cycle_fn(pct99_mean['regression-prediction'])) + + # Plot combined diurnal 99th-percentile cycle + labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff 99th Pct ± Std', 'Regression Prediction'] if 'regression-prediction' in pct99_mean else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff 99th Pct ± Std'] + fn = out_root/'diurnal_cycle_precip_99th_percentile.png' + save_plot( + hrs_c, + pct99_lines, + labels, + 'Precipitation (mm/day)', + 'Diurnal Cycle of 99th-Percentile Precipitation', + fn + ) + logger.info(f"Combined plot saved: {fn}") + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py new file mode 100644 index 0000000..5b4c36f --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -0,0 +1,174 @@ +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +import xarray as xr +from omegaconf import DictConfig, OmegaConf + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, LOG_INTERVAL, concat_and_group_diurnal + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Initialize + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + # Load times + logger.info("Starting computation for diurnal cycles of 2m temperature and windspeed") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + logger.info(f"Loaded {len(times)} timesteps to process") + + # Dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + + # Indices for channels + indices = get_channel_indices(dataset) + out_ch = indices['output'] + in_ch = indices['input'] + + # Temperature channel (try '2t' first, fallback to 't2m') + t2m_out = out_ch.get('2t', out_ch.get('t2m')) + t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out)) + + # Wind channels + u_out = out_ch['10u'] + u_in = in_ch.get('10u', u_out) + v_out = out_ch['10v'] + v_in = in_ch.get('10v', v_out) + + # Output path + out_root = Path(cfg.generation.io.output_path or './outputs') + def load(ts, fn): + return torch.load(out_root/ts/fn, weights_only=False) + + # Land-sea mask + land_mask = load_land_sea_mask() + + # Prepare lists to collect DataArrays + target_temp, baseline_temp, pred_temp, mean_pred_temp = [], [], [], [] + target_wind, baseline_wind, pred_wind, mean_pred_wind = [], [], [], [] + + def mean_over_land(data, dims, coords, time_coord): + da = xr.DataArray(data, dims=dims, coords=coords) * land_mask + return da.mean(dim=("lat","lon")).assign_coords(time=time_coord) + + # Loop over timestamps + for idx, ts in enumerate(times, 1): + dt = datetimes[idx-1] + + # Load data + target = load(ts, f"{ts}-target") + baseline = load(ts, f"{ts}-baseline") + predictions = load(ts, f"{ts}-predictions") + try: + regression_pred = load(ts, f"{ts}-regression-prediction") + except: + regression_pred = None + + # Process temperature (convert to Celsius) + target_temp.append(mean_over_land( + target[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) + baseline_temp.append(mean_over_land( + baseline[t2m_in] - 273.15, ("lat","lon"), land_mask.coords, dt)) + pred_temp.append(mean_over_land( + predictions[:, t2m_out, :, :] - 273.15, ("member","lat","lon"), + {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) + if regression_pred is not None: + mean_pred_temp.append(mean_over_land( + regression_pred[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) + + + # Process wind speed + target_wind.append(mean_over_land( + np.hypot(target[u_out], target[v_out]), ("lat","lon"), land_mask.coords, dt)) + baseline_wind.append(mean_over_land( + np.hypot(baseline[u_in], baseline[v_in]), ("lat","lon"), land_mask.coords, dt)) + pred_wind.append(mean_over_land( + np.hypot(predictions[:, u_out, :, :], predictions[:, v_out, :, :]), + ("member","lat","lon"), {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) + if regression_pred is not None: + mean_pred_wind.append(mean_over_land( + np.hypot(regression_pred[u_out], regression_pred[v_out]), ("lat","lon"), land_mask.coords, dt)) + + if idx % LOG_INTERVAL == 0 or idx == len(times): + logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") + + # Compute diurnal means and stds + temp_target_mean, _ = concat_and_group_diurnal(target_temp) + temp_baseline_mean, _ = concat_and_group_diurnal(baseline_temp) + temp_pred_mean, temp_pred_std = concat_and_group_diurnal(pred_temp, is_member=True) + if mean_pred_temp: + temp_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_temp) + + wind_target_mean, _ = concat_and_group_diurnal(target_wind) + wind_baseline_mean, _ = concat_and_group_diurnal(baseline_wind) + wind_pred_mean, wind_pred_std = concat_and_group_diurnal(pred_wind, is_member=True) + if mean_pred_wind: + wind_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_wind) + + def save_plot(hour, means, stds, labels, ylabel, title, out_path): + hrs = np.concatenate([hour.values, [24]]) + plt.figure(figsize=(8,4)) + for mean, std, label in zip(means, stds, labels): + vals = np.append(mean.values, mean.values[0]) + line, = plt.plot(hrs, vals, label=label) + if std is not None: + stdv = np.append(std.values, std.values[0]) + plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path) + plt.close() + + data = [temp_target_mean, temp_baseline_mean, temp_pred_mean, temp_mean_pred_mean] if mean_pred_temp else [temp_target_mean, temp_baseline_mean, temp_pred_mean] + labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_temp else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)'] + stds = [None, None, temp_pred_std, None] if mean_pred_temp else [None, None, temp_pred_std] + + # Generate plots + save_plot( + temp_target_mean.hour, + data, + stds, + labels, + '2m Temperature [°C]', + 'Diurnal Cycle of 2m Temperature', + out_root / 'diurnal_cycle_2t.png' + ) + + data = [wind_target_mean, wind_baseline_mean, wind_pred_mean, wind_mean_pred_mean] if mean_pred_wind else [wind_target_mean, wind_baseline_mean, wind_pred_mean] + labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wind else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)'] + stds = [None, None, wind_pred_std, None] if mean_pred_wind else [None, None, wind_pred_std] + + save_plot( + wind_target_mean.hour, + data, + stds, + labels, + 'Windspeed [m/s]', + 'Diurnal Cycle of Windspeed', + out_root / 'diurnal_cycle_windspeed.png' + ) + + logger.info("Plots saved.") + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py new file mode 100644 index 0000000..78d7395 --- /dev/null +++ b/src/hirad/eval/hist.py @@ -0,0 +1,254 @@ +""" +Plots the domain-mean precipitation distribution over land. + +This script computes and visualizes the distribution of precipitation values +over land. +""" +import logging +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR_HOURLY, LOG_INTERVAL + + +def save_distribution_plot(hist_data_dict, bin_edges, labels, colors, title, ylabel, out_path, percentiles_data=None): + """Save distribution plot with pre-computed histograms.""" + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + + plt.figure(figsize=(10, 6)) + + # Plot histograms from pre-computed bin counts + for (key, hist_data), label, color in zip(hist_data_dict.items(), labels, colors): + if isinstance(hist_data, tuple): # Handle ensemble data + # Plot individual members with transparency + for i, member_hist in enumerate(hist_data): + alpha = 0.5 if i > 0 else 0.7 + label_member = label if i == 0 else None + # Plot histogram from bin counts + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + plt.plot(bin_centers, member_hist, alpha=alpha, color=color, + label=label_member, drawstyle='steps-mid') + else: + # Plot histogram from bin counts + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + plt.plot(bin_centers, hist_data, alpha=0.7, color=color, + label=label, linewidth=2, drawstyle='steps-mid') + + plt.xscale('log') + plt.yscale('log') + plt.xlabel(ylabel) + plt.ylabel('Probability Density') + plt.ylim(1e-8, 1) + plt.xlim(bin_edges[1], bin_edges[-1]) + plt.title(title) + plt.grid(True, alpha=0.3) + + # Add percentile lines if provided + if percentiles_data: + # Calculate y-range for percentile lines (lowest 10% of log scale) + y_bottom, y_top = plt.ylim() + log_bottom, log_top = np.log10(y_bottom), np.log10(y_top) + vline_ymax = 10**(log_bottom + 0.1 * (log_top - log_bottom)) + vline_ymin = y_bottom + + # Define line styles for percentiles + percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'} + percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'} + colors = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green', 'regression-prediction': 'red'} + legend_added = set() + + # Plot all percentile lines + for dataset_name, data in percentiles_data.items(): + color = colors[dataset_name] + + if dataset_name in ['target', 'baseline', 'regression-prediction']: + # Single dataset + for percentile, value in data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.8) # No label here + else: + # Ensemble members + for member_data in data.values(): + for percentile, value in member_data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.6) # No label here + + # Add black legend entries for percentiles (override the colored ones) + for percentile in [99, 99.9, 99.99]: + if percentile in legend_added: + plt.plot([], [], color='black', linestyle=percentile_styles[percentile], + label=percentile_labels[percentile]) + + plt.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for domain-mean precipitation distribution over land") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root + out_root = Path(cfg.generation.io.output_path or './outputs') + + # Find channel indices + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Land-sea mask + land_mask = load_land_sea_mask() + + # Define histogram bins + bins = np.logspace(-1, 3.3, 200) # Log-spaced bins for precipitation + + # Storage for histogram data and land values + hist_data = {} + all_land_values = {} + + # -- Process target and baseline -- + for mode in ['target', 'baseline', 'regression-prediction']: + logger.info(f"Processing mode: {mode}") + + hist_counts = np.zeros(len(bins) - 1) + total_samples = 0 + all_values = [] + + try: + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target', 'regression-prediction'] else tp_in] * CONV_FACTOR_HOURLY * land_mask + + # Apply scaling factor for baseline + # if mode == 'baseline': + # data = data / 6.0 + + land_values = data.values[~np.isnan(data.values)] + all_values.extend(land_values) + + counts, _ = np.histogram(land_values, bins=bins) + hist_counts += counts + total_samples += len(land_values) + except: + logger.warning(f"{mode} not available, skipping") + continue + # Normalize to probability density + bin_widths = np.diff(bins) + hist_data[mode] = hist_counts / (total_samples * bin_widths) + all_land_values[mode] = np.array(all_values) + logger.info(f"Processed {total_samples} land values for {mode}") + + # -- Process predictions: compute histogram for each ensemble member -- + logger.info("Processing predictions") + + n_members = None + member_hist_data = [] + all_member_values = [] + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) * CONV_FACTOR_HOURLY # [n_members, n_channels, lat, lon] + + if n_members is None: + n_members = preds.shape[0] + member_hist_data = [np.zeros(len(bins) - 1) for _ in range(n_members)] + member_sample_counts = [0 for _ in range(n_members)] + all_member_values = [[] for _ in range(n_members)] + + for member_idx in range(n_members): + data = preds[member_idx, tp_out] * land_mask + land_values = data.values[~np.isnan(data.values)] + all_member_values[member_idx].extend(land_values) + + counts, _ = np.histogram(land_values, bins=bins) + member_hist_data[member_idx] += counts + member_sample_counts[member_idx] += len(land_values) + + # Normalize member histograms to probability density + bin_widths = np.diff(bins) + normalized_member_hists = [] + for member_idx in range(n_members): + normalized_hist = member_hist_data[member_idx] / (member_sample_counts[member_idx] * bin_widths) + normalized_member_hists.append(normalized_hist) + + hist_data['predictions'] = tuple(normalized_member_hists) + + logger.info(f"Collected {n_members} ensemble members for predictions") + + # Compute percentiles for all datasets + percentiles_data = {} + percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999} + + # Target and baseline percentiles + for mode in ['target', 'baseline', 'regression-prediction']: + if mode in all_land_values: + data_array = xr.DataArray(all_land_values[mode]) + percentiles_data[mode] = { + key: data_array.quantile(p).item() + for key, p in percentiles.items() + } + + # Ensemble member percentiles + percentiles_data['predictions'] = {} + for member_idx in range(n_members): + member_data_array = xr.DataArray(all_member_values[member_idx]) + percentiles_data['predictions'][f'member_{member_idx}'] = { + key: member_data_array.quantile(p).item() + for key, p in percentiles.items() + } + + # Create distribution plots + labels = ['COSMO-2 Analysis', 'ERA5', 'Regression Prediction', 'CorrDiff Ensemble'] if 'regression-prediction' in hist_data else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] + colors = ['blue', 'orange', 'red', 'green'] if 'regression-prediction' in hist_data else ['blue', 'orange', 'green'] + + fn = out_root / 'precipitation_distribution_over_land.png' + save_distribution_plot( + hist_data, + bins, + labels, + colors, + 'Domain-Mean Precip. Over Land (Pooled Data)', + 'Precipitation (mm/h)', + fn, + percentiles_data + ) + logger.info(f"Distribution plot saved: {fn}") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py new file mode 100644 index 0000000..257cfcd --- /dev/null +++ b/src/hirad/eval/map_precip_stats.py @@ -0,0 +1,202 @@ +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import ( + plot_map_precipitation, plot_map, get_channel_indices, + CONV_FACTOR, LOG_INTERVAL, WET_THRESHOLD +) + + +def consecutive_spell(condition): + """Return longest consecutive spell where condition is True (per gridpoint).""" + def _spell_length(x): + x = np.asarray(x, dtype=bool) + if len(x) == 0: + return 0 + runs = np.diff(np.concatenate(([False], x, [False])).astype(int)) + starts = np.where(runs == 1)[0] + ends = np.where(runs == -1)[0] + return int(np.max(ends - starts)) if len(starts) > 0 else 0 + return xr.apply_ufunc(_spell_length, condition, input_core_dims=[['time']], vectorize=True) + + +def apply_statistic(data, stat_type, stat_param): + """Apply a statistic to the data along the time dimension.""" + if stat_type == 'mean': + return data.mean(dim='time') + if stat_type == 'quantile': + return data.quantile(stat_param, dim='time') + if stat_type == 'Rx1hr': + return data.max(dim='time') + if stat_type == 'Rx1day': + daily = data.resample(time="1D").sum("time") + return daily.max(dim='time') + if stat_type == 'Rx5day': + daily = data.resample(time="1D").sum("time") + return daily.rolling(time=5, center=False).sum().max(dim='time') + if stat_type == 'cdd': + daily = data.resample(time="1D").sum("time") + return consecutive_spell(daily < 1.0) + if stat_type == 'cwd': + daily = data.resample(time="1D").sum("time") + return consecutive_spell(daily >= 1.0) + if stat_type == 'weth_freq': + return (data / 24 > WET_THRESHOLD).mean(dim='time') * 100 + raise ValueError(f"Unsupported statistic type: {stat_type}") + + +def plot_stat_map(data, filename, stat_config, label): + """Plot a single statistic map with appropriate styling.""" + if stat_config['type'] == 'weth_freq': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]} (%)', + label='Wet-Hour Frequency [%]', vmin=0, vmax=30, cmap='PuBu', extend='max' + ) + elif stat_config['type'] == 'cdd': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Days', vmin=0, vmax=60, cmap='viridis', extend='max' + ) + elif stat_config['type'] == 'cwd': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Days', vmin=0, vmax=20, cmap='viridis', extend='max' + ) + else: + plot_map_precipitation( + data, filename, + title=f'{label}: {stat_config["title_stat"]} Precipitation', + threshold=stat_config['threshold'], rfac=1.0 + ) + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup and config + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting precipitation statistics generation") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Processing {len(times)} timesteps") + + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + out_root = Path(cfg.generation.io.output_path or './outputs') + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + + # Statistic configuration + STATISTICS_CONFIG = { + 'mean': {'type': 'mean', 'threshold': 0.01, 'title': 'Mean'}, + 'p99': {'type': 'quantile', 'param': 0.99, 'threshold': 0.1, 'title': '99th Percentile'}, + 'p99.9': {'type': 'quantile', 'param': 0.999, 'threshold': 0.1, 'title': '99.9th Percentile'}, + 'p99.99': {'type': 'quantile', 'param': 0.9999, 'threshold': 0.1, 'title': '99.99th Percentile'}, + 'Rx1hr': {'type': 'Rx1hr', 'threshold': 0.1, 'title': 'Maximum (Rx1hr)'}, + 'Rx1day': {'type': 'Rx1day', 'threshold': 0.1, 'title': 'Maximum 1-day Amount (Rx1day)'}, + 'Rx5day': {'type': 'Rx5day', 'threshold': 0.1, 'title': 'Maximum 5-day Total (Rx5day)'}, + 'cdd': {'type': 'cdd', 'threshold': 0.1, 'title': 'Consecutive Dry Days (CDD)'}, + 'cwd': {'type': 'cwd', 'threshold': 0.1, 'title': 'Consecutive Wet Days (CWD)'}, + 'weth_freq': {'type': 'weth_freq', 'threshold': 0.01, 'title': 'Wet-Hour Frequency'} + } + stat_configs = [ + { + 'stat_name': name, + 'title_stat': config['title'], + 'param': config.get('param'), + **config + } + for name, config in STATISTICS_CONFIG.items() + ] + + # Target and baseline modes + basic_modes = { + 'target': (tp_out, 'COSMO-2 Analysis'), + 'baseline': (tp_in, 'ERA5'), + 'regression-prediction': (tp_out, 'Regression Prediction') + } + logger.info(f"Generating {len(stat_configs)} statistics for {len(basic_modes)} basic modes + predictions") + + for mode, (tp_channel, label) in basic_modes.items(): + logger.info(f"Processing mode: {mode}") + # Load all timesteps for this mode + data_list = [] + try: + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Loading {mode} timestep {i+1}/{len(times)}: {ts}") + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) * CONV_FACTOR + data_list.append(data[tp_channel]) + except: + logger.warning(f"{mode} not available, skipping") + continue + mode_data = xr.DataArray( + np.stack(data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + # if mode == 'baseline': + # mode_data = mode_data / 6.0 + # Compute and plot all statistics for this mode + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for {mode}...") + result = apply_statistic(mode_data, stat_config['type'], stat_config['param']) + map_output_dir = out_root / f"maps_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_stat_map(result.values, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label) + + # Predictions mode: process each member separately to save memory + logger.info("Processing predictions mode...") + data = torch.load(out_root/times[0]/f"{times[0]}-predictions", weights_only=False) + n_members = data.shape[0] + logger.info(f"Found {n_members} ensemble members") + + for member_idx in range(n_members): + logger.info(f"Processing prediction member {member_idx+1}/{n_members}") + # Load all timesteps for this member + data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Loading prediction member {member_idx} timestep {i+1}/{len(times)}: {ts}") + pred_data = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) * CONV_FACTOR + data_list.append(pred_data[member_idx, tp_out]) + member_data = xr.DataArray( + np.stack(data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + + # Compute and plot all statistics for this member + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") + member_result = apply_statistic(member_data, stat_config['type'], stat_config['param']) + + # Create map + map_output_dir = out_root / f"maps_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') + member_label = f'CorrDiff Member {member_idx+1}' + plot_stat_map(member_result.values, member_filename, stat_config, member_label) + + logger.info("All precipitation statistics maps generated successfully") + + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval/map_wind_stats.py b/src/hirad/eval/map_wind_stats.py new file mode 100644 index 0000000..ac63224 --- /dev/null +++ b/src/hirad/eval/map_wind_stats.py @@ -0,0 +1,433 @@ +import logging +from pathlib import Path + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import plot_map, get_channel_indices, LOG_INTERVAL + + +def compute_wind_speed(u, v): + """Compute wind speed from U and V.""" + return np.hypot(u, v) + + +def compute_wind_direction(u, v, calm_threshold=0.0): + """Compute wind direction in degrees from N.""" + dir_deg = (np.degrees(np.arctan2(-u, -v)) % 360) + if calm_threshold > 0: + speed = np.hypot(u, v) + dir_deg = np.where(speed <= calm_threshold, np.nan, dir_deg) + return dir_deg + + +def apply_wind_statistic_streaming(times, out_root, mode, u_channel, v_channel, stat_type, stat_param=None): + """Compute wind statistic by streaming through timesteps.""" + accumulator = None + count = 0 + sin_acc = cos_acc = speed_acc = None + + for ts in times: + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) + u = data[u_channel].cpu().numpy() if torch.is_tensor(data[u_channel]) else data[u_channel] + v = data[v_channel].cpu().numpy() if torch.is_tensor(data[v_channel]) else data[v_channel] + + if stat_type == 'mean_speed': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.zeros_like(speed) + accumulator += speed + elif stat_type == 'max_speed': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.full_like(speed, -np.inf) + accumulator = np.maximum(accumulator, speed) + elif stat_type == 'wind_power': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.zeros_like(speed) + accumulator += speed**3 + elif stat_type == 'mean_u': + if accumulator is None: + accumulator = np.zeros_like(u) + accumulator += u + elif stat_type == 'mean_v': + if accumulator is None: + accumulator = np.zeros_like(v) + accumulator += v + elif stat_type in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq', + 'strong_breeze_freq', 'gale_freq']: + speed = compute_wind_speed(u, v) + thresholds = { + 'calm_freq': 2.0, + 'light_breeze_freq': 1.6, + 'moderate_breeze_freq': 5.5, + 'strong_breeze_freq': 10.8, + 'gale_freq': 17.2 + } + threshold = thresholds[stat_type] + if accumulator is None: + accumulator = np.zeros_like(speed) + if stat_type == 'calm_freq': + accumulator += (speed < threshold).astype(float) + else: + accumulator += (speed > threshold).astype(float) + elif stat_type == 'prevailing_direction': + speed = compute_wind_speed(u, v) + direction = compute_wind_direction(u, v, calm_threshold=1.0) + rad = np.deg2rad(direction) + weighted_sin = np.sin(rad) * speed + weighted_cos = np.cos(rad) * speed + + if sin_acc is None: + sin_acc = np.zeros_like(weighted_sin) + cos_acc = np.zeros_like(weighted_cos) + speed_acc = np.zeros_like(speed) + + sin_acc += np.nan_to_num(weighted_sin, 0) + cos_acc += np.nan_to_num(weighted_cos, 0) + speed_acc += speed + elif stat_type == 'direction_variability': + direction = compute_wind_direction(u, v) + rad = np.deg2rad(direction) + + if sin_acc is None: + sin_acc = np.zeros_like(np.sin(rad)) + cos_acc = np.zeros_like(np.cos(rad)) + + sin_acc += np.sin(rad) + cos_acc += np.cos(rad) + + count += 1 + del data, u, v + + if stat_type == 'prevailing_direction': + mean_dir = np.arctan2(sin_acc / (speed_acc + 1e-10), cos_acc / (speed_acc + 1e-10)) + return np.mod(np.rad2deg(mean_dir), 360) + elif stat_type == 'direction_variability': + R = np.clip(np.hypot(sin_acc / count, cos_acc / count), 1e-10, 1.0) + return np.rad2deg(np.sqrt(-2 * np.log(R))) + elif stat_type == 'max_speed': + return accumulator + elif stat_type in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq', + 'strong_breeze_freq', 'gale_freq']: + return (accumulator / count) * 100 + else: + return accumulator / count + + +def plot_wind_stat_map(data, filename, stat_config, label): + """Plot wind statistic map.""" + if stat_config['type'] == 'mean_speed': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Wind Speed [m/s]', vmin=0, vmax=10, cmap='inferno', extend='max' + ) + elif stat_config['type'] == 'max_speed': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Wind Speed [m/s]', vmin=0, vmax=30, cmap='inferno', extend='max' + ) + elif stat_config['type'] == 'wind_power': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Wind Power Density [m³/s³]', vmin=0, vmax=1000, cmap='plasma', extend='max' + ) + elif stat_config['type'] in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq', 'strong_breeze_freq', 'gale_freq']: + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Frequency [%]', vmin=0, vmax=80, cmap='GnBu', extend='max' + ) + elif stat_config['type'] == 'prevailing_direction': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Direction [degrees from N]', vmin=0, vmax=360, cmap='twilight', extend='neither' + ) + elif stat_config['type'] == 'direction_variability': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Circular Std Dev [degrees]', vmin=20, vmax=140, cmap='viridis', extend='max' + ) + elif stat_config['type'] in ['mean_u', 'mean_v']: + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Wind Component [m/s]', vmin=-5, vmax=5, cmap='RdBu_r', extend='both' + ) + else: + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Value', vmin=None, vmax=None, cmap='viridis', extend='neither' + ) + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting wind statistics generation") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Processing {len(times)} timesteps") + + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + out_root = Path(cfg.generation.io.output_path or './outputs') + indices = get_channel_indices(dataset) + + u10_out = indices['output'].get('10u') + v10_out = indices['output'].get('10v') + u10_in = indices['input'].get('10u', u10_out) + v10_in = indices['input'].get('10v', v10_out) + + + WIND_STATISTICS_CONFIG = { + 'mean_speed': { + 'type': 'mean_speed', + 'title': 'Mean Wind Speed' + }, + 'max_speed': { + 'type': 'max_speed', + 'title': 'Maximum Wind Speed' + }, + 'wind_power': { + 'type': 'wind_power', + 'title': 'Mean Wind Power Density' + }, + 'calm_freq': { + 'type': 'calm_freq', + 'title': 'Calm Frequency (<2 m/s, Beaufort 0-1)' + }, + 'light_breeze_freq': { + 'type': 'light_breeze_freq', + 'title': 'Light Breeze Frequency (>1.6 m/s, Beaufort 2+)' + }, + 'moderate_breeze_freq': { + 'type': 'moderate_breeze_freq', + 'title': 'Moderate Breeze Frequency (>5.5 m/s, Beaufort 4+)' + }, + 'strong_breeze_freq': { + 'type': 'strong_breeze_freq', + 'title': 'Strong Breeze Frequency (>10.8 m/s, Beaufort 6+)' + }, + 'gale_freq': { + 'type': 'gale_freq', + 'title': 'Gale Frequency (>17.2 m/s, Beaufort 8+)' + }, + 'prevailing_dir': { + 'type': 'prevailing_direction', + 'title': 'Prevailing Wind Direction' + }, + 'dir_variability': { + 'type': 'direction_variability', + 'title': 'Wind Direction Variability' + }, + 'mean_u': { + 'type': 'mean_u', + 'title': 'Mean U-Component' + }, + 'mean_v': { + 'type': 'mean_v', + 'title': 'Mean V-Component' + } + } + + stat_configs = [ + { + 'stat_name': name, + 'title_stat': config['title'], + 'param': config.get('param'), + **config + } + for name, config in WIND_STATISTICS_CONFIG.items() + ] + + basic_modes = { + 'target': ((u10_out, v10_out), 'COSMO-2 Analysis'), + 'baseline': ((u10_in, v10_in), 'ERA5'), + 'regression-prediction': ((u10_out, v10_out), 'Regression Prediction') + } + logger.info(f"Generating {len(stat_configs)} statistics for {len(basic_modes)} modes + predictions") + + for mode, (wind_channels, label) in basic_modes.items(): + logger.info(f"Processing mode: {mode}") + u_channel, v_channel = wind_channels + + try: + test_data = torch.load(out_root/times[0]/f"{times[0]}-{mode}", weights_only=False) + del test_data + except Exception as e: + logger.warning(f"{mode} not available: {e}") + continue + + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for {mode}...") + try: + result = apply_wind_statistic_streaming( + times, out_root, mode, u_channel, v_channel, + stat_config['type'], stat_config.get('param') + ) + + map_output_dir = out_root / f"maps_wind_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_wind_stat_map( + result, + str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), + stat_config, + label + ) + del result + except Exception as e: + logger.error(f"Failed {stat_config['title_stat']} for {mode}: {e}") + continue + + logger.info("Processing predictions mode...") + try: + data = torch.load(out_root/times[0]/f"{times[0]}-predictions", weights_only=False) + n_members = data.shape[0] + del data + logger.info(f"Found {n_members} ensemble members") + + for member_idx in range(n_members): + logger.info(f"Processing member {member_idx+1}/{n_members}") + + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") + try: + def load_member_data(ts): + pred_data = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) + u_data = pred_data[member_idx, u10_out] + v_data = pred_data[member_idx, v10_out] + u = u_data.cpu().numpy() if torch.is_tensor(u_data) else u_data + v = v_data.cpu().numpy() if torch.is_tensor(v_data) else v_data + del pred_data + return u, v + + accumulator = None + count = 0 + sin_acc = cos_acc = speed_acc = None + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Loading prediction member {member_idx} timestep {i+1}/{len(times)}: {ts}") + + u, v = load_member_data(ts) + + if stat_config['type'] == 'mean_speed': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.zeros_like(speed) + accumulator += speed + elif stat_config['type'] == 'max_speed': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.full_like(speed, -np.inf) + accumulator = np.maximum(accumulator, speed) + elif stat_config['type'] == 'wind_power': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.zeros_like(speed) + accumulator += speed**3 + elif stat_config['type'] == 'mean_u': + if accumulator is None: + accumulator = np.zeros_like(u) + accumulator += u + elif stat_config['type'] == 'mean_v': + if accumulator is None: + accumulator = np.zeros_like(v) + accumulator += v + elif stat_config['type'] in ['calm_freq', 'light_breeze_freq', + 'moderate_breeze_freq', 'strong_breeze_freq', + 'gale_freq']: + speed = compute_wind_speed(u, v) + thresholds = { + 'calm_freq': 2.0, + 'light_breeze_freq': 1.6, + 'moderate_breeze_freq': 5.5, + 'strong_breeze_freq': 10.8, + 'gale_freq': 17.2 + } + threshold = thresholds[stat_config['type']] + if accumulator is None: + accumulator = np.zeros_like(speed) + if stat_config['type'] == 'calm_freq': + accumulator += (speed < threshold).astype(float) + else: + accumulator += (speed > threshold).astype(float) + elif stat_config['type'] == 'prevailing_direction': + speed = compute_wind_speed(u, v) + direction = compute_wind_direction(u, v, calm_threshold=1.0) + rad = np.deg2rad(direction) + weighted_sin = np.sin(rad) * speed + weighted_cos = np.cos(rad) * speed + + if sin_acc is None: + sin_acc = np.zeros_like(weighted_sin) + cos_acc = np.zeros_like(weighted_cos) + speed_acc = np.zeros_like(speed) + + sin_acc += np.nan_to_num(weighted_sin, 0) + cos_acc += np.nan_to_num(weighted_cos, 0) + speed_acc += speed + elif stat_config['type'] == 'direction_variability': + direction = compute_wind_direction(u, v) + rad = np.deg2rad(direction) + + if sin_acc is None: + sin_acc = np.zeros_like(np.sin(rad)) + cos_acc = np.zeros_like(np.cos(rad)) + + sin_acc += np.sin(rad) + cos_acc += np.cos(rad) + + count += 1 + del u, v + + if stat_config['type'] == 'prevailing_direction': + mean_dir = np.arctan2(sin_acc / (speed_acc + 1e-10), cos_acc / (speed_acc + 1e-10)) + member_result = np.mod(np.rad2deg(mean_dir), 360) + elif stat_config['type'] == 'direction_variability': + R = np.clip(np.hypot(sin_acc / count, cos_acc / count), 1e-10, 1.0) + member_result = np.rad2deg(np.sqrt(-2 * np.log(R))) + elif stat_config['type'] == 'max_speed': + member_result = accumulator + elif stat_config['type'] in ['calm_freq', 'light_breeze_freq', + 'moderate_breeze_freq', 'strong_breeze_freq', + 'gale_freq']: + member_result = (accumulator / count) * 100 + else: + member_result = accumulator / count + + map_output_dir = out_root / f"maps_wind_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') + plot_wind_stat_map(member_result, member_filename, stat_config, f'CorrDiff Member {member_idx+1}') + del member_result + + except Exception as e: + logger.error(f"Failed {stat_config['title_stat']} for member {member_idx+1}: {e}") + continue + + except Exception as e: + logger.warning(f"Predictions not available: {e}") + + logger.info("Wind statistics generation complete") + + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index b73caf0..af9d5a9 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -3,9 +3,10 @@ import numpy as np import torch import xskillscore -import scoringrules as sr from scipy.signal import periodogram +import xskillscore +import xarray as xr # set up MAE calculation to be run for each channel for a given date/time (for target COSMO, prediction, and ERA interpolated) @@ -58,6 +59,41 @@ def average_power_spectrum(data: np.ndarray, d=2.0): # d=2km by default return freqs, power_spectra -def crps(): - # Time, variable, ensemble, x, y - xskillscore.crps_ensemble() \ No newline at end of file +def crps(prediction_ensemble, target, average_over_area=True, average_over_channels=True, average_over_time=True): + # Assumes that prediction_ensemble is in form: + # (member, channel, x, y) or + # (time, member, channel, x, y) + # Returns: a k-dimensional array of continuous ranked probability scores, + # where k is the number of dimensions that were not averaged over. + # For example, if average_over_area is False (and all others true), will + # return an ndarray of shape (X,Y) + target_coords = [('channel', np.arange(target.shape[-3])), + ('x', np.arange(target.shape[-2])), + ('y', np.arange(target.shape[-1]))] + + + forecasts_coords = [('member', np.arange(prediction_ensemble.shape[-4])), + ('channel', np.arange(prediction_ensemble.shape[-3])), + ('x', np.arange(prediction_ensemble.shape[-2])), + ('y', np.arange(prediction_ensemble.shape[-1]))] + + if prediction_ensemble.ndim > 4 and target.ndim > 3: + forecasts_coords.insert(0, ('time', np.arange(prediction_ensemble.shape[-5]))) + target_coords.insert(0, ('time', np.arange(target.shape[-4]))) + + + + forecasts = xr.DataArray(prediction_ensemble, coords = forecasts_coords) + observations = xr.DataArray(target, coords = target_coords) + + dim = [] + if prediction_ensemble.ndim > 4 and average_over_time: + dim.append('time') + if average_over_area: + dim.append('x') + dim.append('y') + if average_over_channels: + dim.append('channel') + crps = xskillscore.crps_ensemble(observations=observations, forecasts=forecasts, dim=dim) + crps = crps.to_numpy() + return crps diff --git a/src/hirad/eval/plot_maps.py b/src/hirad/eval/plot_maps.py new file mode 100644 index 0000000..a2230a6 --- /dev/null +++ b/src/hirad/eval/plot_maps.py @@ -0,0 +1,163 @@ +import torch +import yaml +import numpy as np +import os +from pathlib import Path +import argparse +from hirad.eval import plotting +from hirad.utils.inference_utils import calculate_bounds, transform_channel +import hydra +from omegaconf import DictConfig, OmegaConf + +channel_plot_args = { + "2t": {"label": "°C"}, + "10u": {"label": "m/s"}, + "10v": {"label": "m/s"}, + "tp": {"label": "boxcox(mm/h)"}, +} + + +def get_available_time_steps(results_dir): + """Find all available time step directories""" + results_path = Path(results_dir) + time_step_dirs = [d.name for d in results_path.iterdir() if d.is_dir() and not d.name.startswith('plots') and not d.name.startswith('maps')] + return sorted(time_step_dirs) + +def plot_time_step(results_dir, output_dir, time_step, input_channels, output_channels, cfg): + """Plot all channels for a single time step""" + print(f"Processing time step: {time_step}") + + # Set up paths for this time step + ts_results_dir = Path(results_dir) / time_step + + # Load tensors + try: + target = torch.load(ts_results_dir / f"{time_step}-target", weights_only=False) + baseline = torch.load(ts_results_dir / f"{time_step}-baseline", weights_only=False) + predictions = torch.load(ts_results_dir / f"{time_step}-predictions", weights_only=False) + + try: + mean_pred = torch.load(ts_results_dir / f"{time_step}-regression-prediction", weights_only=False) + except FileNotFoundError: + mean_pred = None + print(f" Warning: No mean prediction found for {time_step}") + + except FileNotFoundError as e: + print(f" Error: Missing required file for {time_step}: {e}") + return + + # Create output directory for this time step + ts_output_dir = Path(output_dir) / time_step + ts_output_dir.mkdir(parents=True, exist_ok=True) + + # Plot each channel + for idx, channel in enumerate(output_channels): + print(f" Plotting channel: {channel}") + + # Get input channel index (handle case where input/output channels differ) + try: + input_idx = input_channels.index(channel) + except ValueError: + print(f" Warning: Channel {channel} not in input channels, skipping baseline") + continue + + # Transform data + if channel != "tp" or not cfg.get("plot_box_precipitation", False): + tgt = transform_channel(target[idx], channel) + base = transform_channel(baseline[input_idx], channel) + preds = transform_channel(predictions[:, idx], channel) + mean = transform_channel(mean_pred[idx], channel) if mean_pred is not None else None + if channel == "tp": + threshold = transform_channel(np.array([cfg.get("tp_threshold", 0.002)]), "tp")[0] # Transform threshold too + tgt = np.ma.masked_where(tgt <= threshold, tgt) + base = np.ma.masked_where(base <= threshold, base) + preds = np.ma.masked_where(preds <= threshold, preds) + if mean is not None: + mean = np.ma.masked_where(mean <= threshold, mean) + else: + # For precipitation, use raw values if plotting box precipitation + tgt = target[idx] + base = baseline[input_idx] + preds = predictions[:, idx] + mean = mean_pred[idx] if mean_pred is not None else None + + # Calculate consistent bounds (skip for precipitation) + if channel != "tp" or not cfg.get("plot_box_precipitation", False): + arrays = [tgt, base] + [preds[i] for i in range(preds.shape[0])] + if mean is not None: + arrays.append(mean) + vmin, vmax = calculate_bounds(*arrays) + else: + vmin, vmax = None, None + + base_channel_dir = ts_output_dir / channel + base_channel_dir.mkdir(parents=True, exist_ok=True) + + # Plot target + fname = ts_output_dir / channel / f"target" + if channel == "tp" and cfg.get("plot_box_precipitation", False): + plotting.plot_map_precipitation(tgt, str(fname), title=f"Target - {channel}") + else: + plotting.plot_map(tgt, str(fname), vmin=vmin, vmax=vmax, + title=f"Target - {channel}", **channel_plot_args.get(channel, {})) + + # Plot baseline + fname = ts_output_dir / channel / "baseline" + if channel == "tp" and cfg.get("plot_box_precipitation", False): + plotting.plot_map_precipitation(base, str(fname), title=f"Baseline - {channel}") + else: + plotting.plot_map(base, str(fname), vmin=vmin, vmax=vmax, + title=f"Baseline - {channel}", **channel_plot_args.get(channel, {})) + + # Plot mean prediction if available + if mean is not None: + fname = ts_output_dir / channel / "mean-prediction" + if channel == "tp" and cfg.get("plot_box_precipitation", False): + plotting.plot_map_precipitation(mean, str(fname), title=f"Mean Prediction - {channel}") + else: + plotting.plot_map(mean, str(fname), vmin=vmin, vmax=vmax, + title=f"Mean Prediction - {channel}", **channel_plot_args.get(channel, {})) + + # Plot ensemble members + for member_idx in range(preds.shape[0]): + fname = ts_output_dir / channel / f"prediction_{member_idx:02d}" + if channel == "tp" and cfg.get("plot_box_precipitation", False): + plotting.plot_map_precipitation( + preds[member_idx], str(fname), + title=f"Prediction {member_idx} - {channel}" + ) + else: + plotting.plot_map( + preds[member_idx], str(fname), + vmin=vmin, vmax=vmax, + title=f"Prediction {member_idx} - {channel}", + **channel_plot_args.get(channel, {}) + ) + +@hydra.main(version_base=None, config_path="../conf", config_name="plotting") +def main(cfg: DictConfig) -> None: + OmegaConf.resolve(cfg) + + input_channels = cfg.dataset.input_channel_names + output_channels = cfg.dataset.output_channel_names + + # Set up directories + results_dir = Path(cfg.results_dir) + output_dir = Path(cfg.output_dir) if "output_dir" in cfg and cfg.output_dir else results_dir / "plots" + output_dir.mkdir(exist_ok=True) + + # Determine time steps to process + if cfg.time_steps: + time_steps = cfg.time_steps + else: + time_steps = get_available_time_steps(results_dir) + print(f"Found {len(time_steps)} time steps: {time_steps}") + + # Process each time step + for time_step in time_steps: + plot_time_step(results_dir, output_dir, time_step, input_channels, output_channels, cfg) + + print(f"Plotting complete. Results saved to: {output_dir}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 1ca11c2..e83a625 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -1,22 +1,198 @@ import logging import cartopy.crs as ccrs +import cartopy.feature as cfeature import matplotlib.pyplot as plt import numpy as np +import torch +import xarray as xr +from matplotlib.colors import BoundaryNorm, ListedColormap +from pathlib import Path +from datetime import datetime -def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str): + +# COSMO‑2 GRID: TODO: Add to dataset config +LAT = np.arange(-4.42, 3.36 + 0.02, 0.02) +LON = np.arange(-6.82, 4.80 + 0.02, 0.02) +RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) + +# Constants for data processing +CONV_FACTOR_HOURLY = 1000 # Convert precip of ERA5 from meters to mm/h +CONV_FACTOR = CONV_FACTOR_HOURLY * 24 # Convert precip of ERA5 from from meters to mm/day +WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h +LOG_INTERVAL = 24 # Log progress every N timesteps + +LAND_SEA_MASK_PATH = '/capstor/store/mch/msopr/hirad-gen/eval/lsm.npy' + +def get_channel_indices(dataset, channels=None): + """ + Get channel indices for input and output channels from dataset. + + Args: + dataset: Dataset object with input_channels() and output_channels() methods + channels: Optional list of channel names to look up. If None, returns all channel mappings. + + Returns: + dict: Dictionary with 'input' and 'output' keys, each containing channel name -> index mapping + + Example: + indices = get_channel_indices(dataset, ['tp', '2t', '10u', '10v']) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) # Fallback to output index if not in input + """ + out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} + in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + + if channels is None: + return {'input': in_ch, 'output': out_ch} + + # Filter to requested channels only + filtered_out = {ch: out_ch[ch] for ch in channels if ch in out_ch} + filtered_in = {ch: in_ch[ch] for ch in channels if ch in in_ch} + + return {'input': filtered_in, 'output': filtered_out} + +def load_land_sea_mask(path=LAND_SEA_MASK_PATH): + """Load and retrun a land-sea mask as xarray DataArray.""" + lsm_data = np.load(path).reshape(352, 544) + return xr.DataArray( + np.where(lsm_data >= 0.5, 1.0, np.nan), + dims=['lat', 'lon'], + coords={"lat": np.arange(352), "lon": np.arange(544)} + ) + +def concat_and_group_diurnal(list_of_da, is_member=False, scale=1.0): + """Helper to concatenate DataArrays and compute diurnal statistics.""" + da = xr.concat(list_of_da, dim="time").groupby("time.hour") + if is_member: + timmean = da.mean(dim='time') * scale + mean = timmean.mean(dim='member') + std = da.std(dim='member').mean(dim='time') * scale + else: + mean = da.mean(dim='time') * scale + std = None + return mean, std + + +def plot_map(values: np.array, + filename: str, + label='', + title='', + vmin=None, + vmax=None, + cmap=None, + extend='neither', + norm=None, + ticks=None): + """Plot observed or interpolated data in a scatter plot.""" + logging.info(f'Creating map: {filename}') + + latitudes = LAT[RELAX_ZONE : RELAX_ZONE + 352] + longitudes = LON[RELAX_ZONE : RELAX_ZONE + 544] + lon2d, lat2d = np.meshgrid(longitudes, latitudes) + + fig, ax = plt.subplots( + figsize=(8, 6), + subplot_kw={"projection": ccrs.RotatedPole(pole_longitude=-170.0, + pole_latitude= 43.0)} + ) + contour = ax.pcolormesh( + lon2d, lat2d, values, + cmap=cmap, shading="auto", + norm=norm if norm else None, + vmin=None if norm else vmin, + vmax=None if norm else vmax, + ) + ax.coastlines() + ax.add_feature(cfeature.BORDERS, linewidth=1) + ax.gridlines(visible=False) + ax.set_xticks([]) + ax.set_yticks([]) + + plt.title(title) + cbar = plt.colorbar( + contour, + label=label, + orientation="horizontal", + extend=extend, + shrink=0.75, + pad=0.02 + ) + if ticks is not None: + cbar.set_ticks(ticks) + cbar.set_ticklabels([f'{tick:g}' for tick in ticks]) + + plt.tight_layout() + fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight") + plt.close(fig) + +def plot_map_precipitation(values, filename, title='', threshold=0.1, rfac=1000.0): + """Plot precipitation data with specific colormap and thresholds.""" + # Scale and mask values below threshold + values = rfac * values # m/h --> mm/h + values = np.ma.masked_where(values <= threshold, values) + + # Predefined colors and bounds specific for precipitation + colors = ['none', 'powderblue', 'dodgerblue', 'mediumblue', + 'forestgreen', 'limegreen', 'lawngreen', + 'yellow', 'gold', 'darkorange', 'red', + 'darkviolet', 'violet', 'thistle'] + bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000] + + cmap = ListedColormap(colors) + norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False) + + plot_map( + values, filename, + cmap=cmap, + norm=norm, + ticks=bounds, + title=title, + label='mm/h', + extend='max' + ) + +def wind_direction(u, v): + """Compute wind direction from u and v components.""" + return(np.arctan2(-u, -v) * 180 / np.pi) % 360 + +@DeprecationWarning +def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str, label='', title='', vmin=None, vmax=None): + """Plot observed or interpolated data in a scatter plot.""" fig = plt.figure() fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) logging.info(f'plotting values to {filename}') - p = ax.scatter(x=longitudes, y=latitudes, c=values) + p = ax.scatter(x=longitudes, y=latitudes, c=values, vmin=vmin, vmax=vmax) ax.coastlines() ax.gridlines(draw_labels=True) - plt.colorbar(p, label="absolute error", orientation="horizontal") + plt.colorbar(p, label=label, orientation="horizontal") + plt.savefig(filename) + plt.close('all') + +def plot_scores_vs_t(scores: dict[str,np.ndarray], times: np.array, filename: str, xlabel='', ylabel='', title=''): + + ax = plt.subplot() + colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w'] # TODO, add more + i=0 + for k in scores.keys(): + style = colors[i] + # If more than 50 points, don't connect lines + if len(times) > 50: + style = style + '.' + else: + style = style + '-' + p, = ax.plot(times, scores[k], style) + i=i+1 + p.set_label(k) + ax.legend() + ax.set_xticks([times[0],times[-1]]) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) plt.savefig(filename) plt.close('all') def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): - fig = plt.figure() for k in freqs.keys(): plt.loglog(freqs[k], spec[k], label=k) plt.title(channel_name) diff --git a/src/hirad/eval/probability_of_exceedance.py b/src/hirad/eval/probability_of_exceedance.py new file mode 100644 index 0000000..f56b551 --- /dev/null +++ b/src/hirad/eval/probability_of_exceedance.py @@ -0,0 +1,247 @@ +""" +Plots the probability of exceedance for precipitation over land. + +This script computes and visualizes the complementary cumulative distribution +(probability of exceeding x mm/h) over land). +""" +import logging +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR_HOURLY, LOG_INTERVAL + + +def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title, ylabel, out_path, percentiles_data=None): + """Save probability of exceedance plot with pre-computed data.""" + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + + plt.figure(figsize=(10, 6)) + + # Plot exceedance curves + for (key, exceedance_data), label, color in zip(exceedance_data_dict.items(), labels, colors): + if isinstance(exceedance_data, tuple): # Handle ensemble data + # Plot individual members with transparency + for i, member_exceedance in enumerate(exceedance_data): + alpha = 0.5 if i > 0 else 0.7 + label_member = label if i == 0 else None + plt.plot(thresholds, member_exceedance, alpha=alpha, color=color, + label=label_member, linewidth=1) + else: + # Plot single dataset + plt.plot(thresholds, exceedance_data, alpha=0.7, color=color, + label=label, linewidth=2) + + plt.xscale('log') + plt.yscale('log') + plt.xlabel(ylabel) + plt.ylabel('Probability of Exceedance') + plt.ylim(1e-8, 1) + plt.xlim(thresholds[1], thresholds[-1]) + plt.title(title) + plt.grid(True, alpha=0.3) + + # Add percentile lines if provided + if percentiles_data: + # Calculate y-range for percentile lines (lowest 10% of log scale) + y_bottom, y_top = plt.ylim() + log_bottom, log_top = np.log10(y_bottom), np.log10(y_top) + vline_ymax = 10**(log_bottom + 0.1 * (log_top - log_bottom)) + vline_ymin = y_bottom + + # Define line styles for percentiles + percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'} + percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'} + colors_perc = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green', 'regression-prediction': 'red'} + legend_added = set() + + # Plot all percentile lines + for dataset_name, data in percentiles_data.items(): + color = colors_perc[dataset_name] + + if dataset_name in ['target', 'baseline', 'regression-prediction']: + # Single dataset + for percentile, value in data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.8) # No label here + else: + # Ensemble members + for member_data in data.values(): + for percentile, value in member_data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.6) # No label here + + # Add black legend entries for percentiles (override the colored ones) + for percentile in [99, 99.9, 99.99]: + if percentile in legend_added: + plt.plot([], [], color='black', linestyle=percentile_styles[percentile], + label=percentile_labels[percentile]) + + plt.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for probability of exceedance over land") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root + out_root = Path(cfg.generation.io.output_path or './outputs') + + # Find channel indices + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Land-sea mask + land_mask = load_land_sea_mask() + + # Define thresholds for exceedance calculation + thresholds = np.logspace(-2, 2.1, 200) # From 0.01 to 100 mm/h + + # Storage for exceedance data and land values + exceedance_data = {} + all_land_values = {} + + # -- Process target and baseline -- + for mode in ['target', 'baseline', 'regression-prediction']: + logger.info(f"Processing mode: {mode}") + + all_values = [] + + try: + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * CONV_FACTOR_HOURLY * land_mask + + # Apply scaling factor for baseline + # if mode == 'baseline': + # data = data / 6.0 + + land_values = data.values[~np.isnan(data.values)] + all_values.extend(land_values) + except: + logger.warning(f"{mode} data not found, skipping") + continue + + # Compute exceedance probabilities + all_values = np.array(all_values) + exceedance_probs = [] + for threshold in thresholds: + prob_exceed = np.mean(all_values > threshold) + exceedance_probs.append(prob_exceed) + + exceedance_data[mode] = np.array(exceedance_probs) + all_land_values[mode] = all_values + logger.info(f"Processed {len(all_values)} land values for {mode}") + + # -- Process predictions: compute exceedance for each ensemble member -- + logger.info("Processing predictions") + + n_members = None + all_member_values = [] + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) * CONV_FACTOR_HOURLY # [n_members, n_channels, lat, lon] + + if n_members is None: + n_members = preds.shape[0] + all_member_values = [[] for _ in range(n_members)] + + for member_idx in range(n_members): + data = preds[member_idx, tp_out] * land_mask + land_values = data.values[~np.isnan(data.values)] + all_member_values[member_idx].extend(land_values) + + # Compute exceedance probabilities for each ensemble member + member_exceedance_data = [] + for member_idx in range(n_members): + member_values = np.array(all_member_values[member_idx]) + member_exceedance = [] + for threshold in thresholds: + prob_exceed = np.mean(member_values > threshold) + member_exceedance.append(prob_exceed) + member_exceedance_data.append(np.array(member_exceedance)) + + exceedance_data['predictions'] = tuple(member_exceedance_data) + + logger.info(f"Collected {n_members} ensemble members for predictions") + + # Compute percentiles for all datasets + percentiles_data = {} + percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999} + + # Target and baseline percentiles + for mode in ['target', 'baseline', 'regression-prediction']: + if mode in all_land_values: + data_array = xr.DataArray(all_land_values[mode]) + percentiles_data[mode] = { + key: data_array.quantile(p).item() + for key, p in percentiles.items() + } + + # Ensemble member percentiles + percentiles_data['predictions'] = {} + for member_idx in range(n_members): + member_data_array = xr.DataArray(all_member_values[member_idx]) + percentiles_data['predictions'][f'member_{member_idx}'] = { + key: member_data_array.quantile(p).item() + for key, p in percentiles.items() + } + + # Create exceedance plots + labels = ['COSMO-2 Analysis', 'ERA5', 'Regression Prediction', 'CorrDiff Ensemble'] if 'regression-prediction' in exceedance_data else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] + colors = ['blue', 'orange', 'red', 'green'] if 'regression-prediction' in exceedance_data else ['blue', 'orange', 'green'] + + fn = out_root / 'precipitation_exceedance_over_land.png' + save_exceedance_plot( + exceedance_data, + thresholds, + labels, + colors, + 'Probability of Exceedance', + 'All-hour Precipitation Over Land [mm/h] (Pooled Data)', + fn, + percentiles_data + ) + logger.info(f"Exceedance plot saved: {fn}") + + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval/probability_of_exceedance_wind.py b/src/hirad/eval/probability_of_exceedance_wind.py new file mode 100644 index 0000000..6fe3ff7 --- /dev/null +++ b/src/hirad/eval/probability_of_exceedance_wind.py @@ -0,0 +1,344 @@ +"""Probability of exceedance for wind speed and components.""" +import logging +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices, LOG_INTERVAL + + +def compute_wind_speed(u, v): + """Compute wind speed from U and V components.""" + return np.hypot(u, v) + + +def compute_exceedance_probs(values, thresholds, use_abs=False): + """Compute exceedance probabilities.""" + if use_abs: + return np.array([np.mean(np.abs(values) > t) for t in thresholds]) + else: + return np.array([np.mean(values > t) for t in thresholds]) + + +def update_exceedance_counts(counts, total, values, thresholds, use_abs=False): + """Update exceedance counts incrementally.""" + data = np.abs(values) if use_abs else values + for i, threshold in enumerate(thresholds): + counts[i] += np.sum(data > threshold) + total += len(values) + return counts, total + + +def compute_percentiles(values, percentile_dict, use_abs=False): + """Compute percentiles.""" + data = np.abs(values) if use_abs else values + data_array = xr.DataArray(data) + return {key: data_array.quantile(p).item() for key, p in percentile_dict.items()} + + +def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title, ylabel, out_path, percentiles_data=None): + """Save probability of exceedance plot.""" + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + + plt.figure(figsize=(10, 6)) + + # Plot exceedance curves + for (key, exceedance_data), label, color in zip(exceedance_data_dict.items(), labels, colors): + if isinstance(exceedance_data, tuple): # Handle ensemble data + # Plot individual members with transparency + for i, member_exceedance in enumerate(exceedance_data): + alpha = 0.5 if i > 0 else 0.7 + label_member = label if i == 0 else None + plt.plot(thresholds, member_exceedance, alpha=alpha, color=color, + label=label_member, linewidth=1) + else: + # Plot single dataset + plt.plot(thresholds, exceedance_data, alpha=0.7, color=color, + label=label, linewidth=2) + + plt.xscale('log') + plt.xlim(thresholds[1], thresholds[-1]) + plt.yscale('log') + plt.xlabel(ylabel) + plt.ylabel('Probability of Exceedance') + plt.ylim(1e-8, 1) + plt.title(title) + plt.grid(True, alpha=0.3) + + # Add percentile lines if provided + if percentiles_data: + # Calculate y-range for percentile lines (lowest 10% of log scale) + y_bottom, y_top = plt.ylim() + log_bottom, log_top = np.log10(y_bottom), np.log10(y_top) + vline_ymax = 10**(log_bottom + 0.1 * (log_top - log_bottom)) + vline_ymin = y_bottom + + # Define line styles for percentiles + percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'} + percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'} + colors_perc = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green', 'regression-prediction': 'red'} + legend_added = set() + + # Plot all percentile lines + for dataset_name, data in percentiles_data.items(): + color = colors_perc[dataset_name] + + if dataset_name in ['target', 'baseline', 'regression-prediction']: + # Single dataset + for percentile, value in data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.8) # No label here + else: + # Ensemble members + for member_data in data.values(): + for percentile, value in member_data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.6) # No label here + + # Add black legend entries for percentiles (override the colored ones) + for percentile in [99, 99.9, 99.99]: + if percentile in legend_added: + plt.plot([], [], color='black', linestyle=percentile_styles[percentile], + label=percentile_labels[percentile]) + + plt.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for probability of exceedance for wind speed") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root + out_root = Path(cfg.generation.io.output_path or './outputs') + + # Find channel indices for wind components + indices = get_channel_indices(dataset) + u10_out = indices['output'].get('10u') + v10_out = indices['output'].get('10v') + u10_in = indices['input'].get('10u', u10_out) + v10_in = indices['input'].get('10v', v10_out) + + if u10_out is None or v10_out is None: + logger.error("Wind components (10u, 10v) not found in dataset!") + return + + logger.info(f"Wind component channel indices - output: 10u={u10_out}, 10v={v10_out}, input: 10u={u10_in}, 10v={v10_in}") + + # Define thresholds for exceedance calculation (same for all variables) + thresholds = np.logspace(-1, 1.5, 200) # From 0.1 to ~31.6 m/s + n_thresholds = len(thresholds) + + # Storage for exceedance counts (incremental computation) + exceedance_counts = { + 'speed': {}, 'u': {}, 'v': {} + } + totals = {'speed': {}, 'u': {}, 'v': {}} + + # Storage for percentile computation (collect samples) + percentile_samples = {'speed': {}, 'u': {}, 'v': {}} + + # -- Process target and baseline -- + for mode in ['target', 'baseline', 'regression-prediction']: + logger.info(f"Processing mode: {mode}") + + # Initialize counts + for var in ['speed', 'u', 'v']: + exceedance_counts[var][mode] = np.zeros(n_thresholds, dtype=np.int64) + totals[var][mode] = 0 + percentile_samples[var][mode] = [] + + try: + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) + + # Extract wind components + if mode in ['target', 'regression-prediction']: + u = data[u10_out] + v = data[v10_out] + else: # baseline + u = data[u10_in] + v = data[v10_in] + + wind_speed = compute_wind_speed(u, v) + + # Get valid values + valid_mask = ~np.isnan(wind_speed) + speed_vals = wind_speed[valid_mask].flatten() + u_vals = u[valid_mask].flatten() + v_vals = v[valid_mask].flatten() + + # Update exceedance counts incrementally + exceedance_counts['speed'][mode], totals['speed'][mode] = update_exceedance_counts( + exceedance_counts['speed'][mode], totals['speed'][mode], speed_vals, thresholds, use_abs=False + ) + exceedance_counts['u'][mode], totals['u'][mode] = update_exceedance_counts( + exceedance_counts['u'][mode], totals['u'][mode], u_vals, thresholds, use_abs=True + ) + exceedance_counts['v'][mode], totals['v'][mode] = update_exceedance_counts( + exceedance_counts['v'][mode], totals['v'][mode], v_vals, thresholds, use_abs=True + ) + + # Collect samples for percentiles (subsample to save memory) + sample_rate = max(1, len(speed_vals) // 10000) # Keep ~10k samples per timestep + percentile_samples['speed'][mode].extend(speed_vals[::sample_rate]) + percentile_samples['u'][mode].extend(u_vals[::sample_rate]) + percentile_samples['v'][mode].extend(v_vals[::sample_rate]) + + except Exception as e: + logger.warning(f"{mode} data not found or error occurred, skipping: {e}") + continue + + logger.info(f"Processed {totals['speed'][mode]} values for {mode}") + + # -- Process predictions: compute exceedance for each ensemble member -- + logger.info("Processing predictions") + + n_members = None + member_counts = {'speed': [], 'u': [], 'v': []} + member_totals = {'speed': [], 'u': [], 'v': []} + member_samples = {'speed': [], 'u': [], 'v': []} + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) # [n_members, n_channels, lat, lon] + + if n_members is None: + n_members = preds.shape[0] + for var in ['speed', 'u', 'v']: + member_counts[var] = [np.zeros(n_thresholds, dtype=np.int64) for _ in range(n_members)] + member_totals[var] = [0 for _ in range(n_members)] + member_samples[var] = [[] for _ in range(n_members)] + + for member_idx in range(n_members): + u = preds[member_idx, u10_out] + v = preds[member_idx, v10_out] + wind_speed = compute_wind_speed(u, v) + + valid_mask = ~np.isnan(wind_speed) + speed_vals = wind_speed[valid_mask].flatten() + u_vals = u[valid_mask].flatten() + v_vals = v[valid_mask].flatten() + + # Update counts + member_counts['speed'][member_idx], member_totals['speed'][member_idx] = update_exceedance_counts( + member_counts['speed'][member_idx], member_totals['speed'][member_idx], speed_vals, thresholds, use_abs=False + ) + member_counts['u'][member_idx], member_totals['u'][member_idx] = update_exceedance_counts( + member_counts['u'][member_idx], member_totals['u'][member_idx], u_vals, thresholds, use_abs=True + ) + member_counts['v'][member_idx], member_totals['v'][member_idx] = update_exceedance_counts( + member_counts['v'][member_idx], member_totals['v'][member_idx], v_vals, thresholds, use_abs=True + ) + + # Collect samples for percentiles + sample_rate = max(1, len(speed_vals) // 10000) + member_samples['speed'][member_idx].extend(speed_vals[::sample_rate]) + member_samples['u'][member_idx].extend(u_vals[::sample_rate]) + member_samples['v'][member_idx].extend(v_vals[::sample_rate]) + + logger.info(f"Collected {n_members} ensemble members for predictions") + + # Convert counts to probabilities + exceedance_data = {'speed': {}, 'u': {}, 'v': {}} + + for var in ['speed', 'u', 'v']: + # Single datasets + for mode in ['target', 'baseline', 'regression-prediction']: + if mode in exceedance_counts[var] and totals[var][mode] > 0: + exceedance_data[var][mode] = exceedance_counts[var][mode] / totals[var][mode] + + # Ensemble members + member_probs = [] + for member_idx in range(n_members): + if member_totals[var][member_idx] > 0: + member_probs.append(member_counts[var][member_idx] / member_totals[var][member_idx]) + exceedance_data[var]['predictions'] = tuple(member_probs) + + # Compute percentiles for all datasets and variables + percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999} + percentiles_data = {'speed': {}, 'u': {}, 'v': {}} + + # Single datasets (target, baseline, regression-prediction) + for var in ['speed', 'u', 'v']: + use_abs = (var in ['u', 'v']) + for mode in ['target', 'baseline', 'regression-prediction']: + if mode in percentile_samples[var] and len(percentile_samples[var][mode]) > 0: + percentiles_data[var][mode] = compute_percentiles( + np.array(percentile_samples[var][mode]), percentiles, use_abs + ) + + # Ensemble members + percentiles_data[var]['predictions'] = {} + for member_idx in range(n_members): + if len(member_samples[var][member_idx]) > 0: + percentiles_data[var]['predictions'][f'member_{member_idx}'] = compute_percentiles( + np.array(member_samples[var][member_idx]), percentiles, use_abs + ) + + # Create exceedance plots + labels = ['COSMO-2 Analysis', 'ERA5', 'Regression Prediction', 'CorrDiff Ensemble'] if 'regression-prediction' in exceedance_data['speed'] else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] + colors = ['blue', 'orange', 'red', 'green'] if 'regression-prediction' in exceedance_data['speed'] else ['blue', 'orange', 'green'] + + # Define plot configurations + plot_configs = [ + ('windspeed_exceedance.png', 'speed', 'Probability of Exceedance for Wind Speed', + 'All-hour Wind Speed [m/s] (Pooled Data)'), + ('wind_u_exceedance.png', 'u', 'Probability of Exceedance for abs(10u)', + 'All-hour 10u Component [m/s] (Pooled Data)'), + ('wind_v_exceedance.png', 'v', 'Probability of Exceedance for abs(10v)', + 'All-hour 10v Component [m/s] (Pooled Data)'), + ] + + for filename, var, title, ylabel in plot_configs: + fn = out_root / filename + save_exceedance_plot( + exceedance_data[var], + thresholds, + labels, + colors, + title, + ylabel, + fn, + percentiles_data[var] + ) + logger.info(f"{var.capitalize()} exceedance plot saved: {fn}") + + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval/snapshots.py b/src/hirad/eval/snapshots.py new file mode 100644 index 0000000..f6c7892 --- /dev/null +++ b/src/hirad/eval/snapshots.py @@ -0,0 +1,303 @@ +"""Generates maps of precipitation, temperature, and wind components/speed/direction.""" +import logging +from dataclasses import dataclass, field, replace +from datetime import datetime +from pathlib import Path + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.eval import compute_mae, plot_map +from hirad.eval.plotting import plot_map_precipitation, wind_direction +from hirad.utils.function_utils import get_time_from_range +from hirad.utils.inference_utils import calculate_bounds + +@dataclass +class ChannelMeta: + """Metadata for a channel.""" + name: str + cmap: str = "viridis" + me_cmap: str | None = None + unit: str = "" + norm: any = None + err_vmin: float = None + err_vmax: float = None + vmin: float = None + vmax: float = None + extend: str = "both" + precip_kwargs: dict = field(default_factory=lambda: {"threshold": 0.1, "rfac": 1000.0}) + + @classmethod + def get(cls, ch_or_name: "ChannelMeta | str | None", *, vmin=None, vmax=None) -> "ChannelMeta": + name = getattr(ch_or_name, "name", ch_or_name or "") + base = CHANNELS.get(name) or cls(name=name) + if vmin is not None or vmax is not None: + return replace(base, vmin=vmin, vmax=vmax) + return base + +CHANNELS = { + "tp": ChannelMeta(name="tp", cmap=None, unit="mm/h", extend="max", precip_kwargs={"threshold": 0.1, "rfac": 1000.0}), + "2t": ChannelMeta(name="2t", cmap="RdYlBu_r", me_cmap="RdBu", unit="K", err_vmin=-4.5, err_vmax=4.5), + "10u": ChannelMeta(name="10u", cmap="BrBG", me_cmap="BrBG", unit="m/s", err_vmin=-10, err_vmax=10, vmin=-10, vmax=10), + "10v": ChannelMeta(name="10v", cmap="BrBG", me_cmap="BrBG", unit="m/s", err_vmin=-10, err_vmax=10, vmin=-10, vmax=10), +} + +def format_time_str(dt_str, input_fmt="%Y%m%d-%H%M", output_fmt="%d-%m-%Y %H:%M"): + """Convert time string from input_fmt to output_fmt.""" + dt = datetime.strptime(dt_str, input_fmt) + return dt.strftime(output_fmt) + +class FileRepository: + def __init__(self, root_path): + self.root = Path(root_path) + + def load(self, time, filename): + return torch.load(self.root / time / filename, weights_only=False) + + def _ensure_dir(self, *subdirs): + """Make (and return) root_path/subdir1/subdir2/….""" + d = self.root.joinpath(*subdirs) + d.mkdir(parents=True, exist_ok=True) + return d + + def _make_fname(self, curr_time, prefix, suffix, member_idx): + """Build a filename like '20250724-1230-prefix-suffix[_member]'.""" + base = f"{curr_time}-{prefix}-{suffix}" + if member_idx is not None: + base += f"_{member_idx}" + return base + + def output_file(self, channel, curr_time, suffix, member_idx=None): + # decide on the folder name: e.g. 'tp_100m' or just 'tp' + folder = f"{channel.name}_{channel.level}" if getattr(channel, "level", None) else channel.name + fname = self._make_fname(curr_time, channel.name, suffix, member_idx) + return self._ensure_dir(folder) / fname + + def wind_file(self, wind_type, curr_time, suffix, member_idx=None): + # e.g. wind_type = "FF10m" or "DD10m" + fname = self._make_fname(curr_time, wind_type, suffix, member_idx) + return self._ensure_dir(wind_type) / fname + +def map_output_to_input_channels(output_channels, input_channels): + """ + Maps output channels to input channels based on their names. + """ + return { + j: next((k for k, input_channel in enumerate(input_channels) if input_channel.name == output_channel.name), -1) + for j, output_channel in enumerate(output_channels) + } + +def save_field(name, data, meta, files, channel, t, member=None, kind=None, cmap=None, vmin=None, vmax=None, custom_path=None, plot_func=None, title=None, **plot_kwargs): + """Save a field by plotting it with the appropriate function and parameters. Handles precipitation specially and supports custom paths.""" + # Determine output path + suffix = f"{name}-{kind}" if kind else f"{name}" + out_path = custom_path or files.output_file(channel, t, suffix, member) + + # Choose plotting function + plot = plot_func or plot_map + + # Case dependent plot parameters + title = title or meta.name + extend = 'max' if kind == 'mae' else meta.extend + common = {'title': title, 'norm': meta.norm, 'extend': extend, **plot_kwargs} + + # Precipitation case + if plot.__name__ == 'plot_map_precipitation': + precip_args = {**meta.precip_kwargs, **{k: plot_kwargs[k] for k in meta.precip_kwargs if k in plot_kwargs}} + plot(data, out_path, title=title, **precip_args) + else: + plot(data, out_path, vmin=vmin if vmin is not None else meta.vmin, vmax=vmax if vmax is not None else meta.vmax, cmap=cmap or meta.cmap, label=meta.unit, **common) + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig) -> None: + + # Initialize distributed manager + DistributedManager.initialize() + + # Initialize logger + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger("plot_maps") + + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") + else: + times = cfg.generation.times + + dataset_cfg = OmegaConf.to_container(cfg.dataset) + has_lead_time = cfg.generation.get("has_lead_time", False) + dataset, sampler = get_dataset_and_sampler_inference( + dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time + ) + + output_path = getattr(cfg.generation.io, "output_path", "./outputs") + files = FileRepository(output_path) + + for curr_time in times: + prediction = files.load(curr_time, f'{curr_time}-predictions') + baseline = files.load(curr_time, f'{curr_time}-baseline') + target = files.load(curr_time, f'{curr_time}-target') + try: + mean_pred = files.load(curr_time, f'{curr_time}-regression-prediction') + except: + mean_pred = None + + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + output_to_input_channel_map = map_output_to_input_channels(output_channels, input_channels) + + for idx, channel in enumerate(output_channels): + input_channel_idx = output_to_input_channel_map[idx] + plot_title = f"{format_time_str(curr_time)}: {getattr(channel, 'title', channel.name)}" + vmin, vmax = calculate_bounds( + target[idx,:,:], + prediction[:,idx,:,:], + baseline[input_channel_idx,:,:], + mean_pred[idx,:,:] if mean_pred is not None else None + ) + metadata = ChannelMeta.get(channel, vmin=vmin, vmax=vmax) + + if channel.name == "tp": + save_field( + "target", target[idx, :, :], metadata, files, channel, curr_time, + plot_func=plot_map_precipitation, title=plot_title + ) + save_field( + "baseline", baseline[input_channel_idx, :, :], metadata, files, channel, curr_time, + plot_func=plot_map_precipitation, title=plot_title + ) + if prediction.shape[0] > 1: + for member_idx in range(prediction.shape[0]): + save_field( + "prediction", prediction[member_idx, idx, :, :], metadata, files, channel, curr_time, + member=member_idx, plot_func=plot_map_precipitation, title=plot_title + ) + else: + save_field( + "prediction", prediction[0, idx, :, :], metadata, files, channel, curr_time, + plot_func=plot_map_precipitation, title=plot_title + ) + if mean_pred is not None: + save_field( + "mean-prediction", mean_pred[idx, :, :], metadata, files, channel, curr_time, + plot_func=plot_map_precipitation, title=plot_title + ) + continue + + # Plot target and baseline and regression prediction if available + save_field("target", target[idx, :, :], metadata, files, channel, curr_time, title=plot_title) + save_field("baseline", baseline[input_channel_idx, :, :], metadata, files, channel, curr_time, title=plot_title) + if mean_pred is not None: + save_field("mean-prediction", mean_pred[idx, :, :], metadata, files, channel, curr_time, title=plot_title) + + # Baseline MAE and ME + _, baseline_mae = compute_mae(baseline[input_channel_idx, :, :], target[idx, :, :]) + baseline_me = (baseline[input_channel_idx, :, :] - target[idx, :, :]) + save_field("baseline", baseline_mae.reshape(baseline[input_channel_idx, :, :].shape), metadata, files, channel, curr_time, kind="mae", cmap=metadata.cmap if channel.name not in ("10u", "10v", "2t") else 'viridis', vmin=0, vmax=metadata.err_vmax, title=plot_title) + save_field("baseline", baseline_me, metadata, files, channel, curr_time, kind="me", cmap=metadata.me_cmap, vmin=metadata.err_vmin, vmax=metadata.err_vmax, title=plot_title) + + # Regression prediction MAE and ME + if mean_pred is not None: + _, mean_mae = compute_mae(mean_pred[idx, :, :], target[idx, :, :]) + mean_me = (mean_pred[idx, :, :] - target[idx, :, :]) + save_field("mean-prediction", mean_mae.reshape(mean_pred[idx, :, :].shape), metadata, files, channel, curr_time, kind="mae", cmap=metadata.cmap if channel.name not in ("10u", "10v", "2t") else 'viridis', vmin=0, vmax=metadata.err_vmax, title=plot_title) + save_field("mean-prediction", mean_me, metadata, files, channel, curr_time, kind="me", cmap=metadata.me_cmap, vmin=metadata.err_vmin, vmax=metadata.err_vmax, title=plot_title) + + # Ensemble predictions + for member_idx in range(prediction.shape[0]): + member = prediction[member_idx, idx, :, :] + save_field("prediction", member, metadata, files, channel, curr_time, member=member_idx, title=plot_title) + _, prediction_mae = compute_mae(member, target[idx, :, :]) + save_field("prediction", prediction_mae.reshape(member.shape), metadata, files, channel, curr_time, member=member_idx, kind="mae", cmap=metadata.cmap if channel.name not in ("10u", "10v", "2t") else 'viridis', vmin=0, vmax=metadata.err_vmax, title=plot_title) + prediction_me = (member - target[idx, :, :]) + save_field("prediction", prediction_me, metadata, files, channel, curr_time, member=member_idx, kind="me", cmap=metadata.me_cmap, vmin=metadata.err_vmin, vmax=metadata.err_vmax, title=plot_title) + + # Plot Windspeed and direction + wind_channels = {ch.name: idx for idx, ch in enumerate(output_channels) if ch.name in ("10u", "10v")} + if "10u" in wind_channels and "10v" in wind_channels: + idx_10u = wind_channels["10u"] + idx_10v = wind_channels["10v"] + input_idx_10u = output_to_input_channel_map[idx_10u] + input_idx_10v = output_to_input_channel_map[idx_10v] + + # Compute windspeed and direction for target, baseline, prediction and mean prediction + target_wind_speed = np.hypot(target[idx_10u, :, :], target[idx_10v, :, :]) + target_wind_dir = wind_direction(target[idx_10u, :, :], target[idx_10v, :, :]) + baseline_wind_speed = np.hypot(baseline[input_idx_10u, :, :], baseline[input_idx_10v, :, :]) + baseline_wind_dir = wind_direction(baseline[input_idx_10u, :, :], baseline[input_idx_10v, :, :]) + prediction_wind_speed = np.hypot(prediction[:, idx_10u, :, :], prediction[:, idx_10v, :, :]) + prediction_wind_dir = wind_direction(prediction[:, idx_10u, :, :], prediction[:, idx_10v, :, :]) + if mean_pred is not None: + mean_wind_speed = np.hypot(mean_pred[idx_10u, :, :], mean_pred[idx_10v, :, :]) + mean_wind_dir = wind_direction(mean_pred[idx_10u, :, :], mean_pred[idx_10v, :, :]) + + plot_title_speed = f"{format_time_str(curr_time)}: FF10m" + plot_title_dir = f"{format_time_str(curr_time)}: DD10m" + + wind_meta = ChannelMeta.get("10u", vmin=0, vmax=10) + dir_meta = ChannelMeta.get("10u", vmin=0, vmax=360) + + # Save windspeed plots + save_field( + "FF10m-target", target_wind_speed, wind_meta, files, None, curr_time, + cmap="viridis", vmin=0, vmax=10, extend='max', + custom_path=files.wind_file("FF10m", curr_time, "FF10m-target"), + plot_func=plot_map, title=plot_title_speed + ) + save_field( + "FF10m-baseline", baseline_wind_speed, wind_meta, files, None, curr_time, + cmap="viridis", vmin=0, vmax=10, extend='max', + custom_path=files.wind_file("FF10m", curr_time, "FF10m-baseline"), + plot_func=plot_map, title=plot_title_speed + ) + for member_idx in range(prediction.shape[0]): + save_field( + "FF10m-prediction", prediction_wind_speed[member_idx], wind_meta, files, None, curr_time, + member=member_idx, cmap="viridis", vmin=0, vmax=10, extend='max', + custom_path=files.wind_file("FF10m", curr_time, "FF10m-prediction", member_idx), + plot_func=plot_map, title=plot_title_speed + ) + if mean_pred is not None: + save_field( + "FF10m-mean-prediction", mean_wind_speed, wind_meta, files, None, curr_time, + cmap="viridis", vmin=0, vmax=10, extend='max', + custom_path=files.wind_file("FF10m", curr_time, "FF10m-mean-prediction"), + plot_func=plot_map, title=plot_title_speed + ) + + # Save wind direction plots + save_field( + "DD10m-target", target_wind_dir, dir_meta, files, None, curr_time, + cmap="twilight", vmin=0, vmax=360, + custom_path=files.wind_file("DD10m", curr_time, "DD10m-target"), + plot_func=plot_map, title=plot_title_dir + ) + save_field( + "DD10m-baseline", baseline_wind_dir, dir_meta, files, None, curr_time, + cmap="twilight", vmin=0, vmax=360, + custom_path=files.wind_file("DD10m", curr_time, "DD10m-baseline"), + plot_func=plot_map, title=plot_title_dir + ) + for member_idx in range(prediction.shape[0]): + save_field( + "DD10m-prediction", prediction_wind_dir[member_idx], dir_meta, files, None, curr_time, + member=member_idx, cmap="twilight", vmin=0, vmax=360, + custom_path=files.wind_file("DD10m", curr_time, "DD10m-prediction", member_idx), + plot_func=plot_map, title=plot_title_dir + ) + if mean_pred is not None: + save_field( + "DD10m-mean-prediction", mean_wind_dir, dir_meta, files, None, curr_time, + cmap="twilight", vmin=0, vmax=360, + custom_path=files.wind_file("DD10m", curr_time, "DD10m-mean-prediction"), + plot_func=plot_map, title=plot_title_dir + ) + + logger.info("Image loading and plotting completed.") + +if __name__ == "__main__": + main() diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh new file mode 100644 index 0000000..ec522d9 --- /dev/null +++ b/src/hirad/eval_precip.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +#SBATCH --job-name="eval_precip" + +### HARDWARE ### +#SBATCH --partition=normal +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +#SBATCH --time=05:00:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/plots_precip.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + + # Diurnal cycle + python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/diurnal_cycle_precip_p99.py --config-name=generate_era_cosmo.yaml + + # Histograms + python src/hirad/eval/hist.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/probability_of_exceedance.py --config-name=generate_era_cosmo.yaml + + # Maps + python src/hirad/eval/map_precip_stats.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh new file mode 100644 index 0000000..df113e8 --- /dev/null +++ b/src/hirad/eval_wind.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +#SBATCH --job-name="eval_wind" + +### HARDWARE ### +#SBATCH --partition=normal +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +#SBATCH --time=12:00:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/plots_wind.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + + # Diurnal cycle + python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=generate_era_cosmo.yaml + + # Maps + python src/hirad/eval/map_wind_stats.py --config-name=generate_era_cosmo.yaml + + # Probability of exceedance + python src/hirad/eval/probability_of_exceedance_wind.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index 87c8979..1e3e9f5 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -5,21 +5,18 @@ ### HARDWARE ### #SBATCH --partition=debug #SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 #SBATCH --cpus-per-task=72 #SBATCH --time=00:30:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.err +#SBATCH --output=./logs/regression_generation.log ### ENVIRONMENT #### -#SBATCH --uenv=pytorch/v2.6.0:/user-environment -#SBATCH --view=default -#SBATCH -A a-a122 +#SBATCH -A a161 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -34,18 +31,19 @@ export MASTER_PORT=29500 echo "Master port: $MASTER_PORT" # Get number of physical cores using Python -PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -# Use SLURM_NTASKS (number of processes to be launched by torchrun) -LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute threads per process -OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) -export OMP_NUM_THREADS=$OMP_THREADS -echo "Physical cores: $PHYSICAL_CORES" -echo "Local processes: $LOCAL_PROCS" -echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun bash -c " - . ./train_env/bin/activate +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file diff --git a/src/hirad/generate_test.sh b/src/hirad/generate_test.sh new file mode 100644 index 0000000..0f12f9e --- /dev/null +++ b/src/hirad/generate_test.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH --job-name="corrdiff-test-genreate" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/generation_test.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/inference/generate.py --config-name=generate_era_cosmo_test.yaml +" \ No newline at end of file diff --git a/src/hirad/inference/README.md b/src/hirad/inference/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hirad/inference/__init__.py b/src/hirad/inference/__init__.py new file mode 100644 index 0000000..1593b3a --- /dev/null +++ b/src/hirad/inference/__init__.py @@ -0,0 +1,3 @@ +from .deterministic_sampler import deterministic_sampler +from .stochastic_sampler import stochastic_sampler +from .generator import Generator \ No newline at end of file diff --git a/src/hirad/utils/deterministic_sampler.py b/src/hirad/inference/deterministic_sampler.py similarity index 100% rename from src/hirad/utils/deterministic_sampler.py rename to src/hirad/inference/deterministic_sampler.py diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 35f856f..f90f974 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -4,48 +4,30 @@ from omegaconf import OmegaConf, DictConfig import torch import torch._dynamo -import nvtx import numpy as np import contextlib from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from concurrent.futures import ThreadPoolExecutor -from functools import partial -import cartopy.crs as ccrs -from matplotlib import pyplot as plt -from einops import rearrange -from torch.distributed import gather - - -from hydra.utils import to_absolute_path from hirad.models import EDMPrecondSuperResolution, UNet -from hirad.utils.patching import GridPatching2D -from hirad.utils.stochastic_sampler import stochastic_sampler -from hirad.utils.deterministic_sampler import deterministic_sampler -from hirad.utils.inference_utils import ( - get_time_from_range, - regression_step, - diffusion_step, -) +from hirad.inference import Generator +from hirad.utils.inference_utils import save_results_as_torch +from hirad.utils.function_utils import get_time_from_range from hirad.utils.checkpoint import load_checkpoint - -from hirad.utils.generate_utils import ( - get_dataset_and_sampler -) +from hirad.datasets import get_dataset_and_sampler_inference from hirad.utils.train_helpers import set_patch_shape -from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig) -> None: """Generate random dowscaled atmospheric states using the techniques described in the paper "Elucidating the Design Space of Diffusion-Based Generative Models". """ - torch.backends.cudnn.enabled = False + # torch.backends.cudnn.enabled = False # Initialize distributed manager DistributedManager.initialize() dist = DistributedManager() @@ -55,14 +37,6 @@ def main(cfg: DictConfig) -> None: logger = PythonLogger("generate") # General python logger logger0 = RankZeroLoggingWrapper(logger, dist) - # Handle the batch size - seeds = list(np.arange(cfg.generation.num_ensembles)) - num_batches = ( - (len(seeds) - 1) // (cfg.generation.seed_batch_size * dist.world_size) + 1 - ) * dist.world_size - all_batches = torch.as_tensor(seeds).tensor_split(num_batches) - rank_batches = all_batches[dist.rank :: dist.world_size] - # Synchronize if dist.world_size > 1: torch.distributed.barrier() @@ -81,32 +55,12 @@ def main(cfg: DictConfig) -> None: has_lead_time = cfg.generation["has_lead_time"] else: has_lead_time = False - dataset, sampler = get_dataset_and_sampler( + dataset, sampler = get_dataset_and_sampler_inference( dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time ) img_shape = dataset.image_shape() img_out_channels = len(dataset.output_channels()) - # Parse the patch shape - if cfg.generation.patching: - patch_shape_x = cfg.generation.patch_shape_x - patch_shape_y = cfg.generation.patch_shape_y - else: - patch_shape_x, patch_shape_y = None, None - patch_shape = (patch_shape_y, patch_shape_x) - use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) - if use_patching: - patching = GridPatching2D( - img_shape=img_shape, - patch_shape=patch_shape, - boundary_pix=cfg.generation.boundary_pix, - overlap_pix=cfg.generation.overlap_pix, - ) - logger0.info("Patch-based training enabled") - else: - patching = None - logger0.info("Patch-based training disabled") - # Parse the inference mode if cfg.generation.inference_mode == "regression": load_net_reg, load_net_res = True, False @@ -136,7 +90,6 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - #TODO fix to use channels_last which is optimal for H100 net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_res.use_fp16 = True @@ -177,118 +130,47 @@ def main(cfg: DictConfig) -> None: else: net_reg = None - # Reset since we are using a different mode. + # Reset since we are using a different mode. if cfg.generation.perf.use_torch_compile: torch._dynamo.reset() # Only compile residual network # Overhead of compiling regression network outweights any benefits if net_res: - net_res = torch.compile(net_res, mode="reduce-overhead") - - # Partially instantiate the sampler based on the configs - if cfg.sampler.type == "deterministic": - if cfg.generation.hr_mean_conditioning: - raise NotImplementedError( - "High-res mean conditioning is not yet implemented for the deterministic sampler" - ) - sampler_fn = partial( - deterministic_sampler, - num_steps=cfg.sampler.num_steps, - # num_ensembles=cfg.generation.num_ensembles, - solver=cfg.sampler.solver, + net_res = torch.compile(net_res) #, mode="reduce-overhead") + # removed reduce-overhead because it was breaking cuda graph compilation + generator = Generator( + net_reg=net_reg, + net_res=net_res, + batch_size=cfg.generation.seed_batch_size, + ensemble_size=cfg.generation.num_ensembles, + hr_mean_conditioning=cfg.generation.hr_mean_conditioning, + n_out_channels=img_out_channels, + inference_mode=cfg.generation.inference_mode, + dist=dist, ) - elif cfg.sampler.type == "stochastic": - sampler_fn = partial(stochastic_sampler, patching=patching) - else: - raise ValueError(f"Unknown sampling method {cfg.sampling.type}") - - - # Main generation definition - def generate_fn(image_lr, lead_time_label): - with nvtx.annotate("generate_fn", color="green"): - # (1, C, H, W) - image_lr = image_lr.to(memory_format=torch.channels_last) - - if net_reg: - with nvtx.annotate("regression_model", color="yellow"): - image_reg = regression_step( - net=net_reg, - img_lr=image_lr, - latents_shape=( - cfg.generation.seed_batch_size, - img_out_channels, - img_shape[0], - img_shape[1], - ), # (batch_size, C, H, W) - lead_time_label=lead_time_label, - ) - if net_res: - if cfg.generation.hr_mean_conditioning: - mean_hr = image_reg[0:1] - else: - mean_hr = None - with nvtx.annotate("diffusion model", color="purple"): - image_res = diffusion_step( - net=net_res, - sampler_fn=sampler_fn, - img_shape=img_shape, - img_out_channels=img_out_channels, - rank_batches=rank_batches, - img_lr=image_lr.expand( - cfg.generation.seed_batch_size, -1, -1, -1 - ), #.to(memory_format=torch.channels_last), - rank=dist.rank, - device=device, - mean_hr=mean_hr, - lead_time_label=lead_time_label, - ) - if cfg.generation.inference_mode == "regression": - image_out = image_reg - elif cfg.generation.inference_mode == "diffusion": - image_out = image_res - else: - image_out = image_reg[0:1,::] + image_res - - # Gather tensors on rank 0 - if dist.world_size > 1: - if dist.rank == 0: - gathered_tensors = [ - torch.zeros_like( - image_out, dtype=image_out.dtype, device=image_out.device - ) - for _ in range(dist.world_size) - ] - else: - gathered_tensors = None - torch.distributed.barrier() - gather( - image_out, - gather_list=gathered_tensors if dist.rank == 0 else None, - dst=0, - ) - - if dist.rank == 0: - if cfg.generation.inference_mode != "regression": - return torch.cat(gathered_tensors), image_reg[0:1,::] - return torch.cat(gathered_tensors), None - else: - return None, None - else: - #TODO do this for multi-gpu setting above too - if cfg.generation.inference_mode != "regression": - return image_out, image_reg - return image_out, None + # Parse the patch shape + if cfg.generation.patching: + patch_shape_x = cfg.generation.patch_shape_x + patch_shape_y = cfg.generation.patch_shape_y + else: + patch_shape_x, patch_shape_y = None, None + patch_shape = (patch_shape_y, patch_shape_x) + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + generator.initialize_patching(img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=cfg.generation.boundary_pix, + overlap_pix=cfg.generation.overlap_pix, + ) + sampler_params = cfg.sampler.params if "params" in cfg.sampler else {} + generator.initialize_sampler(cfg.sampler.type, **sampler_params) # generate images output_path = getattr(cfg.generation.io, "output_path", "./outputs") logger0.info(f"Generating images, saving results to {output_path}...") batch_size = 1 warmup_steps = min(len(times) - 1, 2) - # Generates model predictions from the input data using the specified - # `generate_fn`, and save the predictions to the provided NetCDF file. It iterates - # through the dataset using a data loader, computes predictions, and saves them along - # with associated metadata. torch_cuda_profiler = ( torch.cuda.profiler.profile() @@ -338,11 +220,13 @@ def elapsed_time(self, _): ): time_index += 1 if dist.rank == 0: - logger0.info(f"starting index: {time_index}") + logger0.info(f"starting index: {time_index} time: {times[sampler[time_index]]}") if time_index == warmup_steps: start.record() + savedir = os.path.join(output_path,f"{times[sampler[time_index]]}") + os.makedirs(savedir,exist_ok=True) # continue if lead_time_label: lead_time_label = lead_time_label[0].to(dist.device).contiguous() @@ -354,14 +238,24 @@ def elapsed_time(self, _): .to(memory_format=torch.channels_last) ) image_tar = image_tar.to(device=device).to(torch.float32) - image_out, image_reg = generate_fn(image_lr,lead_time_label) + # image_out, image_reg = generate_fn(image_lr,lead_time_label) + random_seed = cfg.generation.get("random_seed", None)+index if cfg.generation.get("randomize", False) and cfg.generation.get("random_seed", None) is not None else None + # print(f"On rank {dist.rank} using base random seed: {random_seed} for time index {time_index}") + image_out, image_reg = generator.generate( + image_lr, + lead_time_label, + randomize=cfg.generation.get("randomize", False), + random_seed=random_seed + ) + if dist.rank == 0: batch_size = image_out.shape[0] # write out data in a seperate thread so we don't hold up inferencing + writer_threads.append( writer_executor.submit( - save_images, - output_path, + save_results_as_torch, + savedir, times[sampler[time_index]], dataset, image_out.cpu().numpy(), @@ -397,90 +291,5 @@ def elapsed_time(self, _): logger0.info("Generation Completed.") -def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): - - os.makedirs(output_path, exist_ok=True) - - longitudes = dataset.longitude() - latitudes = dataset.latitude() - input_channels = dataset.input_channels() - output_channels = dataset.output_channels() - - target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) - prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) - baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) - if mean_pred is not None: - mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) - - - freqs = {} - power = {} - for idx, channel in enumerate(output_channels): - input_channel_idx = input_channels.index(channel) - - if channel.name=="tp": - target[idx,::] = prepare_precipitaiton(target[idx,:,:]) - prediction[idx,::] = prepare_precipitaiton(prediction[idx,:,:]) - baseline[input_channel_idx,:,:] = prepare_precipitaiton(baseline[input_channel_idx]) - if mean_pred is not None: - mean_pred[idx,::] = prepare_precipitaiton(mean_pred[idx,::]) - - _plot_projection(longitudes, latitudes, target[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-target.jpg')) - _plot_projection(longitudes, latitudes, prediction[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-prediction.jpg')) - _plot_projection(longitudes, latitudes, baseline[input_channel_idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-input.jpg')) - if mean_pred is not None: - _plot_projection(longitudes, latitudes, mean_pred[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-mean_prediction.jpg')) - - _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) - _, prediction_errors = compute_mae(prediction[idx,:,:], target[idx,:,:]) - if mean_pred is not None: - _, mean_prediction_errors = compute_mae(mean_pred[idx,:,:], target[idx,:,:]) - - - plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-baseline-error.jpg')) - plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-prediction-error.jpg')) - if mean_pred is not None: - plot_error_projection(mean_prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-mean-prediction-error.jpg')) - - b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0) - freqs['baseline'] = b_freq - power['baseline'] = b_power - #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) - t_freq, t_power = average_power_spectrum(target[idx,:,:].squeeze(), 2.0) - freqs['target'] = t_freq - power['target'] = t_power - p_freq, p_power = average_power_spectrum(prediction[idx,:,:].squeeze(), 2.0) - freqs['prediction'] = p_freq - power['prediction'] = p_power - if mean_pred is not None: - mp_freq, mp_power = average_power_spectrum(mean_pred[idx,:,:].squeeze(), 2.0) - freqs['mean_prediction'] = mp_freq - power['mean_prediction'] = mp_power - plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) - - -def prepare_precipitaiton(precip_array): - precip_array = np.clip(precip_array, 0, None) - epsilon = 1e-2 - precip_array = precip_array + epsilon - precip_array = np.log(precip_array) - # log_min, log_max = precip_array.min(), precip_array.max() - # precip_array = (precip_array-log_min)/(log_max-log_min) - return precip_array - - -def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): - - """Plot observed or interpolated data in a scatter plot.""" - # TODO: Refactor this somehow, it's not really generalizing well across variables. - fig = plt.figure() - fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) - p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) - ax.coastlines() - ax.gridlines(draw_labels=True) - plt.colorbar(p, label="K", orientation="horizontal") - plt.savefig(filename) - plt.close('all') - if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/hirad/inference/generator.py b/src/hirad/inference/generator.py new file mode 100644 index 0000000..d3c95c5 --- /dev/null +++ b/src/hirad/inference/generator.py @@ -0,0 +1,147 @@ +from typing import Callable +from functools import partial +import nvtx +import numpy as np +import random +import torch +from torch.distributed import gather +from hirad.utils.inference_utils import regression_step, diffusion_step +from hirad.distributed import DistributedManager +from hirad.utils.patching import GridPatching2D +from hirad.inference import stochastic_sampler, deterministic_sampler + +class Generator(): + def __init__(self, + net_reg: torch.nn.Module, + net_res: torch.nn.Module, + batch_size: int, + ensemble_size: int, + hr_mean_conditioning: bool, + n_out_channels: int, + inference_mode: str, + dist: DistributedManager, + ): + + self.net_reg = net_reg + self.net_res = net_res + self.batch_size = batch_size + self.hr_mean_conditioning = hr_mean_conditioning + self.n_out_channels = n_out_channels + self.inference_mode = inference_mode + self.ensemble_size = ensemble_size + self.dist = dist + self.get_rank_batches() + self.patching = None + + def get_rank_batches(self, seeds=None): + if seeds is None: + seeds = list(np.arange(self.ensemble_size)) + num_batches = ( + (len(seeds) - 1) // (self.batch_size * self.dist.world_size) + 1 + ) * self.dist.world_size + all_batches = torch.as_tensor(seeds).tensor_split(num_batches) + self.rank_batches = all_batches[self.dist.rank :: self.dist.world_size] + + def initialize_sampler(self, sampler_type, **sampler_args): + if sampler_type == "deterministic": + if self.hr_mean_conditioning: + raise NotImplementedError( + "High-res mean conditioning is not yet implemented for the deterministic sampler" + ) + self.sampler = partial( + deterministic_sampler, + **sampler_args + ) + elif sampler_type == "stochastic": + self.sampler = partial(stochastic_sampler, patching=self.patching) + else: + raise ValueError(f"Unknown sampling method {sampler_type}") + + def initialize_patching(self, img_shape, patch_shape, boundary_pix, overlap_pix): + self.patching = GridPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=boundary_pix, + overlap_pix=overlap_pix, + ) + + def generate(self, image_lr, lead_time_label=None, randomize=False, random_seed=None): + with nvtx.annotate("generate_fn", color="green"): + # (1, C, H, W) + img_shape = image_lr.shape[-2:] + + if self.net_reg: + with nvtx.annotate("regression_model", color="yellow"): + image_reg = regression_step( + net=self.net_reg, + img_lr=image_lr, + latents_shape=( + self.batch_size, + self.n_out_channels, + img_shape[0], + img_shape[1], + ), # (batch_size, C, H, W) + lead_time_label=lead_time_label, + ) + if self.net_res: + if self.hr_mean_conditioning: + mean_hr = image_reg[0:1] + else: + mean_hr = None + if randomize: + # Set random seed for numpy + if random_seed is not None: + np.random.seed((random_seed) % (1 << 31)) + seeds = np.random.randint(0, 1<<31, size=self.ensemble_size) + self.get_rank_batches(seeds=seeds) + with nvtx.annotate("diffusion model", color="purple"): + image_res = diffusion_step( + net=self.net_res, + sampler_fn=self.sampler, + img_shape=img_shape, + img_out_channels=self.n_out_channels, + rank_batches=self.rank_batches, + img_lr=image_lr.expand( + self.batch_size, -1, -1, -1 + ).to(memory_format=torch.channels_last), #.to(memory_format=torch.channels_last), + rank=self.dist.rank, + device=image_lr.device, + mean_hr=mean_hr, + lead_time_label=lead_time_label, + ) + if self.inference_mode == "regression": + image_out = image_reg[0:1,::] + elif self.inference_mode == "diffusion": + image_out = image_res + else: + image_out = image_reg[0:1,::] + image_res + + # Gather tensors on rank 0 + if self.dist.world_size > 1: + if self.dist.rank == 0: + gathered_tensors = [ + torch.zeros_like( + image_out, dtype=image_out.dtype, device=image_out.device + ) + for _ in range(self.dist.world_size) + ] + else: + gathered_tensors = None + + torch.distributed.barrier() + gather( + image_out, + gather_list=gathered_tensors if self.dist.rank == 0 else None, + dst=0, + ) + + if self.dist.rank == 0: + if self.inference_mode != "regression": + return torch.cat(gathered_tensors), image_reg[0:1,::] + return torch.cat(gathered_tensors)[0:1,::], None + else: + return None, None + else: + if self.inference_mode != "regression": + return image_out, image_reg[0:1,::] + return image_out, None diff --git a/src/hirad/utils/stochastic_sampler.py b/src/hirad/inference/stochastic_sampler.py similarity index 94% rename from src/hirad/utils/stochastic_sampler.py rename to src/hirad/inference/stochastic_sampler.py index 198fde4..24c5f7a 100644 --- a/src/hirad/utils/stochastic_sampler.py +++ b/src/hirad/inference/stochastic_sampler.py @@ -130,8 +130,8 @@ def stochastic_sampler( # Adjust noise levels based on what's supported by the network. # Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution. - sigma_min = max(sigma_min, net.sigma_min) - sigma_max = min(sigma_max, net.sigma_max) + sigma_min = max(sigma_min, net.module.sigma_min if hasattr(net, "module") else net.sigma_min) + sigma_max = min(sigma_max, net.module.sigma_max if hasattr(net, "module") else net.sigma_max) if patching is not None and not isinstance(patching, GridPatching2D): raise ValueError("patching must be an instance of GridPatching2D.") @@ -162,7 +162,8 @@ def stochastic_sampler( * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho t_steps = torch.cat( - [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + [net.module.round_sigma(t_steps) if hasattr(net, "module") else net.round_sigma(t_steps), + torch.zeros_like(t_steps[:1])] ) # t_N = 0 batch_size = img_lr.shape[0] @@ -179,10 +180,11 @@ def stochastic_sampler( # input and position padding + patching if patching: + # print(f"Input for generator beofre patching {x_lr.shape}") # Patched conditioning [x_lr, mean_hr] # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) x_lr = patching.apply(input=x_lr, additional_input=img_lr) - + # print(f"Input for generator after patching {x_lr.shape}") # Function to select the correct positional embedding for each patch def patch_embedding_selector(emb): # emb: (N_pe, image_shape_y, image_shape_x) @@ -198,7 +200,7 @@ def patch_embedding_selector(emb): x_cur = x_next # Increase noise temporarily. gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 - t_hat = net.round_sigma(t_cur + gamma * t_cur) + t_hat = net.module.round_sigma(t_cur + gamma * t_cur) if hasattr(net, "module") else net.round_sigma(t_cur + gamma * t_cur) x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) diff --git a/src/hirad/input_data/__init__.py b/src/hirad/input_data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hirad/input_data/copernicus-tp.sh b/src/hirad/input_data/copernicus-tp.sh new file mode 100644 index 0000000..943812f --- /dev/null +++ b/src/hirad/input_data/copernicus-tp.sh @@ -0,0 +1,5 @@ +#!/bin/sh + +pip install -e . +pip install anemoi.datasets +python src/hirad/input_data/read_tp.py diff --git a/src/hirad/input_data/copyanemoi.sh b/src/hirad/input_data/copyanemoi.sh new file mode 100644 index 0000000..361a9fb --- /dev/null +++ b/src/hirad/input_data/copyanemoi.sh @@ -0,0 +1,11 @@ +#!/bin/bash -l +# +#SBATCH --time=23:59:00 +#SBATCH --ntasks=1 +#SBATCH --partition=xfer + +echo -e "$SLURM_JOB_NAME started on $(date):\n $command $1 $2" +cp -rvn $1 $2 + +echo -e "$SLURM_JOB_NAME finished on $(date)\n" + diff --git a/src/hirad/input_data/cosmo-static.yaml b/src/hirad/input_data/cosmo-static.yaml new file mode 100644 index 0000000..9bece97 --- /dev/null +++ b/src/hirad/input_data/cosmo-static.yaml @@ -0,0 +1,8 @@ +dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr' +select: ['hsurf', 'lsm',] +# Static cosmo channels +trim_edge: 19 # Removes boundary +start: 2016-01-01 +# start: 2015-11-29 +end: 2016-01-01 +# end: 2020-12-31 \ No newline at end of file diff --git a/src/hirad/input_data/download_copernicus_tp.py b/src/hirad/input_data/download_copernicus_tp.py index 55031ee..6eee50c 100644 --- a/src/hirad/input_data/download_copernicus_tp.py +++ b/src/hirad/input_data/download_copernicus_tp.py @@ -5,10 +5,14 @@ "product_type": ["reanalysis"], "variable": ["total_precipitation"], "year": [ - "2016" + "2015", "2016", "2017", + "2018", "2019", "2020", ], "month": [ - "01", "02" + "01", "02", "03", + "04", "05", "06", + "07", "08", "09", + "10", "11", "12", ], "day": [ "01", "02", "03", @@ -34,7 +38,9 @@ "21:00", "22:00", "23:00" ], "data_format": "netcdf", - "download_format": "unarchived" + "download_format": "unarchived", + "grid": "N320", + "area": [60, 0, 40, 20] } client = cdsapi.Client() diff --git a/src/hirad/input_data/era.yaml b/src/hirad/input_data/era.yaml index 3234321..9018031 100644 --- a/src/hirad/input_data/era.yaml +++ b/src/hirad/input_data/era.yaml @@ -1,4 +1,7 @@ -dataset: '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr' +#dataset: '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr' +dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr' select: ['2t', '10u', '10v', 'tcw', 't_850', 'z_850', 'u_850', 'v_850', 't_500', 'z_500', 'u_500', 'v_500', 'tp'] # See table S2 from corrdiff paper for the inputs. -# Note: Bounding dates/area will be done in .py code. \ No newline at end of file +# Note: Bounding dates/area will be done in .py code. +start: 2020-01-01 +end: 2020-01-31 \ No newline at end of file diff --git a/src/hirad/input_data/interpolate_basic.py b/src/hirad/input_data/interpolate_basic.py index 7ec5801..5a77429 100644 --- a/src/hirad/input_data/interpolate_basic.py +++ b/src/hirad/input_data/interpolate_basic.py @@ -14,39 +14,45 @@ from scipy.interpolate import griddata import torch import multiprocessing +import xarray +from earthkit.geo.rotate import unrotate # Margin to use for ERA dataset (to avoid nans from interpolation at boundary) ERA_MARGIN_DEGREES = 1.0 -def _read_input(era_config_file: str, cosmo_config_file: str, bound_to_cosmo_area=True) -> tuple[Dataset, Dataset]: +def _read_era5_cosmo(era_config_file: str, cosmo_config_file: str) -> tuple[Dataset, Dataset]: """ Read both ERA and COSMO data, optionally bounding to the COSMO data area, and return the 2m temperature values for the time range under COSMO. """ # trim edge removes boundary + cosmo = read_cosmo_anemoi(cosmo_config_file) + # area = N, W, S, E + min_lat = min(cosmo.latitudes) - ERA_MARGIN_DEGREES + max_lat = max(cosmo.latitudes) + ERA_MARGIN_DEGREES + min_lon = min(cosmo.longitudes) - ERA_MARGIN_DEGREES + max_lon = max(cosmo.longitudes) + ERA_MARGIN_DEGREES + start_date = cosmo.metadata()['start_date'] + end_date = cosmo.metadata()['end_date'] + era = read_era5_anemoi(era_config_file, + start_date = start_date, end_date = end_date, + area=(max_lat, min_lon, min_lat, max_lon)) + return (era, cosmo) + +def read_cosmo_anemoi(cosmo_config_file: str): with open(cosmo_config_file) as cosmo_file: cosmo_config = yaml.safe_load(cosmo_file) cosmo = open_dataset(cosmo_config) + return cosmo + +def read_era5_anemoi(era_config_file: str, start_date = None, + end_date = None, area=None): with open(era_config_file) as era_file: era_config = yaml.safe_load(era_file) era = open_dataset(era_config) - # Subset the ERA dataset to have COSMO area/dates. - start_date = cosmo.metadata()['start_date'] - end_date = cosmo.metadata()['end_date'] - # load era5 2m-temperature in the time-range of cosmo - # area = N, W, S, E - if bound_to_cosmo_area: - min_lat = min(cosmo.latitudes) - ERA_MARGIN_DEGREES - max_lat = max(cosmo.latitudes) + ERA_MARGIN_DEGREES - min_lon = min(cosmo.longitudes) - ERA_MARGIN_DEGREES - max_lon = max(cosmo.longitudes) + ERA_MARGIN_DEGREES - era = open_dataset(era, start=start_date, end=end_date, - area=(max_lat, min_lon, min_lat, max_lon)) - else: - era = open_dataset(era, start=start_date, end=end_date) - - return (era, cosmo) - + era = open_dataset(era, start=start_date, end=end_date, + area=area) + return era def regrid(era_for_time: np.ndarray, input_grid: np.ndarray, output_grid: np.ndarray): # shape (channel, ensemble, grid) @@ -57,7 +63,7 @@ def regrid(era_for_time: np.ndarray, input_grid: np.ndarray, output_grid: np.nda interpolated_data[j,0,:] = regrid return interpolated_data -def _interpolate_task(i: int, era: Dataset, cosmo: Dataset, input_grid: np.ndarray, output_grid: np.ndarray, intermediate_files_path: str, outfile_plots_path: str = None, plot_indices=[0]): +def _interpolate_era5_cosmo_task(i: int, era: Dataset, cosmo: Dataset, input_grid: np.ndarray, output_grid: np.ndarray, intermediate_files_path: str, outfile_plots_path: str = None, plot_indices=[0]): logging.info('interpolating time point ' + _format_date(cosmo.dates[i])) interpolated_data = np.empty([era.shape[1], 1, cosmo.shape[3]]) for j in range(era.shape[1]): @@ -76,15 +82,15 @@ def _interpolate_task(i: int, era: Dataset, cosmo: Dataset, input_grid: np.ndarr logging.info(f'plotting {datestr} to {outfile_plots_path}') for j,var in enumerate(era.variables): # plot era original - _plot_and_save_projection(era.longitudes, era.latitudes, era[i, j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era.jpg') + plot_and_save_projection(era.longitudes, era.latitudes, era[i, j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era.jpg') - _plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, interpolated_data[j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era-interpolated.jpg') + plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, interpolated_data[j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era-interpolated.jpg') for j,var in enumerate(cosmo.variables): - _plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, cosmo[i, j, 0, :], f'{outfile_plots_path}{cosmo.variables[j]}-{datestr}-cosmo.jpg') + plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, cosmo[i, j, 0, :], f'{outfile_plots_path}{cosmo.variables[j]}-{datestr}-cosmo.jpg') -def _interpolate_basic(era: Dataset, cosmo: Dataset, intermediate_files_path: str, threaded = True, outfile_plots_path: str =None, plot_indices=[0]): +def _interpolate_era5_cosmo_basic(era: Dataset, cosmo: Dataset, intermediate_files_path: str, threaded = True, outfile_plots_path: str =None, plot_indices=[0]): """Perform simple interpolation from ERA5 to COSMO grid for all data points in the COSMO date range. Parameters: @@ -112,13 +118,13 @@ def _interpolate_basic(era: Dataset, cosmo: Dataset, intermediate_files_path: st if (threaded): pool = multiprocessing.Pool() for i in dates: - pool.apply_async(_interpolate_task, (i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices)) + pool.apply_async(_interpolate_era5_cosmo_task, (i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices)) pool.close() pool.join() else: for i in dates: - _interpolate_task(i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices) + _interpolate_era5_cosmo_task(i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices) return @@ -130,54 +136,30 @@ def _save_datetime_file(values: np.ndarray[np.intp], variables: np.ndarray, date filename = filepath + _format_date(date) torch.save(values, filename) -def _save_latlon_grid(dataset: Dataset, filename: str): +def save_anemoi_latlon_grid(dataset: Dataset, filename: str): grid = np.column_stack((dataset.latitudes, dataset.longitudes)) torch.save(grid, filename) -def _save_stats(dataset: Dataset, filename: str): +def save_anemoi_stats(dataset: Dataset, filename: str): torch.save(dataset.statistics, filename) -def _save_interpolation(values: np.ndarray[np.intp], filename: str): - """Output interpolated data to a given filename, in PyTorch tensor format.""" - torch_data = torch.from_numpy(values) - torch.save(torch_data, filename) - -def _get_plot_indices(era: Dataset, cosmo: Dataset) -> np.ndarray[np.intp]: - """ - Get indices of ERA5 data that is in the bounding rectangle of COSMO data. - This is useful for plotting in the case where read_input(..., bound_to_cosmo_area=False) was used. - In this case, one would then feed e.g. era.latitudes[indices] into _plot_projection. - """ - min_lat_cosmo = min(cosmo.latitudes) - max_lat_cosmo = max(cosmo.latitudes) - min_lon_cosmo = min(cosmo.longitudes) - max_lon_cosmo = max(cosmo.longitudes) - box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) - box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) - indices = np.where(box_lon*box_lat) - return indices - -def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None): - p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) +def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None, s = None): + p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax, s=s) ax.coastlines() - ax.gridlines(draw_labels=False) + ax.gridlines(draw_labels=True) plt.colorbar(p, orientation="horizontal") -def _plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): +def plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, projection=ccrs.PlateCarree(), cmap=None, vmin = None, vmax = None, s = None): """Plot observed or interpolated data in a scatter plot.""" # TODO: Refactor this somehow, it's not really generalizing well across variables. fig = plt.figure() - fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + fig, ax = plt.subplots(subplot_kw={"projection": projection}) logging.info(f'plotting values to {filename}') - plot_projection(ax, longitudes, latitudes, values, cmap, vmin, vmax) - #p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) - #ax.coastlines() - #ax.gridlines(draw_labels=True) - #plt.colorbar(p, orientation="horizontal") + plot_projection(ax, longitudes, latitudes, values, cmap, vmin, vmax, s) plt.savefig(filename) plt.close('all') -def interpolate_and_save(infile_era: str, infile_cosmo: str, outfile_data_path: str, threaded=True, outfile_plots_path: str = None, plot_indices=[0]): +def interpolate_era5_cosmo_and_save(infile_era: str, infile_cosmo: str, outfile_data_path: str, threaded=True, outfile_plots_path: str = None, plot_indices=[0]): """Read both ERA and COSMO data and perform basic interpolation. Save output into Pytorch format, and (optionally) plot ERA, COSMO, and interpolated data. @@ -210,30 +192,19 @@ def interpolate_and_save(infile_era: str, infile_cosmo: str, outfile_data_path: logging.info('Successfully read input') # Output stats and grid - _save_stats(era, os.path.join(outfile_data_path, "info/era-stats")) - _save_stats(cosmo, os.path.join(outfile_data_path, "info/cosmo-stats")) - _save_latlon_grid(cosmo, os.path.join(outfile_data_path, "info/cosmo-lat-lon")) - _save_latlon_grid(era, os.path.join(outfile_data_path, "info/era-lat-lon")) + save_anemoi_stats(era, os.path.join(outfile_data_path, "info/era-stats")) + save_anemoi_stats(cosmo, os.path.join(outfile_data_path, "info/cosmo-stats")) + save_anemoi_latlon_grid(cosmo, os.path.join(outfile_data_path, "info/cosmo-lat-lon")) + save_anemoi_latlon_grid(era, os.path.join(outfile_data_path, "info/era-lat-lon")) # Copy the .yaml files over for recording purposes shutil.copy(infile_cosmo, os.path.join(outfile_data_path, "info/cosmo.yaml")) shutil.copy(infile_era, os.path.join(outfile_data_path, "info/era.yaml")) # generate interpolated data - _interpolate_basic(era, cosmo, outfile_data_path, threaded=threaded, outfile_plots_path=outfile_plots_path, plot_indices=plot_indices) - -def plot_tp(path_6h: str, path_1h: str): - fig, axs = plt.subplots(2, 3, subplot_kw={"projection": ccrs.PlateCarree()}) + _interpolate_era5_cosmo_basic(era, cosmo, outfile_data_path, threaded=threaded, outfile_plots_path=outfile_plots_path, plot_indices=plot_indices) - logging.info(f'plotting values to {filename}') - p = ax.scatter(x=longitudes, y=latitudes, c=values) - ax.coastlines() - ax.gridlines(draw_labels=True) - plt.colorbar(p, label="absolute error", orientation="horizontal") - plt.savefig(filename) - plt.close('all') - def main(): # TODO: Do better arg parsing so it's not as easy to reverse era and cosmo configs. if len(sys.argv) < 4: @@ -242,13 +213,13 @@ def main(): infile_cosmo = sys.argv[2] output_directory = sys.argv[3] - logging.basicConfig( filename=os.path.join(output_directory, 'interpolate_basic.log'), format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') - interpolate_and_save(infile_era, infile_cosmo, output_directory, threaded=False, outfile_plots_path=os.path.join(output_directory, "plots/")) + + interpolate_era5_cosmo_and_save(infile_era, infile_cosmo, output_directory, threaded=False, outfile_plots_path=None) if __name__ == "__main__": main() diff --git a/src/hirad/input_data/interpolate_realch1.py b/src/hirad/input_data/interpolate_realch1.py new file mode 100644 index 0000000..e1a17bc --- /dev/null +++ b/src/hirad/input_data/interpolate_realch1.py @@ -0,0 +1,153 @@ +import hirad.input_data.interpolate_basic as interpolate_basic +import hirad.input_data.regrid_copernicus_tp as regrid_copernicus_tp + +import datetime +import logging +import os +import shutil +import sys +import yaml +import array + +from anemoi.datasets import open_dataset +from anemoi.datasets.data.dataset import Dataset +import netCDF4 +import numpy as np +from pandas import to_datetime +from scipy.interpolate import griddata +from meteodatalab.operators import regrid +import torch +import multiprocessing +import xarray + +# Margin to use for ERA dataset (to avoid nans from interpolation at boundary) +ERA_MARGIN_DEGREES = 1.0 +COPERNICUS_FILES = ['/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2015-2016.nc', + '/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2017-2018.nc', + '/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2019-2020.nc'] + +def _read_input(era_config_file: str, realch1_latlon_file: str) -> tuple[Dataset, Dataset, array.array, np.ndarray]: + """ + Read ERA data, and return the values for the area under REA-L-CH1 (plus a margin). + """ + # read the lat/lon data for REA-L-CH1 + realch1_latlon = torch.load(realch1_latlon_file) + # we expect the start and end dates to be specified in config. + with open(era_config_file) as era_file: + era_config = yaml.safe_load(era_file) + era = open_dataset(era_config) + # Subset the ERA dataset to have REAL-CH-1 area. + # area = N, W, S, E + min_lat = min(realch1_latlon[:,0]) - ERA_MARGIN_DEGREES + max_lat = max(realch1_latlon[:,0]) + ERA_MARGIN_DEGREES + min_lon = min(realch1_latlon[:,1]) - ERA_MARGIN_DEGREES + max_lon = max(realch1_latlon[:,1]) + ERA_MARGIN_DEGREES + era = open_dataset(era, + area=(max_lat, min_lon, min_lat, max_lon)) + copernicus_netcdf = [] + for f in COPERNICUS_FILES: + netcdf_data = netCDF4.Dataset(f) + copernicus_netcdf.append(netcdf_data) + return (era, copernicus_netcdf, realch1_latlon) + + +def main(): + # read REA-L-CH1 latlon grid + era_config_file = sys.argv[1] + realch1_latlon_file = sys.argv[2] + netcdf_file = sys.argv[3] + output_directory = sys.argv[4] + + logging.basicConfig( + filename=os.path.join(output_directory, 'interpolate_realch1.log'), + format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + + # Copy ERA yml file + shutil.copy(era_config_file, os.path.join(output_directory, 'info')) + + logging.info('reading realch1 lat/lon') + realch1_latlon = torch.load(realch1_latlon_file, weights_only=False) + realch1_lat = realch1_latlon[:,0] + realch1_lon = realch1_latlon[:,1] + # read ERA input + min_lat = min(realch1_lat) - interpolate_basic.ERA_MARGIN_DEGREES + max_lat = max(realch1_lat) + interpolate_basic.ERA_MARGIN_DEGREES + min_lon = min(realch1_lon) - interpolate_basic.ERA_MARGIN_DEGREES + max_lon = max(realch1_lon) + interpolate_basic.ERA_MARGIN_DEGREES + logging.info('reading era') + + era = interpolate_basic.read_era5_anemoi(era_config_file, + area=(max_lat, min_lon, min_lat, max_lon)) + era_grid = np.column_stack((era.longitudes, era.latitudes)) + realch1_grid = np.column_stack((realch1_lon, realch1_lat)) + logging.info(f'lat lon area is {min_lat}-{max_lat} {min_lon}-{max_lon}') + + # save era stats and lat lon + interpolate_basic.save_anemoi_latlon_grid(era, os.path.join(output_directory, 'info', 'era-lat-lon')) + interpolate_basic.save_anemoi_stats(era, os.path.join(output_directory, 'info', 'era-stats')) + + # read copernicus input for tp variable + logging.info('reading copernicus') + netcdf_data = netCDF4.Dataset(netcdf_file) + logging.info('processing netcdf data') + netcdf_latitudes, netcdf_longitudes = regrid_copernicus_tp.extract_lat_lon_025(netcdf_data) + netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) + + # TODO: start and end date functionality + netcdf_tp_values = regrid_copernicus_tp.extract_values(netcdf_data, 'tp', start_date=era.start_date, end_date=era.end_date) + assert(netcdf_tp_values.shape[0] == era.shape[0]) + # todo: incorporate this somehow + netcdf_tp_values = netcdf_tp_values.reshape((netcdf_tp_values.shape[0], 1,1, netcdf_tp_values.shape[1])) + + # save copernicus stats and lat lon + torch.save(np.column_stack((netcdf_grid[:,1], netcdf_grid[:,0])), + os.path.join(output_directory, 'info', 'copernicus-lat-lon')) + regrid_copernicus_tp.make_stats(os.path.join(output_directory, 'info'), + os.path.join(output_directory, 'info'), + netcdf_tp_values) + + # Iterate over ERA time range, which should be subsetted in configuration. + tp_index = era.variables.index('tp') + logging.info(f'tp index {tp_index}') + + plot_indices = {12} + + logging.info('interpolating') + #for i in plot_indices: + for i in range(era.shape[0]): + # T + t = era.dates[i] + # Get everything but the tp variable + era_for_time = era[i,:,:,:] + era_regridded = interpolate_basic.regrid(era_for_time, era_grid, realch1_grid) + # Regrid TP from copernicus + copernicus_regridded = interpolate_basic.regrid(netcdf_tp_values[i,:], netcdf_grid, realch1_grid) + # Concatenate and save + era_regridded[tp_index,:] = copernicus_regridded + #output=np.concatenate((era_regridded, copernicus_regridded), axis=0) + datefmt = interpolate_basic._format_date(t) + filename = os.path.join(output_directory, 'era-copernicus-interpolated', + datefmt) + torch.save(era_regridded, filename) + + if i in plot_indices: + realch1var = ['t2m', '10u', '10v', 'tp'] + realch1_data = torch.load(os.path.join(output_directory, 'realch1', datefmt), weights_only=False) + for j in range(realch1_data.shape[0]): + interpolate_basic.plot_and_save_projection(realch1_lon, realch1_lat, realch1_data[j,:], + os.path.join(output_directory, 'plots', + f'{datefmt}-{realch1var[j]}-realch1')) + for j in range(era_regridded.shape[0]): + interpolate_basic.plot_and_save_projection(era.longitudes, era.latitudes, era_for_time[j,:], + os.path.join(output_directory, 'plots', + f'{datefmt}-{era.variables[j]}-era')) + interpolate_basic.plot_and_save_projection(realch1_lon, realch1_lat, + era_regridded[j,0,:], + os.path.join(output_directory, 'plots', + f'{datefmt}-{era.variables[j]}-interpolated')) + return 0 + +if __name__ == "__main__": + main() diff --git a/src/hirad/input_data/read_tp.py b/src/hirad/input_data/read_tp.py deleted file mode 100644 index a13fed3..0000000 --- a/src/hirad/input_data/read_tp.py +++ /dev/null @@ -1,183 +0,0 @@ -import logging -import netCDF4 -from anemoi.datasets import open_dataset -import numpy as np -import yaml - -import matplotlib.pyplot as plt -import cartopy.crs as ccrs -import cartopy.feature as cfeature -from matplotlib.colors import BoundaryNorm, ListedColormap - -import interpolate_basic - -import sys -from pathlib import Path - -import os -print (os.getcwd()) - -sys.path.insert(0, Path(__file__).parent.as_posix()) - -ANEMOI_1H_FILENAME = "/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr" -ANEMOI_6H_FILENAME = "/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr" -COSMO_6H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr" -COSMO_1H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr" -COSMO_CONFIG_FILE="src/input_data/cosmo.yaml" -CDF_FILENAME = "8e49f064d738154bed136666ff72ae1c.nc" - - -LAT = np.arange(-4.42, 3.36 + 0.02, 0.02) -LON = np.arange(-6.82, 4.80 + 0.02, 0.02) -RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) - - -def extract_values(netcdf_data): - netcdf_lat = netcdf_data['latitude'][:] - netcdf_lon = netcdf_data['longitude'][:] - netcdf_tp = netcdf_data['tp'][:,:] - values = np.zeros((netcdf_tp.shape[0], netcdf_tp.shape[1]*netcdf_tp.shape[2])) - latitudes = np.zeros(values.shape[1]) - longitudes = np.zeros(values.shape[1]) - # You could probably get this by reshaping, but I can't be bothered. - for i in range(len(netcdf_lat)): - if i % 10 == 0: - print(i) - for j in range(len(netcdf_lon)): - grid_index = i * len(netcdf_lon) + j - values[:,grid_index] = netcdf_tp[:,i,j] - latitudes[grid_index] = netcdf_lat[i] - longitudes[grid_index] = netcdf_lon[j] - return values, latitudes, longitudes - -def plot_map(values: np.array, - filename: str, - label='', - title='', - vmin=None, - vmax=None, - cmap=None, - extend='neither', - norm=None, - ticks=None): - """Plot observed or interpolated data in a scatter plot.""" - logging.info(f'Creating map: {filename}') - - latitudes = LAT[RELAX_ZONE : RELAX_ZONE + 352] - longitudes = LON[RELAX_ZONE : RELAX_ZONE + 544] - lon2d, lat2d = np.meshgrid(longitudes, latitudes) - - fig, ax = plt.subplots( - figsize=(8, 6), - subplot_kw={"projection": ccrs.RotatedPole(pole_longitude=-170.0, - pole_latitude= 43.0)} - ) - values = values.reshape((len(latitudes), len(longitudes))) - contour = ax.pcolormesh( - lon2d, lat2d, values, - cmap=cmap, shading="auto", - norm=norm if norm else None, - vmin=None if norm else vmin, - vmax=None if norm else vmax, - ) - ax.coastlines() - ax.add_feature(cfeature.BORDERS, linewidth=1) - ax.gridlines(visible=False) - ax.set_xticks([]) - ax.set_yticks([]) - - plt.title(title) - cbar = plt.colorbar( - contour, - label=label, - orientation="horizontal", - extend=extend, - shrink=0.75, - pad=0.02 - ) - if ticks is not None: - cbar.set_ticks(ticks) - cbar.set_ticklabels([f'{tick:g}' for tick in ticks]) - - plt.tight_layout() - fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight") - plt.close(fig) - -def plot_map_precipitation(values, filename, title='', threshold=0.1, rfac=1000.0): - """Plot precipitation data with specific colormap and thresholds.""" - # Scale and mask values below threshold - values = rfac * values # m/h --> mm/h - values = np.ma.masked_where(values <= threshold, values) - - # Predefined colors and bounds specific for precipitation - colors = ['none', 'powderblue', 'dodgerblue', 'mediumblue', - 'forestgreen', 'limegreen', 'lawngreen', - 'yellow', 'gold', 'darkorange', 'red', - 'darkviolet', 'violet', 'thistle'] - bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 30, 50, 70, 100, 150, 200] - bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000] - - cmap = ListedColormap(colors) - norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False) - - plot_map( - values, filename, - cmap=cmap, - norm=norm, - ticks=bounds, - title=title, - label='mm/h', - extend='max' - ) - - -print(interpolate_basic.regrid) - -file_id = netCDF4.Dataset(CDF_FILENAME) -#anemoi1 = open_dataset(ANEMOI_1H_FILENAME) -#anemoi6 = open_dataset(ANEMOI_6H_FILENAME) -#with open(COSMO_CONFIG_FILE) as cosmo_file: -# cosmo_config = yaml.safe_load(cosmo_file) -#cosmo = open_dataset(cosmo_config) -cosmo1 = open_dataset(COSMO_1H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') -cosmo6 = open_dataset(COSMO_6H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') - - -output_grid= np.column_stack((cosmo1.longitudes, cosmo1.latitudes)) -print(output_grid.shape) -print(cosmo1[0,0,0,:].shape) - -plot_map_precipitation(values=cosmo1[0,:], filename="cosmo1.png") -plot_map_precipitation(values=cosmo6[0,:], filename="cosmo6.png") - -#fig = plt.figure() -#fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -#interpolate_basic.plot_projection(ax, longitudes=cosmo1.longitudes, latitudes=cosmo1.latitudes, values=cosmo1[0,:]) -#fig.savefig('cosmo1.png') - -#fig = plt.figure() -#fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -#interpolate_basic.plot_projection(ax, longitudes=cosmo1.longitudes, latitudes=cosmo1.latitudes, values=cosmo6[0,:]) -#fig.savefig('cosmo6.png') - - -values, latitudes, longitudes = extract_values(netcdf_data=file_id) -input_grid=np.column_stack((longitudes, latitudes)) -vals = values[0,:].reshape((1,1,values.shape[1])) -regrid=interpolate_basic.regrid(vals, input_grid, output_grid) -plot_map_precipitation(regrid, 'netcdf.png') - - -era1 = open_dataset(ANEMOI_1H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') -era6 = open_dataset(ANEMOI_6H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') -era_grid = np.column_stack((era1.longitudes, era1.latitudes)) -era1_regrid = interpolate_basic.regrid(era1[0,:], era_grid, output_grid) -plot_map_precipitation(era1_regrid, "era1.png") - - - -era6_regrid = interpolate_basic.regrid(era6[0,:], era_grid, output_grid) -plot_map_precipitation(era1_regrid, "era6.png") - - - diff --git a/src/hirad/input_data/realch1-all.yaml b/src/hirad/input_data/realch1-all.yaml new file mode 100644 index 0000000..6b0f776 --- /dev/null +++ b/src/hirad/input_data/realch1-all.yaml @@ -0,0 +1,26 @@ +dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr' +select: ['CLCH', 'CLCL', 'CLCM', 'CLCT', + 'FI_100', 'FI_1000', 'FI_150', 'FI_200', 'FI_250', 'FI_300', 'FI_400', + 'FI_50', 'FI_500', 'FI_600', 'FI_700', 'FI_850', 'FI_925', + 'FR_LAND', 'HSURF', + 'OMEGA_100', 'OMEGA_1000', 'OMEGA_150', 'OMEGA_200', 'OMEGA_250', + 'OMEGA_300', 'OMEGA_400', 'OMEGA_50', 'OMEGA_500', 'OMEGA_600', + 'OMEGA_700', 'OMEGA_850', 'OMEGA_925', + 'PLCOV', 'PMSL', 'PS', + 'QV_100', 'QV_1000', 'QV_150', 'QV_200', 'QV_250', 'QV_300', 'QV_400', + 'QV_50', 'QV_500', 'QV_600', 'QV_700', 'QV_850', 'QV_925', + 'SKC', 'SKT', 'SOILTYP', 'SSO_GAMMA', 'SSO_SIGMA', 'SSO_STDH', + 'SSO_THETA', 'TD_2M', 'TOT_PREC', 'TOT_PREC_6H', + 'T_100', 'T_1000', 'T_150', 'T_200', 'T_250', 'T_2M', 'T_300', 'T_400', + 'T_50', 'T_500', 'T_600', 'T_700', 'T_850', 'T_925', + 'U_100', 'U_1000', 'U_10M', 'U_150', 'U_200', 'U_250', 'U_300', 'U_400', + 'U_50', 'U_500', 'U_600', 'U_700', 'U_850', 'U_925', + 'V_100', 'V_1000', 'V_10M', 'V_150', 'V_200', 'V_250', 'V_300', 'V_400', + 'V_50', 'V_500', 'V_600', 'V_700', 'V_850', 'V_925', + 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', + 'insolation', 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude'] +# ALL REALCH1 CHANNELS. +start: 2020-01-01 +# start: 2015-11-29 +end: 2020-01-01 +# end: 2020-12-31 diff --git a/src/hirad/input_data/realch1.yaml b/src/hirad/input_data/realch1.yaml new file mode 100644 index 0000000..c536704 --- /dev/null +++ b/src/hirad/input_data/realch1.yaml @@ -0,0 +1,4 @@ +dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr' +select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1 +#start: 2020-01-01 +#end: 2020-01-01 diff --git a/src/hirad/input_data/regrid_copernicus_tp.py b/src/hirad/input_data/regrid_copernicus_tp.py new file mode 100644 index 0000000..adcea16 --- /dev/null +++ b/src/hirad/input_data/regrid_copernicus_tp.py @@ -0,0 +1,274 @@ +import logging +import netCDF4 +import xarray +import numpy as np +import torch +import datetime +from scipy.interpolate import griddata + +from hirad.eval.plotting import plot_map_precipitation, plot_scores_vs_t +from hirad.eval.metrics import absolute_error + +import interpolate_basic + +import sys +from pathlib import Path + +import os +print (os.getcwd()) + +sys.path.insert(0, Path(__file__).parent.as_posix()) + + +CDF_FILENAME_BALFRIN = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-janfeb2020.nc" +#CDF_FILENAME_CLARIDEN_TP = "/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc" +#CDF_FILENAME_CLARIDEN_TP = "/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2017-2018-n320.nc" +CDF_FILENAME_CLARIDEN_TP = "/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2015-2016.nc" + + +BASE_FILEPATH = "/capstor/store/" +INPUT_DATA_FILEPATH = "mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/" +OUTPUT_DATA_FILEPATH_ERA_INTERPOLATED = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp" +OUTPUT_DATA_FILEPATH_ERA = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-with-copernicus-tp" +TP_INDEX = 12 + +LAT = np.arange(-4.42, 3.36 + 0.02, 0.02) +LON = np.arange(-6.82, 4.80 + 0.02, 0.02) +RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) + +def extract_grib_values(grib_data): + grib_lat = grib_data['latitude'][:] + grib_lon = grib_data['longitude'][:] + grib_t2m = grib_data['t2m'][:] + +def extract_lat_lon_025(data): + logging.info('extracting lat/lon') + lat = data['latitude'][:] + lon = data['longitude'][:] + output_lat = np.zeros(len(lat)* len(lon)) + output_lon = np.zeros(len(lat) * len(lon)) + for i in range(len(lat)): + if i % 10 == 0: + print(i) + for j in range(len(lon)): + grid_index = i * len(lon) + j + output_lat[grid_index] = lat[i] + output_lon[grid_index] = lon[j] + return output_lat, output_lon + +def extract_lat_lon_n320(data): + lat = data['latitudes'][:] + lon = data['longitudes'][:] + logging.info('extracting lat/lon') + logging.info(f'lat lon shapes {lat.shape} {lon.shape}') + +# Get values for a given date range (inclusive) +def extract_values(data: netCDF4.Dataset, variable, start_date=None, end_date=None, area=None): + values = data[variable][:] + #if area: + # Not sure this is working. + # lat = data['latitude'][:] + # lon = data['longitude'][:] + # https://stackoverflow.com/questions/29135885/netcdf4-extract-for-subset-of-lat-lon + # latli = np.argmin( np.abs(lat - area[2])) + # latui = np.argmin( np.abs(lat - area[0])) + # lonli = np.argmin( np.abs(lon - area[1])) + # lonui = np.argmin( np.abs(lon - area[3])) + # lat = data['latitude'][latli:latui] + # lon = data['longitude'][lonli:lonui] + # values = data[variable][latli:latui,lonli:lonui] + date_indices = range(values.shape[0]) + if start_date and end_date: + date_indices = np.intersect1d( + np.where(data['valid_time'][:] >= start_date.astype(np.int64)), + np.where(data['valid_time'][:] <= end_date.astype(np.int64))) + values = values[date_indices,:] + if len(date_indices) == 0: + raise KeyError(f'{start_date} and {end_date} not valid range') + + return np.reshape(values, (values.shape[0], values.shape[1]*values.shape[2])) + +def reshape_to_cosmo(vals): + return vals.reshape((len(LAT)-RELAX_ZONE*2, len(LON)-RELAX_ZONE*2)) + +def calc_errors(cosmo1, era1): + make_plots = True + + prev_netcdf_regrid = [] + + netcdf_error = np.zeros(cosmo1.dates.shape) + era_norm_error = np.zeros(cosmo1.dates.shape) + netcdf_early_error = np.zeros(cosmo1.dates.shape) + netcdf_late_error = np.zeros(cosmo1.dates.shape) + + output_grid= np.column_stack((cosmo1.longitudes, cosmo1.latitudes)) + + for t in range(4): + #for t in range(len(cosmo1.dates)): + date = cosmo1.dates[t] + era_date = era1.dates[t] + if date != era_date: + logging.error('dates do not match: cosmo date: {date}, era date: {era_date}') + if date != netcdf_data['valid_time'][t]: + logging.error(f'dates do not match: cosmo date: {date}, netcdf: {netcdf_data["valid_time"][t]}') + + + # plot cosmo + if make_plots: + plot_map_precipitation(values=reshape_to_cosmo(cosmo1[t,:]), filename=f'plots/tp/{date}-cosmo1') + + # plot netcdf + netcdf_vals = netcdf_values[t,:].reshape((1,1,netcdf_values.shape[1])) + netcdf_regrid=interpolate_basic.regrid(netcdf_vals, netcdf_grid, output_grid) + if make_plots: + plot_map_precipitation(reshape_to_cosmo(netcdf_regrid), f'plots/tp/{date}-netcdf-refactor') + + # plot era + era_grid = np.column_stack((era1.longitudes, era1.latitudes)) + era1_regrid = interpolate_basic.regrid(era1[t,:], era_grid, output_grid) + if make_plots: + plot_map_precipitation(reshape_to_cosmo(era1_regrid/6), f'plots/tp/{date}-era1-norm') + + #if t % 6 == 0: + # if era6.dates[t//6] != date: + # logging.error(f'dates do not match: era1: {date}, era6: {era6.dates[t//6]}') + # era6_regrid = interpolate_basic.regrid(era6[t//6,:], era_grid, output_grid) + # plot_map_precipitation(reshape_to_cosmo(era6_regrid), f'plots/tp/{date}-era6') + + era_norm_error[t] = np.mean(absolute_error(era1_regrid/6, cosmo1[t,:])) + netcdf_error[t] = np.mean(absolute_error(netcdf_regrid, cosmo1[t,:])) + logging.info(f'era norm error: {era_norm_error[t]} netcdf err: {netcdf_error[t]}') + if t>0: + netcdf_early_error[t] = np.mean(absolute_error(prev_netcdf_regrid, cosmo1[t,:])) + netcdf_late_error[t-1] = np.mean(absolute_error(netcdf_regrid, cosmo1[t-1,:])) + logging.info(f'netcdf early err: {netcdf_early_error[t]}, netcdf late err: {netcdf_late_error[t-1]}') + prev_netcdf_regrid = netcdf_regrid + + maes = {} + maes['era normalized'] = era_norm_error + maes['copernicus'] = netcdf_error + maes['copernicus-early'] = netcdf_early_error + maes['copernicus-late'] = netcdf_late_error + plot_scores_vs_t(maes, times=cosmo1.dates, filename='plots/errors.png') + +def process_era_interpolated(netcdf_data, netcdf_tp_values, input_data_filepath, output_interpolated_filepath, netcdf_grid, cosmo_grid): + make_plots = False + #for t in range(100): + for t in range(netcdf_tp_values.shape[0]): + netcdf_date = netcdf_data['valid_time'][t] + date_filename = datetime.datetime.fromtimestamp(netcdf_date, datetime.UTC).strftime('%Y%m%d-%H%M') + t1 = datetime.datetime.now() + era_filename = os.path.join(input_data_filepath, "era-interpolated", date_filename) + output_filename = os.path.join(output_interpolated_filepath, date_filename) + if os.path.exists(era_filename): + requires_processing = False + if date_filename == '20200615-2100': + requires_processing = True + #if os.path.exists(output_filename): + # test the output to make sure it is not corrupted. + #requires_processing = False + #try: + # torch.load(output_filename, weights_only=False) + #except: + # requires_processing = True + if requires_processing: + era_data = torch.load(era_filename, weights_only=False) + t2 = datetime.datetime.now() + #if t % 100 == 0: + logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})') + interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], cosmo_grid, method='linear') + t3 = datetime.datetime.now() + if make_plots and max > 0.0002: + nans = np.count_nonzero(np.isnan(interpolated_tp)) + nonzeros = np.count_nonzero(interpolated_tp) + max = np.max(interpolated_tp) + logging.info(f'nonzeros: {nonzeros} nans: {nans} max: {max}') + if max > 0.0002: + cosmo_filename = os.path.join(input_data_filepath, "cosmo", date_filename) + cosmo_data = torch.load(cosmo_filename, weights_only=False) + plot_map_precipitation(reshape_to_cosmo(interpolated_tp), f'plots/tp/{date_filename}-netcdf-regrid') + # 3 is tp index in cosmo data + plot_map_precipitation(reshape_to_cosmo(cosmo_data[3,:]), f'plots/tp/{date_filename}-cosmo') + plot_map_precipitation(reshape_to_cosmo(era_data[TP_INDEX,0,:]/6), f'plots/tp/{date_filename}-era-norm') + era_data[TP_INDEX,0,:] = interpolated_tp + torch.save(era_data, output_filename) + t4 = datetime.datetime.now() + +def process_era(netcdf_data, netcdf_tp_values): + for t in range(netcdf_tp_values.shape[0]): + netcdf_date = netcdf_data['valid_time'][t] + date_filename = datetime.datetime.fromtimestamp(netcdf_date, datetime.UTC).strftime('%Y%m%d-%H%M') + era_filename = os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, "era", date_filename) + if os.path.exists(era_filename): + era_data = torch.load(era_filename, weights_only=False) + t2 = datetime.datetime.now() + logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})') + interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], era_grid, method='linear') + t3 = datetime.datetime.now() + era_data[TP_INDEX,0,:] = interpolated_tp + torch.save(era_data, os.path.join(OUTPUT_DATA_FILEPATH_ERA, date_filename)) + t4 = datetime.datetime.now() + +def extract_all_values(): + set1 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2015-2016.nc") + set2 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2017-2018.nc") + set3 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc") + set1_tp = extract_values(set1, 'tp') + set2_tp = extract_values(set2, 'tp') + set3_tp = extract_values(set3, 'tp') + all_tp = np.row_stack((set1_tp, set2_tp, set3_tp)) + print(all_tp.shape) + return all_tp + +# Get stats from ERA and replace the TP variable with stats from Copernicus +def make_stats(input_stats_directory: str, output_stats_directory: str, extracted_tp_values: np.ndarray): + stats = torch.load(os.path.join(input_stats_directory, 'era-stats'), weights_only=False) + print(stats) + #extracted_tp_values = extracted_tp_values.reshape(extracted_tp_values.shape[0] * extracted_tp_values.shape[1], 1) + flat_values = extracted_tp_values.flatten() + mean = np.mean(flat_values) + max = np.max(flat_values) + min = np.min(flat_values) + stdev = np.std(flat_values) + stats['mean'][TP_INDEX] = mean + stats['maximum'][TP_INDEX] = max + stats['minimum'][TP_INDEX] = min + stats['stdev'][TP_INDEX] = stdev + print(stats) + torch.save(stats, os.path.join(output_stats_directory, 'era-copernicus-stats')) + + +#process_era(netcdf_data, netcdf_tp_values) + + +def main(): + root = logging.getLogger() + root.setLevel(logging.INFO) + + logging.info('loading data') + netcdf_file = sys.argv[1] + input_data_filepath = sys.argv[2] + output_interpolated_filepath = sys.argv[3] + + netcdf_data = netCDF4.Dataset(netcdf_file) + logging.info(netcdf_data) + + logging.info('processing netcdf data') + netcdf_latitudes, netcdf_longitudes = extract_lat_lon_025(netcdf_data) + netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) + cosmo_grid = torch.load(os.path.join(input_data_filepath, 'info/cosmo-lat-lon'), weights_only=False) + cosmo_grid = np.column_stack((cosmo_grid[:,1], cosmo_grid[:,0])) + logging.info(f'netcdf grid shape {netcdf_grid.shape}') + logging.info(f'{netcdf_grid[1:10,:]}') + logging.info(f'cosmo grid shape {cosmo_grid.shape}') + logging.info(f'{cosmo_grid[1:10,:]}') + + netcdf_tp_values = extract_values(netcdf_data, 'tp') + + + process_era_interpolated(netcdf_data, netcdf_tp_values, input_data_filepath, output_interpolated_filepath, netcdf_grid, cosmo_grid) + + +if __name__ == "__main__": + main() + diff --git a/src/hirad/input_data/regrid_copernicus_tp.sh b/src/hirad/input_data/regrid_copernicus_tp.sh new file mode 100755 index 0000000..f7fc550 --- /dev/null +++ b/src/hirad/input_data/regrid_copernicus_tp.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +pip install -e . +pip install anemoi.datasets + +python src/hirad/input_data/regrid_copernicus_tp.py \ + /capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc \ + /capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/ \ + /capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp/ \ No newline at end of file diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py new file mode 100644 index 0000000..8e84593 --- /dev/null +++ b/src/hirad/input_data/regrid_realch1.py @@ -0,0 +1,183 @@ + +import datetime +import logging +import os +import shutil +import sys + +from anemoi.datasets import open_dataset +from anemoi.datasets.data.dataset import Dataset +import numpy as np +from meteodatalab.operators import regrid +import xarray as xr +from meteodatalab import ogd_api +from hirad.input_data.interpolate_basic import plot_and_save_projection +import yaml +import torch +from pandas import to_datetime + +import matplotlib.pyplot as plt +import cartopy.crs as ccrs +from earthkit.geo.rotate import unrotate + +TRIM_EDGE = 41 +XARRAY_BATCH = 4 + +# Take anemoi dataset and provide xarray dataarrays for a set of variables. +# returns: list of xarray dataarrays +def anemoi_to_xarray(anemoi_data: Dataset, start_date_index=-1, end_date_index=-1): + if start_date_index == -1: + start_date_index = 0 + if end_date_index == -1: + end_date_index = len(anemoi_data.dates) + lon = anemoi_data.longitudes + lat = anemoi_data.latitudes + eps = [0] # deterministic + time = anemoi_data.dates[start_date_index:end_date_index] + metadata = getMetadataFromOGD() + dataarrays = [] + variables = anemoi_data.variables + for var_index in range(anemoi_data.shape[1]): + logging.info(f'building xarray for {variables[var_index]}') + ds = xr.Dataset( + data_vars=dict( + variable=(["time", "eps", "cell"], + np.array(anemoi_data[start_date_index:end_date_index,var_index,:,:])), + ), + coords=dict( + eps=eps, + time=time, + lon=("cell", lon), + lat=("cell", lat), + ), + attrs=dict(description=f'xarray from anemoi dataset for {variables[var_index]}', + metadata=metadata), + ) + dataarrays.append(ds.to_dataarray()) + return dataarrays + +# Run a request to get the metadata, so that we can fake out an xarray. +def getMetadataFromOGD(): + lead_times = ["P0DT0H"] + req = ogd_api.Request( + collection="ogd-forecasting-icon-ch1", + variable="TOT_PREC", #assuming this won't cause problems; we're only using grid info + ref_time="latest", + perturbed=False, + lead_time=lead_times, + ) + tot_prec = ogd_api.get_from_ogd(req) + return tot_prec.metadata + +# get the geo coordinates for the rotated lat/lon dataset. +# returns np.array of lats and array of lons +def get_geo_coords(regridded_data: xr.Dataset, trim_edge=0): + xmin = regridded_data.metadata.get("longitudeOfFirstGridPointInDegrees") + xmax = regridded_data.metadata.get("longitudeOfLastGridPointInDegrees") + dx = regridded_data.metadata.get("iDirectionIncrementInDegrees") + ymin = regridded_data.metadata.get("latitudeOfFirstGridPointInDegrees") + ymax = regridded_data.metadata.get("latitudeOfLastGridPointInDegrees") + dy = regridded_data.metadata.get("jDirectionIncrementInDegrees") + y = np.arange(ymin,ymax+dy,dy) + x = np.arange(xmin,xmax+dx,dx) + # trim x and y according to trim_edge. + # (Have manually verified that when doing this, the outputs are the same as + # trimming post-projection) + y = y[trim_edge:len(y)-trim_edge] + x = x[trim_edge:len(x)-trim_edge] + sp_lat = regridded_data.metadata.get("latitudeOfSouthernPoleInDegrees") # -43.0. north_pole_lat = 43.0 + sp_lon = regridded_data.metadata.get("longitudeOfSouthernPoleInDegrees") # 10.0. north_pole_lon = 190.0 + xcoords = np.meshgrid(x,y)[0].flatten() + ycoords = np.meshgrid(x,y)[1].flatten() + # Expect south pole rotation of lon=10, latitude=-43 + logging.info(f'sp_lat = {sp_lat}, sp_lon = {sp_lon}') + rotated_crs = ccrs.RotatedPole( + pole_longitude=(sp_lon + 180) % 360, pole_latitude=sp_lat * -1 # 190, 43 + ) + # Project onto PlateCarree. Geodetic produces similar coordinates (within 10 nanometers) + dst_grid = ccrs.PlateCarree() + geo_coords = dst_grid.transform_points(rotated_crs, xcoords, ycoords) + lats = geo_coords[:,1] + lons = geo_coords[:,0] + return lats, lons + +def regridded_to_numpy(regridded: xr.DataArray, trim_edge=0): + # regridded is in shape (eps, time, variable, x, y) + # want this in shape (time, channel, ensemble, grid) + # First, trim the edge + data = regridded.data[:,:,:, + trim_edge:regridded.data.shape[3]-trim_edge, + trim_edge:regridded.data.shape[4]-trim_edge] + # reshape to (time,channel,ensemble,grid) + data = data.reshape(data.shape[1], data.shape[0], data.shape[3]*data.shape[4]) + return data + +def main(): + # yml format + realch1_config_file = sys.argv[1] + output_directory = sys.argv[2] + if not os.path.exists(output_directory): + os.mkdir(output_directory) + for subdir in ['info', 'plots', 'realch1']: + if not os.path.exists(os.path.join(output_directory, subdir)): + os.mkdir(os.path.join(output_directory, subdir)) + + logging.basicConfig( + filename=os.path.join(output_directory, 'regrid_realch1.log'), + format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + + # Copy the realch1.yml file to the info directory + shutil.copy(realch1_config_file, os.path.join(output_directory, 'info')) + + with open(realch1_config_file) as realch1_file: + realch1_config = yaml.safe_load(realch1_file) + realch1 = open_dataset(realch1_config) + variables = realch1.variables + + # Get the lat/lon info by regridding one variable + xarrays = anemoi_to_xarray(realch1, 0, 1) + regridded=regrid.icon2rotlatlon(xarrays[0]) + logging.info('getting geo coords') + lats, lons = get_geo_coords(regridded, trim_edge=TRIM_EDGE) + + # Save grid to file + grid = np.column_stack((lats, lons)) + torch.save(grid, os.path.join(output_directory, 'info', 'realch1-lat-lon')) + + # Split regridding into batches; too many time points seems to not scale well. + for i in range(0, len(realch1.dates), XARRAY_BATCH): + start_index = i + end_index = min(i+XARRAY_BATCH, len(realch1.dates)) + torch_data = np.zeros([end_index-start_index, len(realch1.variables), 1, len(lats)]) + + xarrays = anemoi_to_xarray(realch1, start_index, end_index) + for j in range(len(xarrays)): + logging.info(f'regridding {variables[j]} for time {realch1.dates[start_index]} to {realch1.dates[end_index]}') + xarray = xarrays[j] + start = datetime.datetime.now() + regridded=regrid.icon2rotlatlon(xarray) + end = datetime.datetime.now() + logging.info(f' regridding took {end-start} seconds') + torch_data[0:end_index-start_index,j,:,:] = regridded_to_numpy(regridded, trim_edge=TRIM_EDGE) + + logging.info('saving torch data') + for k in range(torch_data.shape[0]): + fmtdate = to_datetime(realch1.dates[start_index + k]).strftime('%Y%m%d-%H%M') + torch.save(torch_data[k,:], os.path.join(output_directory, 'realch1', fmtdate)) + + # Output plots for each variable, for first time point + for i in range(torch_data.shape[1]): + plot_and_save_projection(realch1.longitudes, realch1.latitudes, + realch1[0,i,0,:], + os.path.join(output_directory, 'plots', f'{variables[i]}-iconnative.png'), + s=0.005) + plot_and_save_projection(lons, lats, + torch_data[0,i,0,:], + os.path.join(output_directory, 'plots', f'{variables[i]}-rotlatlon.png'), + s=0.005) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/input_data/reprocess_change_tp_accum.py b/src/hirad/input_data/reprocess_change_tp_accum.py new file mode 100644 index 0000000..5642300 --- /dev/null +++ b/src/hirad/input_data/reprocess_change_tp_accum.py @@ -0,0 +1,58 @@ +import logging +import os +import sys + +import torch +import numpy as np + +# Reprocess ERA-interpolated data to exclude the tp variable. + +# 6H data is all channels, but with 6h accumulation +DATA_SOURCE_6H = "/capstor/scratch/cscs/mmcgloho/datasets/processed/era5-cosmo-1h-all-channels/era-interpolated" +STATS_FILEPATH_6H = "/capstor/scratch/cscs/mmcgloho/datasets/processed/era5-cosmo-1h-all-channels/info" +# 1h data is the updated +DATA_SOURCE_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/validation/era-interpolated-with-copernicus-tp/" +STATS_FILEPATH_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/validation/info" +OUTPUT_DIR = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/validation/era-interpolated" +OUTPUT_STATS_FILEPATH = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/validation/info/" +TP_INDEX_6H = 34 # in era-all.yaml +TP_INDEX_1H = 12 # in era.yaml + +def process(input_directory_6h: str, input_directory_1h: str, output_directory: str): + input_1h_filepath = os.path.join(input_directory_1h) + files = os.listdir(input_1h_filepath) + files.sort() + for f in range(len(files)): + if f % 100 == 0: + logging.info(f) + input_1h_file = os.path.join(input_directory_1h, files[f]) + input_6h_file = os.path.join(input_directory_6h, files[f]) + outfile = os.path.join(output_directory, files[f]) + in_data_6h = torch.load(input_6h_file, weights_only=False) + in_data_1h = torch.load(input_1h_file, weights_only=False) + in_data_6h[TP_INDEX_6H,:] = in_data_1h[TP_INDEX_1H,:] + torch.save(in_data_6h, outfile) + +def edit_info(info_6h_filepath: str, info_1h_filepath: str, output_filepath: str): + stats_6h = torch.load(os.path.join(info_6h_filepath, 'era-stats'), weights_only=False) + stats_1h = torch.load(os.path.join(info_1h_filepath, 'era-stats'), weights_only=False) + logging.info(f'6h stats: {stats_6h}') + logging.info(f'1h stats: {stats_1h}') + for k in stats_6h.keys(): + logging.info(k, stats_6h[k]) + tmp = stats_6h[k] + tmp[TP_INDEX_6H] = stats_1h[k][TP_INDEX_1H] + stats_6h[k] = tmp + logging.info(stats_6h) + torch.save(stats_6h, os.path.join(output_filepath, 'era-stats')) + +def main(): + logging.basicConfig( + format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + process(DATA_SOURCE_6H, DATA_SOURCE_1H, OUTPUT_DIR) + edit_info(STATS_FILEPATH_6H, STATS_FILEPATH_1H, OUTPUT_STATS_FILEPATH) + +if __name__ == "__main__": + main() diff --git a/src/hirad/input_data/reprocess_exclude_tp.py b/src/hirad/input_data/reprocess_exclude_tp.py new file mode 100644 index 0000000..190a474 --- /dev/null +++ b/src/hirad/input_data/reprocess_exclude_tp.py @@ -0,0 +1,43 @@ +import logging +import os +import sys + +import torch +import numpy as np + +# Reprocess ERA-interpolated data to exclude the tp variable. + +TP_INDEX = 12 + +def process(input_directory: str, output_directory: str): + input_filepath = os.path.join(input_directory, 'era-interpolated') + files = os.listdir(input_filepath) + files.sort() + for f in range(len(files)): + if f % 100 == 0: + logging.info(f) + outfile = os.path.join(output_directory, 'era-interpolated', files[f]) + if (not os.path.exists(outfile)) or (os.path.getsize(outfile) < 26000000): + in_data = torch.load(os.path.join(input_filepath, files[f]), weights_only=False) + out_data = in_data[0:TP_INDEX,:] + torch.save(out_data, outfile) + +def edit_info(input_filepath: str, output_filepath: str): + stats = torch.load(os.path.join(input_filepath, '/info', 'era-stats'), weights_only=False) + logging.info(stats) + for k in stats.keys(): + logging.info(k, stats[k]) + stats[k] = stats[k][0:TP_INDEX] + logging.info(stats) + torch.save(stats, os.path.join(output_filepath, "/info", "era-stats")) + +def main(): + root = logging.getLogger() + root.setLevel(logging.INFO) + input_directory = sys.argv[1] + output_directory = sys.argv[2] + process(input_directory, output_directory) + #edit_info(input_directory, output_directory) + +if __name__ == "__main__": + main() diff --git a/src/hirad/input_data/test_input_data.py b/src/hirad/input_data/test_input_data.py new file mode 100644 index 0000000..d8d0f6e --- /dev/null +++ b/src/hirad/input_data/test_input_data.py @@ -0,0 +1,92 @@ +import logging +import os +import sys + +import datetime +import torch +import numpy as np + + +from hirad.eval.plotting import plot_map_precipitation, plot_scores_vs_t + +def load_all_data(filepath: str): + files = os.listdir(filepath) + example = torch.load(os.path.join(filepath, files[0]), weights_only=False) + dims = (len(files),) + example.shape + data = np.zeros(dims) + for f in range(100): + #for f in range(len(files)): + if f % 100 == 0: + logging.info(f) + curr = torch.load(os.path.join(filepath, files[f]), weights_only=False) + data[f,:] = curr + return data + +def count_nans(data: np.array): + nans = np.count_nonzero(np.isnan(data)) + return nans + +def make_stats(filepath: str): + data = load_all_data(filepath) + stats = {} + num_channels = data.shape[1] + stats['mean'] = np.zeros(num_channels) + stats['stdev'] = np.zeros(num_channels) + stats['minimum'] = np.zeros(num_channels) + stats['maximum'] = np.zeros(num_channels) + for k in range(num_channels): + logging.info(f'channel {k}') + stats['mean'][k] = np.mean(data[:,k,:,:]) + stats['minimum'][k] = np.min(data[:,k,:,:]) + stats['maximum'][k] = np.max(data[:,k,:,:]) + stats['stdev'][k] = np.std(data[:,k,:,:]) + return stats + +def main(): + root = logging.getLogger() + root.setLevel(logging.INFO) + input_directory = sys.argv[1] + + logging.info(f'checking input directory {input_directory}') + + missing_data = [] + corrupt_data = [] + nan_data = [] + check_for_nans = False + check_for_corrupt = False + + + + files = os.listdir(input_directory) + files.sort() + start_date = datetime.datetime.strptime(files[0],'%Y%m%d-%H%M') + next_date = datetime.datetime.strptime(files[1],'%Y%m%d-%H%M') + delta = next_date - start_date + prev_date = start_date - delta + + for f in files: + curr_date = datetime.datetime.strptime(f,'%Y%m%d-%H%M') + if curr_date - prev_date != delta: + logging.info(f'missing data: {prev_date} and {curr_date} not {delta} apart') + expected_date = prev_date + delta + while (expected_date < curr_date): + missing_data.append(datetime.datetime.strftime(expected_date, '%Y%m%d-%H%M')) + expected_date = expected_date + delta + if check_for_corrupt: + try: + data = torch.load(os.path.join(input_directory, f), weights_only=False) + except: + logging.info(f'corrupt data: {curr_date}') + corrupt_data.append(curr_date) + if check_for_nans or curr_date == start_date: + if count_nans(data): + logging.info(f'data nans: {curr_date}') + nan_data.append(curr_date) + prev_date = curr_date + logging.info(f'missing data size {len(missing_data)}: {missing_data}') + logging.info(f'corrupt data size {len(corrupt_data)}: {corrupt_data}') + if check_for_nans: + logging.info(f'nan data: {nan_data}') + +if __name__ == "__main__": + main() diff --git a/src/hirad/losses/loss.py b/src/hirad/losses/loss.py index fb65960..2c2123b 100644 --- a/src/hirad/losses/loss.py +++ b/src/hirad/losses/loss.py @@ -380,6 +380,7 @@ def __call__( augment_pipe: Optional[ Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] ] = None, + lead_time_label: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Calculate and return the regression loss for @@ -439,7 +440,23 @@ def __call__( y_lr = y_tot[:, img_clean.shape[1] :, :, :] zero_input = torch.zeros_like(y, device=img_clean.device) - D_yn = net(zero_input, y_lr, force_fp32=False, augment_labels=augment_labels) + + if lead_time_label is not None: + D_yn = net( + zero_input, + y_lr, + force_fp32=False, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + zero_input, + y_lr, + force_fp32=False, + augment_labels=augment_labels, + ) + loss = weight * ((D_yn - y) ** 2) return loss @@ -518,6 +535,32 @@ def __init__( self.hr_mean_conditioning = hr_mean_conditioning self.y_mean = None + def get_noise_params(self, y: torch.Tensor) -> torch.Tensor: + """ + Compute the noise parameters to apply denoising score matching. + + Parameters + ---------- + y : torch.Tensor + Latent state of shape :math:`(B, *)`. Only used to determine the shape of + the noise and create tensors on the same device. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + - Noise ``n`` of shape :math:`(B, *)` to be added to the latent state. + - Noise level ``sigma`` of shape :math:`(B, 1, 1, 1)`. + - Weight ``weight`` of shape :math:`(B, 1, 1, 1)` to multiply the loss. + """ + # Sample noise level + rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=y.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + # Loss weight + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + # Sample noise + n = torch.randn_like(y) * sigma + return n, sigma, weight + def __call__( self, net: torch.nn.Module, @@ -712,35 +755,34 @@ def __call__( y = y_patched y_lr = y_lr_patched - # Noise - rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() - weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - - # Input + noise - latent = y + torch.randn_like(y) * sigma + # Add noise to the latent state + n, sigma, weight = self.get_noise_params(y) if lead_time_label is not None: D_yn = net( - latent, + y + n, y_lr, sigma, embedding_selector=None, - global_index=patching.global_index(batch_size, img_clean.device) - if patching is not None - else None, + global_index=( + patching.global_index(batch_size, img_clean.device) + if patching is not None + else None + ), lead_time_label=lead_time_label, augment_labels=augment_labels, ) else: D_yn = net( - latent, + y + n, y_lr, sigma, embedding_selector=None, - global_index=patching.global_index(batch_size, img_clean.device) - if patching is not None - else None, + global_index=( + patching.global_index(batch_size, img_clean.device) + if patching is not None + else None + ), augment_labels=augment_labels, ) loss = weight * ((D_yn - y) ** 2) diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index b00a477..c8fb306 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,3 +1,4 @@ +from .utils import weight_init from .layers import ( Linear, Conv2d, diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index d7e63d7..4da26b1 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -26,11 +26,11 @@ import numpy as np import nvtx import torch -import torch.cuda.amp as amp +import torch.amp as amp from einops import rearrange from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh -from hirad.utils.model_utils import weight_init +from hirad.models import weight_init _is_apex_available = False if torch.cuda.is_available(): @@ -700,7 +700,7 @@ def forward(self, x, emb): # w = AttentionOp.apply(q, k) # a = torch.einsum("nqk,nck->ncq", w, v) # Compute attention in one step - with amp.autocast(enabled=self.amp_mode): + with amp.autocast(x.device.type, enabled=self.amp_mode): attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = self.proj(attn.reshape(*x.shape)).add_(x) x = x * self.skip_scale diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index e0a447a..f1edc6b 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -69,9 +69,9 @@ class UNet(nn.Module): # TODO a lot of redundancy, need to clean up See Also -------- For information on model types and their usage: - :class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models - :class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings - :class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings + :class:`~models.song_unet.SongUNet`: Basic U-Net for diffusion models + :class:`~models.song_unet.SongUNetPosEmbd`: U-Net with positional embeddings + :class:`~models.song_unet.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings Please refer to the documentation of these classes for details on how to call and use these models directly. @@ -152,6 +152,44 @@ def __init__( **model_kwargs, ) + @property + def use_fp16(self): + """ + bool: Whether the model uses float16 precision. + + Returns + ------- + bool + True if the model is in float16 mode, False otherwise. + """ + return self._use_fp16 + + @use_fp16.setter + def use_fp16(self, value: bool): + """ + Set whether the model should use float16 precision. + + Parameters + ---------- + value : bool + If True, moves the model to torch.float16. If False, moves to torch.float32. + + Raises + ------ + ValueError + If `value` is not a boolean. + """ + # NOTE: allow 0/1 values for older checkpoints + if not (isinstance(value, bool) or value in [0, 1]): + raise ValueError( + f"`use_fp16` must be a boolean, but got {type(value).__name__}." + ) + self._use_fp16 = value + if value: + self.to(torch.float16) + else: + self.to(torch.float32) + def forward( self, x: torch.Tensor, diff --git a/src/hirad/snapshots.sh b/src/hirad/snapshots.sh new file mode 100644 index 0000000..16a12de --- /dev/null +++ b/src/hirad/snapshots.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +#SBATCH --job-name="eval_precip" + +### HARDWARE ### +#SBATCH --partition=normal +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +#SBATCH --time=05:00:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/plots_precip.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + + python src/hirad/eval/snapshots.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file diff --git a/src/hirad/train_diffusion.sh b/src/hirad/train_diffusion.sh index cf2f88f..10eb13f 100644 --- a/src/hirad/train_diffusion.sh +++ b/src/hirad/train_diffusion.sh @@ -1,25 +1,23 @@ #!/bin/bash -#SBATCH --job-name="testrun" +#SBATCH --job-name="corrdiff-second-stage" ### HARDWARE ### #SBATCH --partition=debug -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 #SBATCH --cpus-per-task=72 #SBATCH --time=00:30:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.err +#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/training_diffusion_test.log +#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/training_diffusion_test.err ### ENVIRONMENT #### -#SBATCH --uenv=pytorch/v2.6.0:/user-environment -#SBATCH --view=default -#SBATCH -A a-a122 +#SBATCH -A a122 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -32,14 +30,15 @@ export MASTER_ADDR export MASTER_PORT=29500 # Get number of physical cores using Python -PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute cores per process -OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) -export OMP_NUM_THREADS=$OMP_THREADS +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute cores per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun bash -c " - . ./train_env/bin/activate +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml " \ No newline at end of file diff --git a/src/hirad/train_diffusion_test.sh b/src/hirad/train_diffusion_test.sh new file mode 100644 index 0000000..9c9831f --- /dev/null +++ b/src/hirad/train_diffusion_test.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +#SBATCH --job-name="corrdiff-test-second-stage" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/training_diffusion_test.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +# Get master node. +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +export MASTER_ADDR +export MASTER_PORT=29500 + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute cores per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion_test.yaml +" \ No newline at end of file diff --git a/src/hirad/train_regression.sh b/src/hirad/train_regression.sh index c065477..1c88a10 100644 --- a/src/hirad/train_regression.sh +++ b/src/hirad/train_regression.sh @@ -1,25 +1,22 @@ #!/bin/bash -#SBATCH --job-name="testrun" +#SBATCH --job-name="corrdiff-first-stage" ### HARDWARE ### #SBATCH --partition=debug -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 #SBATCH --cpus-per-task=72 #SBATCH --time=00:30:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression.err +#SBATCH --output=./logs/regression_full_run.log ### ENVIRONMENT #### -#SBATCH --uenv=pytorch/v2.6.0:/user-environment -#SBATCH --view=default -#SBATCH -A a-a122 +#SBATCH -A a161 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -32,14 +29,18 @@ export MASTER_ADDR export MASTER_PORT=29500 # Get number of physical cores using Python -PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute cores per process -OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) -export OMP_NUM_THREADS=$OMP_THREADS - +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute cores per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun bash -c " - . ./train_env/bin/activate +# srun bash -c " +# . ./train_env/bin/activate +# python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml +# " +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml " \ No newline at end of file diff --git a/src/hirad/train_regression_test.sh b/src/hirad/train_regression_test.sh new file mode 100644 index 0000000..6cc0468 --- /dev/null +++ b/src/hirad/train_regression_test.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +### OUTPUT ### +#SBATCH --output=./logs/training_regression_test.log + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +# Get master node. +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +export MASTER_ADDR +export MASTER_PORT=29500 + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute cores per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 + +pip install -e . --no-dependencies +python src/hirad/training/train.py --config-name=training_era_cosmo_regression_test.yaml \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 12b6942..5f584ae 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -1,6 +1,8 @@ import os import time +from concurrent.futures import ThreadPoolExecutor + import psutil import hydra from omegaconf import DictConfig, OmegaConf @@ -9,22 +11,26 @@ import nvtx import torch from hydra.utils import to_absolute_path -from torch.utils.tensorboard import SummaryWriter +# from torch.utils.tensorboard import SummaryWriter from torch.nn.parallel import DistributedDataParallel +import mlflow # from torchinfo import summary from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \ set_patch_shape, compute_num_accumulation_rounds, \ - is_time_for_periodic_task, handle_and_clip_gradients + is_time_for_periodic_task, handle_and_clip_gradients, \ + init_mlflow from hirad.utils.checkpoint import load_checkpoint, save_checkpoint from hirad.utils.patching import RandomPatching2D -from hirad.models import UNet, EDMPrecondSuperResolution, EDMPrecondSR +from hirad.utils.function_utils import get_time_from_range +from hirad.utils.inference_utils import save_results_as_torch +from hirad.utils.env_info import get_env_info, flatten_dict +from hirad.models import UNet, EDMPrecondSuperResolution from hirad.losses import ResidualLoss, RegressionLoss, RegressionLossCE -from hirad.datasets import init_train_valid_datasets_from_config - -from matplotlib import pyplot as plt +from hirad.datasets import init_train_valid_datasets_from_config, get_dataset_and_sampler_inference +from hirad.inference import Generator torch._dynamo.reset() # Increase the cache size limit @@ -63,14 +69,20 @@ def main(cfg: DictConfig) -> None: DistributedManager.initialize() dist = DistributedManager() - if dist.rank==0: - writer = SummaryWriter(log_dir='tensorboard') + OmegaConf.resolve(cfg) + + if cfg.logging.method == "mlflow": + init_mlflow(cfg, dist) + if dist.world_size > 1: + torch.distributed.barrier() + elif cfg.logging.method is not None: + raise ValueError("The only available logging method is mlflow. To disable logging set the method to null.") + logger = PythonLogger("main") # general logger logger0 = RankZeroLoggingWrapper(logger, dist) # rank 0 logger - OmegaConf.resolve(cfg) dataset_cfg = OmegaConf.to_container(cfg.dataset) - if hasattr(cfg.dataset, "validation_path"): + if hasattr(cfg.dataset, "validation_path") and cfg.dataset.validation_path is not None: train_test_split = True else: train_test_split = False @@ -79,20 +91,38 @@ def main(cfg: DictConfig) -> None: fp16 = fp_optimizations == "fp16" enable_amp = fp_optimizations.startswith("amp") amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + logger0.info(f"Config is: {cfg}") logger0.info(f"Saving the outputs in {os.getcwd()}") checkpoint_dir = os.path.join( cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" ) if dist.rank==0 and not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # added creating checkpoint dir + visualize_checkpoints = False + if hasattr(cfg, "generation"): + visualization_dir = os.path.join( + cfg.training.io.get("checkpoint_dir", "."), "visualization" + ) + if dist.rank==0 and not os.path.exists(visualization_dir): + os.makedirs(visualization_dir) # added creating checkpoint dir + visualize_checkpoints = True + if cfg.training.hp.batch_size_per_gpu == "auto" and \ + cfg.training.hp.total_batch_size == "auto": + raise ValueError("batch_size_per_gpu and total_batch_size can't be both set to 'auto'.") if cfg.training.hp.batch_size_per_gpu == "auto": cfg.training.hp.batch_size_per_gpu = ( cfg.training.hp.total_batch_size // dist.world_size ) + elif cfg.training.hp.total_batch_size == "auto": + cfg.training.hp.total_batch_size = ( + cfg.training.hp.batch_size_per_gpu * dist.world_size + ) + + cur_nimg = load_checkpoint(path=checkpoint_dir) - set_seed(dist.rank) + set_seed(dist.rank + cur_nimg) configure_cuda_for_consistent_precision() - + # Instantiate the dataset data_loader_kwargs = { "pin_memory": True, @@ -110,8 +140,10 @@ def main(cfg: DictConfig) -> None: batch_size=cfg.training.hp.batch_size_per_gpu, seed=0, train_test_split=train_test_split, + sampler_start_idx=cur_nimg, ) logger0.info(f"Training on dataset with size {len(dataset)}") + logger0.info(f"Validating on dataset with size {len(validation_dataset)}") # Parse image configuration & update model args dataset_channels = len(dataset.input_channels()) @@ -120,15 +152,32 @@ def main(cfg: DictConfig) -> None: img_out_channels = len(dataset.output_channels()) if cfg.model.hr_mean_conditioning: img_in_channels += img_out_channels - + logger0.info(f"Training on dataset with grid size {img_shape[0]}x{img_shape[1]}, {img_in_channels} input channels and {img_out_channels} output channels.") + logger0.info(f"Input channels: {dataset.input_channels()}") + logger0.info(f"Output channels: {dataset.output_channels()}") if cfg.model.name == "lt_aware_ce_regression": prob_channels = dataset.get_prob_channel_index() #TODO figure out what prob_channel are and update dataloader else: prob_channels = None + if visualize_checkpoints: + # Parse the inference input times + if cfg.generation.times_range and cfg.generation.times: + raise ValueError("Either times_range or times must be provided, but not both") + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") + else: + times = cfg.generation.times + viz_dataset_cfg = OmegaConf.to_container(cfg.generation.dataset) + visualization_dataset, visualization_sampler = get_dataset_and_sampler_inference( + dataset_cfg=viz_dataset_cfg, times=times + ) + visualization_data_loader = torch.utils.data.DataLoader( + dataset=visualization_dataset, sampler=visualization_sampler, batch_size=1, pin_memory=True + ) + # Parse the patch shape - #TODO figure out patched diffusion and how to use it if ( cfg.model.name == "patched_diffusion" or cfg.model.name == "lt_aware_patched_diffusion" @@ -178,19 +227,12 @@ def main(cfg: DictConfig) -> None: if hasattr(cfg.model, "model_args"): # override defaults from config file model_args.update(OmegaConf.to_container(cfg.model.model_args)) - use_torch_compile = False - use_apex_gn = False - profile_mode = False + use_torch_compile = getattr(cfg.training.perf, "torch_compile", False) + use_apex_gn = getattr(cfg.training.perf, "use_apex_gn", False) + profile_mode = getattr(cfg.training.perf, "profile_mode", False) - if hasattr(cfg.training.perf, "torch_compile"): - use_torch_compile = cfg.training.perf.torch_compile - if hasattr(cfg.training.perf, "use_apex_gn"): - use_apex_gn = cfg.training.perf.use_apex_gn - model_args["use_apex_gn"] = use_apex_gn - - if hasattr(cfg.training.perf, "profile_mode"): - profile_mode = cfg.training.perf.profile_mode - model_args["profile_mode"] = profile_mode + model_args["use_apex_gn"] = use_apex_gn + model_args["profile_mode"] = profile_mode if enable_amp: model_args["amp_mode"] = enable_amp @@ -245,6 +287,8 @@ def main(cfg: DictConfig) -> None: # Enable distributed data parallel if applicable if dist.world_size > 1: + # if use_torch_compile: + # model = torch.compile(model) model = DistributedDataParallel( model, device_ids=[dist.local_rank], @@ -264,7 +308,6 @@ def main(cfg: DictConfig) -> None: raise FileNotFoundError( f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" ) - #regression_net = torch.nn.Module() #TODO Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name) #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json') @@ -294,12 +337,6 @@ def main(cfg: DictConfig) -> None: else: regression_net = None - # Compile the model and regression net if applicable - if use_torch_compile: - model = torch.compile(model) - if regression_net: - regression_net = torch.compile(regression_net) - # Compute the number of required gradient accumulation rounds # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size @@ -311,14 +348,17 @@ def main(cfg: DictConfig) -> None: batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") - patch_num = getattr(cfg.training.hp, "patch_num", 1) - max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", 1) - # calculate patch per iter - if hasattr(cfg.training.hp, "max_patch_per_gpu") and max_patch_per_gpu > 1: + patch_num = getattr(cfg.training.hp, "patch_num", 1) + if hasattr(cfg.training.hp, "max_patch_per_gpu"): + max_patch_per_gpu = cfg.training.hp.max_patch_per_gpu + if max_patch_per_gpu // batch_size_per_gpu < 1: + raise ValueError( + f"max_patch_per_gpu ({max_patch_per_gpu}) must be greater or equal to batch_size_per_gpu ({batch_size_per_gpu})." + ) max_patch_num_per_iter = min( patch_num, (max_patch_per_gpu // batch_size_per_gpu) - ) # Ensure at least 1 patch per iter + ) patch_iterations = ( patch_num + max_patch_num_per_iter - 1 ) // max_patch_num_per_iter @@ -326,7 +366,7 @@ def main(cfg: DictConfig) -> None: min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) for i in range(patch_iterations) ] - print( + logger0.info( f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}" ) else: @@ -395,6 +435,34 @@ def main(cfg: DictConfig) -> None: except: cur_nimg = 0 + # Compile the model and regression net if applicable + if use_torch_compile: + if dist.world_size==1: + model = torch.compile(model) + if regression_net: + regression_net = torch.compile(regression_net) + + # init the generator for inference to visualize checkpoint results + if visualize_checkpoints: + generator = Generator( + net_reg= model if regression_net is None else regression_net, + net_res= model if regression_net is not None else None, + batch_size=int(cfg.generation.num_ensembles//dist.world_size), + ensemble_size=cfg.generation.num_ensembles, + hr_mean_conditioning=cfg.model.hr_mean_conditioning, + n_out_channels=img_out_channels, + inference_mode="all" if regression_net is not None else "regression", + dist=dist, + ) + if use_patching: + generator.initialize_patching(img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=cfg.generation.boundary_pix, + overlap_pix=cfg.generation.overlap_pix, + ) + sampler_params = cfg.generation.sampler.params if "params" in cfg.generation.sampler else {} + generator.initialize_sampler(cfg.generation.sampler.type, **sampler_params) + ############################################################################ # MAIN TRAINING LOOP # ############################################################################ @@ -500,7 +568,6 @@ def main(cfg: DictConfig) -> None: / num_accumulation_rounds / len(patch_nums_iter) ) - loss_accum += loss / num_accumulation_rounds with nvtx.annotate(f"loss backward", color="yellow"): loss.backward() @@ -514,32 +581,11 @@ def main(cfg: DictConfig) -> None: ) average_loss = (loss_sum / dist.world_size).cpu().item() - # update running mean of average loss since last periodic task - average_loss_running_mean += ( - average_loss - average_loss_running_mean - ) / n_average_loss_running_mean - n_average_loss_running_mean += 1 - - if dist.rank == 0: - writer.add_scalar("training_loss", average_loss, cur_nimg) - writer.add_scalar( - "training_loss_running_mean", - average_loss_running_mean, - cur_nimg, - ) - - ptt = is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ) - if ptt: - # reset running mean of average loss - average_loss_running_mean = 0 - n_average_loss_running_mean = 1 + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 # Update weights. with nvtx.annotate("update weights", color="blue"): @@ -551,8 +597,8 @@ def main(cfg: DictConfig) -> None: if cur_nimg >= lr_rampup: g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // cfg.training.hp.lr_decay_rate) current_lr = g["lr"] - if dist.rank == 0: - writer.add_scalar("learning_rate", current_lr, cur_nimg) + if dist.rank == 0 and cfg.logging.method == "mlflow": + mlflow.log_metric("learning_rate", current_lr, cur_nimg) handle_and_clip_gradients( model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold ) @@ -562,38 +608,49 @@ def main(cfg: DictConfig) -> None: cur_nimg += cfg.training.hp.total_batch_size done = cur_nimg >= cfg.training.hp.training_duration - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - # Print stats if we crossed the printing threshold with this batch - tick_end_time = time.time() - fields = [] - fields += [f"samples {cur_nimg:<9.1f}"] - fields += [f"training_loss {average_loss:<7.2f}"] - fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] - fields += [f"learning_rate {current_lr:<7.8f}"] - fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] - fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] - fields += [ - f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" - ] - fields += [ - f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" - ] - if torch.cuda.is_available(): + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] fields += [ - f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.4f}" ] fields += [ - f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" ] - torch.cuda.reset_peak_memory_stats() - logger0.info(" ".join(fields)) + if torch.cuda.is_available(): + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + + if cfg.logging.method == "mlflow": + mlflow.log_metric("training_loss", average_loss, cur_nimg) + mlflow.log_metric( + "training_loss_running_mean", + average_loss_running_mean, + cur_nimg, + ) + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 with nvtx.annotate("validation", color="red"): # Validation @@ -690,9 +747,9 @@ def main(cfg: DictConfig) -> None: valid_loss_sum, op=torch.distributed.ReduceOp.SUM ) average_valid_loss = valid_loss_sum / dist.world_size - if dist.rank == 0: - writer.add_scalar( - "validation_loss", average_valid_loss, cur_nimg + if dist.rank == 0 and cfg.logging.method == "mlflow": + mlflow.log_metric( + "validation_loss", average_valid_loss, cur_nimg ) @@ -714,9 +771,101 @@ def main(cfg: DictConfig) -> None: epoch=cur_nimg, ) + # Visualize samples + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.visualization_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ) and visualize_checkpoints: + with nvtx.annotate("visualization", color="red"): + if dist.rank == 0: + writer_executor = ThreadPoolExecutor( + max_workers=cfg.generation.perf.num_writer_workers + ) + writer_threads = [] + + times = visualization_dataset.time() + time_index = -1 + output_paths_list = [] + with torch.no_grad(): + for index, (img_clean_viz, img_lr_viz, *lead_time_label_viz) in enumerate( + iter(visualization_data_loader) + ): + time_index += 1 + logger0.info(f"starting index: {time_index}") + + # continue + if lead_time_label_viz: + lead_time_label_viz = lead_time_label_viz[0].to(dist.device).contiguous() + else: + lead_time_label_viz = None + + if use_apex_gn: + img_clean_viz = img_clean_viz.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_viz = img_lr_viz.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean_viz = ( + img_clean_viz.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_viz = ( + img_lr_viz.to(dist.device) + .to(input_dtype) + .contiguous() + ) + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + image_pred_viz, image_reg_viz = generator.generate(img_lr_viz, lead_time_label_viz) + if dist.rank == 0: + # write out data in a seperate thread so we don't hold up inferencing + output_path = os.path.join(visualization_dir, f"{cur_nimg}_{times[visualization_sampler[time_index]]}") + output_paths_list.append(output_path) + if dist.rank==0 and not os.path.exists(output_path): + os.makedirs(output_path) + writer_threads.append( + writer_executor.submit( + save_results_as_torch, + output_path, + times[visualization_sampler[time_index]], + visualization_dataset, + image_pred_viz.cpu().numpy(), + img_clean_viz.cpu().numpy(), + img_lr_viz.cpu().numpy(), + image_reg_viz.cpu().numpy() if image_reg_viz is not None else None, + ) + ) + # make sure all the workers are done writing + if dist.rank == 0: + for thread in list(writer_threads): + thread.result() + writer_threads.remove(thread) + writer_executor.shutdown() + if cfg.logging.method == "mlflow" and cfg.logging.log_images: + for output_path in output_paths_list: + mlflow.log_artifacts(output_path, + os.path.join( + 'visualization', + os.path.split(output_path)[-1])) + + + if dist.world_size > 1: + torch.distributed.barrier() # Done. logger0.info("Training Completed.") - if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/hirad/training/train_dummy.py b/src/hirad/training/train_dummy.py new file mode 100644 index 0000000..cc47ea9 --- /dev/null +++ b/src/hirad/training/train_dummy.py @@ -0,0 +1,14 @@ +import hydra +from omegaconf import DictConfig, OmegaConf +import json + + +@hydra.main(version_base=None, config_path="../conf", config_name="training") +def main(cfg: DictConfig) -> None: + OmegaConf.resolve(cfg) + cfg = OmegaConf.to_container(cfg) + print(json.dumps(cfg, indent=2)) + # print(cfg.pretty()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index a346b16..28e6547 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -174,10 +174,15 @@ def save_checkpoint( Path(path).mkdir(parents=True, exist_ok=True) # == Saving model checkpoint == - if model: + if model is not None: if hasattr(model, "module"): # Strip out DDP layer model = model.module + + # Strip out optimization wrapper if exists + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + model = model._orig_mod + # Base name of model is meta.name unless pytorch model name = model.__class__.__name__ # Get full file path / name @@ -223,7 +228,7 @@ def save_checkpoint( def load_checkpoint( path: str, - model: torch.nn.Module, + model: torch.nn.Module = None, optimizer: Union[optimizer, None] = None, scheduler: Union[scheduler, None] = None, scaler: Union[scaler, None] = None, @@ -268,27 +273,33 @@ def load_checkpoint( ) return 0 - # == Loading model checkpoint == - if hasattr(model, "module"): - # Strip out DDP layer - model = model.module - # Base name of model is meta.name unless pytorch model - name = model.__class__.__name__ - # Get full file path / name - file_name = _get_checkpoint_filename( - path, name, index=epoch, - ) - if not Path(file_name).exists(): - checkpoint_logging.error( - f"Could not find valid model file {file_name}, skipping load" + if model is not None: + # == Loading model checkpoint == + if hasattr(model, "module"): + # Strip out DDP layer + model = model.module + # Strip out optimization wrapper if exists + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + model = model._orig_mod + checkpoint_logging.warning( + f"Model {model.__class__.__name__} is already compiled, consider loading first and then compiling." + ) + name = model.__class__.__name__ + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, ) - else: - # Load state dictionary - model.load_state_dict(torch.load(file_name, map_location=device)) + if not Path(file_name).exists(): + checkpoint_logging.warning( + f"Could not find valid model file {file_name}, skipping load" + ) + else: + # Load state dictionary + model.load_state_dict(torch.load(file_name, map_location=device)) - checkpoint_logging.success( - f"Loaded model state dictionary {file_name} to device {device}" - ) + checkpoint_logging.success( + f"Loaded model state dictionary {file_name} to device {device}" + ) # == Loading training checkpoint == checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") diff --git a/src/hirad/utils/env_info.py b/src/hirad/utils/env_info.py new file mode 100644 index 0000000..f1cfe64 --- /dev/null +++ b/src/hirad/utils/env_info.py @@ -0,0 +1,166 @@ +""" +Environment Introspection Module +================================ + +This module provides functionality to introspect the Python environment, listing all non-standard +library modules along with their versions and Git information if they are part of a Git repository. +It helps in understanding the environment setup by detailing the modules in use, their versions, and +relevant Git metadata. +""" + +import platform +import os +import sys +from types import ModuleType +from typing import Any, Optional +from git import Repo, InvalidGitRepositoryError + + +def get_module_version(module: ModuleType) -> Optional[str]: + """ + Retrieve the version of a module if available. + + This function attempts to get the version of a module by accessing its ``__version__`` + attribute. It checks if the version is a string to ensure correctness. + + :param module: The module whose version is to be retrieved. + :type module: ModuleType + :return: The version string if available and valid, otherwise None. + :rtype: Optional[str] + """ + version = getattr(module, "__version__", None) + if isinstance(version, str): + return version + return None + + +def get_git_info(path: str) -> Optional[dict[str, Any]]: + """ + Collect basic Git metadata for a given repository path. + + This function checks if the given path is part of a Git repository and collects metadata such as + the commit SHA, modified files, untracked files, and remote URLs. + + :param path: The path to check for Git repository metadata. + :type path: str + :return: A dictionary containing Git metadata if the path is a Git repository, otherwise None. + :rtype: Optional[Dict[str, Any]] + """ + try: + repo = Repo(path, search_parent_directories=True) + diff = "" + for diff_item in repo.index.diff(None, create_patch=True): + a_path = diff_item.a_blob.abspath if diff_item.a_blob else "" + b_path = diff_item.b_blob.abspath if diff_item.b_blob else "" + diff_content = diff_item.diff + if isinstance(diff_content, bytes): + diff_content = diff_content.decode("utf-8") + elif diff_content is None: + diff_content = "" + diff += f"--- a{a_path}\n+++ b{b_path}\n{diff_content}\n\n" + git_info = { + "sha1": repo.head.commit.hexsha, + "diff": diff, + "untracked_files": sorted(repo.untracked_files), + "remotes": [r.url for r in repo.remotes], + } + return git_info + except InvalidGitRepositoryError: + return None + + +def get_module_git_info(module: ModuleType) -> Optional[dict[str, Any]]: + """ + Get Git information for a module if its directory is a Git repository. + + This function determines the directory of the module and checks if it is part of a Git + repository. If it is, it collects and returns the Git metadata. + + :param module: The module to check for Git information. + :type module: ModuleType + :return: A dictionary containing Git metadata if the module is in a Git repository, otherwise + None. + :rtype: Optional[Dict[str, Any]] + """ + module_path = getattr(module, "__file__", None) + if module_path is None or not os.path.isabs(module_path): + return None + module_dir = os.path.dirname(module_path) + return get_git_info(module_dir) + + +def get_env_info(flatten: bool = True, exclude_prefixes: list[str] = None) -> tuple[dict[str, dict[str, Any]], str]: + """ + List all non-standard library modules with their versions and Git information if available. + + This function iterates over all loaded modules in the Python environment, filtering out + built-in modules. It collects version and Git information for each remaining module. + + :return: A tuple containing two elements: + - A dictionary mapping module names to their metadata + - A string containing concatenated Git diffs from all modules + :rtype: tuple[dict[str, dict[str, Any]], str] + :rtype: Dict[str, Dict[str, Any]] + """ + env_info = {} + diffs: list[str] = [] + + exclude_prefixes = exclude_prefixes or [] + + for name, module in sys.modules.copy().items(): + if name in sys.builtin_module_names or name.endswith(('.version','._version')): + continue + + if any(name == prefix or name.startswith(prefix + ".") for prefix in exclude_prefixes): + continue + + version = get_module_version(module) + git_info = get_module_git_info(module) + if version is None and git_info is None: + continue + + module_info: dict[str, Any] = {"version": version} + if git_info is not None: + module_diff = git_info.pop("diff") + if module_diff and module_diff not in diffs: + diffs.append(module_diff) + module_info["git"] = git_info + + env_info[name] = module_info + + env_info["python"] = {"version": platform.python_version()} + + diffs_str = "\n".join(diffs) + + if flatten: + return flatten_dict(env_info), diffs_str + + return env_info, diffs_str + + +def flatten_dict( + d: dict[str, Any], parent_key: str = "", sep: str = "." +) -> dict[str, Any]: + """ + Flatten a nested dictionary. + + This function recursively traverses a nested dictionary and flattens it into a single-level + dictionary with keys formed by concatenating the nested keys using a separator. + + :param d: The dictionary to flatten. + :type d: Dict[str, Any] + :param parent_key: The base key to use for concatenation. + :type parent_key: str + :param sep: The separator to use for concatenating keys. + :type sep: str + :return: A flattened dictionary. + :rtype: Dict[str, Any] + """ + items = {} + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.update(flatten_dict(v, new_key, sep=sep)) + else: + items[new_key] = v + return items diff --git a/src/hirad/utils/function_utils.py b/src/hirad/utils/function_utils.py index 347457c..6534782 100644 --- a/src/hirad/utils/function_utils.py +++ b/src/hirad/utils/function_utils.py @@ -17,19 +17,8 @@ """Miscellaneous utility classes and functions.""" -import contextlib -import ctypes import datetime -import fnmatch -import importlib -import inspect -import os -import re -import shutil -import sys -import types -import warnings -from typing import Any, Iterator, List, Tuple, Union +from typing import Iterator import cftime import numpy as np @@ -38,25 +27,6 @@ # ruff: noqa: E722 PERF203 S110 E713 S324 -class EasyDict(dict): # pragma: no cover - """ - Convenience class that behaves like a dict but allows access with the attribute - syntax. - """ - - def __getattr__(self, name: str) -> Any: - try: - return self[name] - except KeyError: - raise AttributeError(name) - - def __setattr__(self, name: str, value: Any) -> None: - self[name] = value - - def __delattr__(self, name: str) -> None: - del self[name] - - class StackedRandomGenerator: # pragma: no cover """ Wrapper for torch.Generator that allows specifying a different random seed @@ -96,32 +66,8 @@ def randint(self, *args, size, **kwargs): ) -def parse_int_list(s): # pragma: no cover - """ - Parse a comma separated list of numbers or ranges and return a list of ints. - Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] - """ - if isinstance(s, list): - return s - ranges = [] - range_re = re.compile(r"^(\d+)-(\d+)$") - for p in s.split(","): - m = range_re.match(p) - if m: - ranges.extend(range(int(m.group(1)), int(m.group(2)) + 1)) - else: - ranges.append(int(p)) - return ranges - - # Small util functions # ------------------------------------------------------------------------------------- -def convert_datetime_to_cftime( - time: datetime.datetime, cls=cftime.DatetimeGregorian -) -> cftime.DatetimeGregorian: - """Convert a Python datetime object to a cftime DatetimeGregorian object.""" - return cls(time.year, time.month, time.day, time.hour, time.minute, time.second) - def time_range( start_time: datetime.datetime, @@ -135,417 +81,30 @@ def time_range( yield t t += step +def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"): + """Generates a list of times within a given range. -def format_time(seconds: Union[int, float]) -> str: # pragma: no cover - """Convert the seconds to human readable string with days, hours, minutes and seconds.""" - s = int(np.rint(seconds)) - - if s < 60: - return "{0}s".format(s) - elif s < 60 * 60: - return "{0}m {1:02}s".format(s // 60, s % 60) - elif s < 24 * 60 * 60: - return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) - else: - return "{0}d {1:02}h {2:02}m".format( - s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60 - ) - - -def format_time_brief(seconds: Union[int, float]) -> str: # pragma: no cover - """Convert the seconds to human readable string with days, hours, minutes and seconds.""" - s = int(np.rint(seconds)) - - if s < 60: - return "{0}s".format(s) - elif s < 60 * 60: - return "{0}m {1:02}s".format(s // 60, s % 60) - elif s < 24 * 60 * 60: - return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) - else: - return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) - - -def tuple_product(t: Tuple) -> Any: # pragma: no cover - """Calculate the product of the tuple elements.""" - result = 1 - - for v in t: - result *= v - - return result - - -_str_to_ctype = { - "uint8": ctypes.c_ubyte, - "uint16": ctypes.c_uint16, - "uint32": ctypes.c_uint32, - "uint64": ctypes.c_uint64, - "int8": ctypes.c_byte, - "int16": ctypes.c_int16, - "int32": ctypes.c_int32, - "int64": ctypes.c_int64, - "float32": ctypes.c_float, - "float64": ctypes.c_double, -} - - -def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: # pragma: no cover - """ - Given a type name string (or an object having a __name__ attribute), return - matching Numpy and ctypes types that have the same size in bytes. - """ - type_str = None - - if isinstance(type_obj, str): - type_str = type_obj - elif hasattr(type_obj, "__name__"): - type_str = type_obj.__name__ - elif hasattr(type_obj, "name"): - type_str = type_obj.name - else: - raise RuntimeError("Cannot infer type name from input") - - if type_str not in _str_to_ctype.keys(): - raise ValueError("Unknown type name: " + type_str) - - my_dtype = np.dtype(type_str) - my_ctype = _str_to_ctype[type_str] - - if my_dtype.itemsize != ctypes.sizeof(my_ctype): - raise ValueError( - "Numpy and ctypes types for '{}' have different sizes!".format(type_str) - ) - - return my_dtype, my_ctype - - -# Functionality to import modules/objects by name, and call functions by name -# ------------------------------------------------------------------------------------- - - -def get_module_from_obj_name( - obj_name: str, -) -> Tuple[types.ModuleType, str]: # pragma: no cover - """ - Searches for the underlying module behind the name to some python object. - Returns the module and the object name (original name with module part removed). - """ - - # allow convenience shorthands, substitute them by full names - obj_name = re.sub("^np.", "numpy.", obj_name) - obj_name = re.sub("^tf.", "tensorflow.", obj_name) - - # list alternatives for (module_name, local_obj_name) - parts = obj_name.split(".") - name_pairs = [ - (".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1) - ] - - # try each alternative in turn - for module_name, local_obj_name in name_pairs: - try: - module = importlib.import_module(module_name) # may raise ImportError - get_obj_from_module(module, local_obj_name) # may raise AttributeError - return module, local_obj_name - except: - pass - - # maybe some of the modules themselves contain errors? - for module_name, _local_obj_name in name_pairs: - try: - importlib.import_module(module_name) # may raise ImportError - except ImportError: - if not str(sys.exc_info()[1]).startswith( - "No module named '" + module_name + "'" - ): - raise - - # maybe the requested attribute is missing? - for module_name, local_obj_name in name_pairs: - try: - module = importlib.import_module(module_name) # may raise ImportError - get_obj_from_module(module, local_obj_name) # may raise AttributeError - except ImportError: - pass - - # we are out of luck, but we have no idea why - raise ImportError(obj_name) - - -def get_obj_from_module( - module: types.ModuleType, obj_name: str -) -> Any: # pragma: no cover - """ - Traverses the object name and returns the last (rightmost) python object. - """ - if obj_name == "": - return module - obj = module - for part in obj_name.split("."): - obj = getattr(obj, part) - return obj - - -def get_obj_by_name(name: str) -> Any: # pragma: no cover - """ - Finds the python object with the given name. - """ - module, obj_name = get_module_from_obj_name(name) - return get_obj_from_module(module, obj_name) - - -def call_func_by_name( - *args, func_name: str = None, **kwargs -) -> Any: # pragma: no cover - """ - Finds the python object with the given name and calls it as a function. - """ - if func_name is None: - raise ValueError("func_name must be specified") - func_obj = get_obj_by_name(func_name) - if not callable(func_obj): - raise ValueError(func_name + " is not callable") - return func_obj(*args, **kwargs) - - -def construct_class_by_name( - *args, class_name: str = None, **kwargs -) -> Any: # pragma: no cover - """ - Finds the python class with the given name and constructs it with the given - arguments. - """ - return call_func_by_name(*args, func_name=class_name, **kwargs) - - -def get_module_dir_by_obj_name(obj_name: str) -> str: # pragma: no cover - """ - Get the directory path of the module containing the given object name. - """ - module, _ = get_module_from_obj_name(obj_name) - return os.path.dirname(inspect.getfile(module)) - - -def is_top_level_function(obj: Any) -> bool: # pragma: no cover - """ - Determine whether the given object is a top-level function, i.e., defined at module - scope using 'def'. - """ - return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ - - -def get_top_level_function_name(obj: Any) -> str: # pragma: no cover - """ - Return the fully-qualified name of a top-level function. - """ - if not is_top_level_function(obj): - raise ValueError("Object is not a top-level function") - module = obj.__module__ - if module == "__main__": - module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] - return module + "." + obj.__name__ - - -# File system helpers -# ------------------------------------------------------------------------------------------ - - -def list_dir_recursively_with_ignore( - dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False -) -> List[Tuple[str, str]]: # pragma: no cover - """ - List all files recursively in a given directory while ignoring given file and - directory names. Returns list of tuples containing both absolute and relative paths. - """ - if not os.path.isdir(dir_path): - raise RuntimeError(f"Directory does not exist: {dir_path}") - base_name = os.path.basename(os.path.normpath(dir_path)) - - if ignores is None: - ignores = [] - - result = [] - - for root, dirs, files in os.walk(dir_path, topdown=True): - for ignore_ in ignores: - dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] - - # dirs need to be edited in-place - for d in dirs_to_remove: - dirs.remove(d) - - files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] - - absolute_paths = [os.path.join(root, f) for f in files] - relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] - - if add_base_to_relative: - relative_paths = [os.path.join(base_name, p) for p in relative_paths] - - if len(absolute_paths) != len(relative_paths): - raise ValueError("Number of absolute and relative paths do not match") - result += zip(absolute_paths, relative_paths) + Args: + times_range: A list containing start time, end time, and optional interval (hours). + time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S"). - return result - - -def copy_files_and_create_dirs( - files: List[Tuple[str, str]] -) -> None: # pragma: no cover - """ - Takes in a list of tuples of (src, dst) paths and copies files. - Will create all necessary directories. + Returns: + A list of times within the specified range. """ - for file in files: - target_dir_name = os.path.dirname(file[1]) - - # will create all intermediate-level directories - if not os.path.exists(target_dir_name): - os.makedirs(target_dir_name) - - shutil.copyfile(file[0], file[1]) - -# ---------------------------------------------------------------------------- -# Cached construction of constant tensors. Avoids CPU=>GPU copy when the -# same constant is used multiple times. - -_constant_cache = dict() - - -def constant( - value, shape=None, dtype=None, device=None, memory_format=None -): # pragma: no cover - """Cached construction of constant tensors""" - value = np.asarray(value) - if shape is not None: - shape = tuple(shape) - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device("cpu") - if memory_format is None: - memory_format = torch.contiguous_format - - key = ( - value.shape, - value.dtype, - value.tobytes(), - shape, - dtype, - device, - memory_format, + start_time = datetime.datetime.strptime(times_range[0], time_format) + end_time = datetime.datetime.strptime(times_range[1], time_format) + interval = ( + datetime.timedelta(hours=times_range[2]) + if len(times_range) > 2 + else datetime.timedelta(hours=1) ) - tensor = _constant_cache.get(key, None) - if tensor is None: - tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) - if shape is not None: - tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) - tensor = tensor.contiguous(memory_format=memory_format) - _constant_cache[key] = tensor - return tensor - - -# ---------------------------------------------------------------------------- -# Replace NaN/Inf with specified numerical values. - -try: - nan_to_num = torch.nan_to_num # 1.8.0a0 -except AttributeError: - - def nan_to_num( - input, nan=0.0, posinf=None, neginf=None, *, out=None - ): # pylint: disable=redefined-builtin # pragma: no cover - """Replace NaN/Inf with specified numerical values""" - if not isinstance(input, torch.Tensor): - raise TypeError("input should be a Tensor") - if posinf is None: - posinf = torch.finfo(input.dtype).max - if neginf is None: - neginf = torch.finfo(input.dtype).min - if nan != 0: - raise ValueError("nan_to_num only supports nan=0") - return torch.clamp( - input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out - ) - - -# ---------------------------------------------------------------------------- -# Symbolic assert. - -try: - symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access -except AttributeError: - symbolic_assert = torch.Assert # 1.7.0 - -# ---------------------------------------------------------------------------- -# Context manager to temporarily suppress known warnings in torch.jit.trace(). -# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 - -@contextlib.contextmanager -def suppress_tracer_warnings(): # pragma: no cover - """ - Context manager to temporarily suppress known warnings in torch.jit.trace(). - Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 - """ - flt = ("ignore", None, torch.jit.TracerWarning, None, 0) - warnings.filters.insert(0, flt) - yield - warnings.filters.remove(flt) - - -# ---------------------------------------------------------------------------- -# Assert that the shape of a tensor matches the given list of integers. -# None indicates that the size of a dimension is allowed to vary. -# Performs symbolic assertion when used in torch.jit.trace(). - - -def assert_shape(tensor, ref_shape): # pragma: no cover - """ - Assert that the shape of a tensor matches the given list of integers. - None indicates that the size of a dimension is allowed to vary. - Performs symbolic assertion when used in torch.jit.trace(). - """ - if tensor.ndim != len(ref_shape): - raise AssertionError( - f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}" - ) - for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): - if ref_size is None: - pass - elif isinstance(ref_size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert( - torch.equal(torch.as_tensor(size), ref_size), - f"Wrong size for dimension {idx}", - ) - elif isinstance(size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert( - torch.equal(size, torch.as_tensor(ref_size)), - f"Wrong size for dimension {idx}: expected {ref_size}", - ) - elif size != ref_size: - raise AssertionError( - f"Wrong size for dimension {idx}: got {size}, expected {ref_size}" - ) - - -# ---------------------------------------------------------------------------- -# Function decorator that calls torch.autograd.profiler.record_function(). - - -def profiled_function(fn): # pragma: no cover - """Function decorator that calls torch.autograd.profiler.record_function().""" - - def decorator(*args, **kwargs): - with torch.autograd.profiler.record_function(fn.__name__): - return fn(*args, **kwargs) - - decorator.__name__ = fn.__name__ - return decorator + times = [ + t.strftime(time_format) + for t in time_range(start_time, end_time, interval, inclusive=True) + ] + return times # ---------------------------------------------------------------------------- @@ -574,6 +133,8 @@ class InfiniteSampler(torch.utils.data.Sampler[int]): # pragma: no cover window_size : float, default=0.5 Fraction of dataset to use as window for shuffling. Must be between 0 and 1. A larger window means more thorough shuffling but slower iteration. + start_idx : int, default=0 + The initial index to use for the sampler. This is used for resuming training. """ def __init__( @@ -584,6 +145,7 @@ def __init__( shuffle: bool = True, seed: int = 0, window_size: float = 0.5, + start_idx: int = 0, ): if not len(dataset) > 0: raise ValueError("Dataset must contain at least one item") @@ -600,6 +162,7 @@ def __init__( self.shuffle = shuffle self.seed = seed self.window_size = window_size + self.start_idx = start_idx def __iter__(self) -> Iterator[int]: order = np.arange(len(self.dataset)) @@ -610,7 +173,7 @@ def __iter__(self) -> Iterator[int]: rnd.shuffle(order) window = int(np.rint(order.size * self.window_size)) - idx = 0 + idx = self.start_idx while True: i = idx % order.size if idx % self.num_replicas == self.rank: @@ -619,180 +182,3 @@ def __iter__(self) -> Iterator[int]: j = (i - rnd.randint(window)) % order.size order[i], order[j] = order[j], order[i] idx += 1 - - -# ---------------------------------------------------------------------------- -# Utilities for operating with torch.nn.Module parameters and buffers. - - -def params_and_buffers(module): # pragma: no cover - """Get parameters and buffers of a nn.Module""" - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - return list(module.parameters()) + list(module.buffers()) - - -def named_params_and_buffers(module): # pragma: no cover - """Get named parameters and buffers of a nn.Module""" - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - return list(module.named_parameters()) + list(module.named_buffers()) - - -@torch.no_grad() -def copy_params_and_buffers( - src_module, dst_module, require_all=False -): # pragma: no cover - """Copy parameters and buffers from a source module to target module""" - if not isinstance(src_module, torch.nn.Module): - raise TypeError("src_module must be a torch.nn.Module instance") - if not isinstance(dst_module, torch.nn.Module): - raise TypeError("dst_module must be a torch.nn.Module instance") - src_tensors = dict(named_params_and_buffers(src_module)) - for name, tensor in named_params_and_buffers(dst_module): - if not ((name in src_tensors) or (not require_all)): - raise ValueError(f"Missing source tensor for {name}") - if name in src_tensors: - tensor.copy_(src_tensors[name]) - - -# ---------------------------------------------------------------------------- -# Context manager for easily enabling/disabling DistributedDataParallel -# synchronization. - - -@contextlib.contextmanager -def ddp_sync(module, sync): # pragma: no cover - """ - Context manager for easily enabling/disabling DistributedDataParallel - synchronization. - """ - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): - yield - else: - with module.no_sync(): - yield - - -# ---------------------------------------------------------------------------- -# Check DistributedDataParallel consistency across processes. - - -def check_ddp_consistency(module, ignore_regex=None): # pragma: no cover - """Check DistributedDataParallel consistency across processes.""" - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - for name, tensor in named_params_and_buffers(module): - fullname = type(module).__name__ + "." + name - if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): - continue - tensor = tensor.detach() - if tensor.is_floating_point(): - tensor = nan_to_num(tensor) - other = tensor.clone() - torch.distributed.broadcast(tensor=other, src=0) - if not (tensor == other).all(): - raise RuntimeError(f"DDP consistency check failed for {fullname}") - - -# ---------------------------------------------------------------------------- -# Print summary table of module hierarchy. - - -def print_module_summary( - module, inputs, max_nesting=3, skip_redundant=True -): # pragma: no cover - """Print summary table of module hierarchy.""" - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - if isinstance(module, torch.jit.ScriptModule): - raise TypeError("module must not be a torch.jit.ScriptModule instance") - if not isinstance(inputs, (tuple, list)): - raise TypeError("inputs must be a tuple or list") - - # Register hooks. - entries = [] - nesting = [0] - - def pre_hook(_mod, _inputs): - nesting[0] += 1 - - def post_hook(mod, _inputs, outputs): - nesting[0] -= 1 - if nesting[0] <= max_nesting: - outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] - outputs = [t for t in outputs if isinstance(t, torch.Tensor)] - entries.append(EasyDict(mod=mod, outputs=outputs)) - - hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] - hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] - - # Run module. - outputs = module(*inputs) - for hook in hooks: - hook.remove() - - # Identify unique outputs, parameters, and buffers. - tensors_seen = set() - for e in entries: - e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] - e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] - e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] - tensors_seen |= { - id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs - } - - # Filter out redundant entries. - if skip_redundant: - entries = [ - e - for e in entries - if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs) - ] - - # Construct table. - rows = [ - [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"] - ] - rows += [["---"] * len(rows[0])] - param_total = 0 - buffer_total = 0 - submodule_names = {mod: name for name, mod in module.named_modules()} - for e in entries: - name = "" if e.mod is module else submodule_names[e.mod] - param_size = sum(t.numel() for t in e.unique_params) - buffer_size = sum(t.numel() for t in e.unique_buffers) - output_shapes = [str(list(t.shape)) for t in e.outputs] - output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs] - rows += [ - [ - name + (":0" if len(e.outputs) >= 2 else ""), - str(param_size) if param_size else "-", - str(buffer_size) if buffer_size else "-", - (output_shapes + ["-"])[0], - (output_dtypes + ["-"])[0], - ] - ] - for idx in range(1, len(e.outputs)): - rows += [ - [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]] - ] - param_total += param_size - buffer_total += buffer_size - rows += [["---"] * len(rows[0])] - rows += [["Total", str(param_total), str(buffer_total), "-", "-"]] - - # Print table. - widths = [max(len(cell) for cell in column) for column in zip(*rows)] - for row in rows: - print( - " ".join( - cell + " " * (width - len(cell)) for cell, width in zip(row, widths) - ) - ) - return outputs - - -# ---------------------------------------------------------------------------- diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 8665536..0cdfd98 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -14,18 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime from typing import Optional +import os +import logging -import cftime import nvtx +import numpy as np import torch import tqdm +from matplotlib import pyplot as plt +import cartopy.crs as ccrs -from .function_utils import StackedRandomGenerator, time_range - -from .stochastic_sampler import stochastic_sampler -from .deterministic_sampler import deterministic_sampler +from .function_utils import StackedRandomGenerator +from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra, crps ############################################################################ # CorrDiff Generation Utilities # @@ -84,7 +85,7 @@ def regression_step( if lead_time_label is not None: x = net(x=x_hat[0:1], img_lr=img_lr, lead_time_label=lead_time_label) else: - x = net(x=x_hat[0:1], img_lr=img_lr) + x = net(x=x_hat[0:1], img_lr=img_lr, force_fp32=False) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: @@ -201,119 +202,165 @@ def diffusion_step( return torch.cat(all_images) -def generate(): - pass - ############################################################################ -# CorrDiff writer utilities # +# Visualization Utilities # ############################################################################ -class NetCDFWriter: - """NetCDF Writer""" - - def __init__( - self, f, lat, lon, input_channels, output_channels, has_lead_time=False - ): - self._f = f - self.has_lead_time = has_lead_time - # create unlimited dimensions - f.createDimension("time") - f.createDimension("ensemble") - - if lat.shape != lon.shape: - raise ValueError("lat and lon must have the same shape") - ny, nx = lat.shape - - # create lat/lon grid - f.createDimension("x", nx) - f.createDimension("y", ny) - - v = f.createVariable("lat", "f", dimensions=("y", "x")) - # NOTE rethink this for datasets whose samples don't have constant lat-lon. - v[:] = lat - v.standard_name = "latitude" - v.units = "degrees_north" - - v = f.createVariable("lon", "f", dimensions=("y", "x")) - v[:] = lon - v.standard_name = "longitude" - v.units = "degrees_east" - - # create time dimension - if has_lead_time: - v = f.createVariable("time", "str", ("time")) +def save_results_as_torch(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): + os.makedirs(output_path, exist_ok=True) + target = np.flip(dataset.denormalize_output(image_hr)[0,::].squeeze(),1) + prediction_ensemble = np.flip(dataset.denormalize_output(image_pred).squeeze(),-2) + baseline = np.flip(dataset.denormalize_input(image_lr)[0,::].squeeze(),1) + if mean_pred is not None: + mean_pred = np.flip(dataset.denormalize_output(mean_pred)[0,::].squeeze(),1) + torch.save(mean_pred, os.path.join(output_path, f'{time_step}-regression-prediction')) + torch.save(target, os.path.join(output_path, f'{time_step}-target')) + torch.save(prediction_ensemble, os.path.join(output_path, f'{time_step}-predictions')) + torch.save(baseline, os.path.join(output_path, f'{time_step}-baseline')) + +@DeprecationWarning +def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): + + os.makedirs(output_path, exist_ok=True) + + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + + target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + prediction = np.flip(dataset.denormalize_output(image_pred),-2) #.reshape(len(output_channels),-1) + baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) + if mean_pred is not None: + mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + + # Plot power spectra + freqs = {} + power = {} + for idx, channel in enumerate(output_channels): + channel_dir = channel.name + "_" + channel.level if channel.level else channel.name + output_path_channel = os.path.join(output_path, channel_dir) + if not os.path.exists(output_path_channel): + os.makedirs(output_path_channel) + input_channel_idx = input_channels.index(channel) + + if channel.name=="tp": + target[idx,::] = transform_channel(target[idx,:,:]) + prediction[:,idx,::] = transform_channel(prediction[:,idx,:,:]) + baseline[input_channel_idx,:,:] = transform_channel(baseline[input_channel_idx,::]) + if mean_pred is not None: + mean_pred[idx,::] = transform_channel(mean_pred[idx,::]) + + if mean_pred is not None: + vmin, vmax = calculate_bounds(target[idx,:,:], + prediction[:,idx,:,:], + baseline[input_channel_idx,:,:], + mean_pred[idx,:,:]) else: - v = f.createVariable("time", "i8", ("time")) - v.calendar = "standard" - v.units = "hours since 1990-01-01 00:00:00" - - self.truth_group = f.createGroup("truth") - self.prediction_group = f.createGroup("prediction") - self.input_group = f.createGroup("input") - - for variable in output_channels: - name = variable.name + variable.level - self.truth_group.createVariable(name, "f", dimensions=("time", "y", "x")) - self.prediction_group.createVariable( - name, "f", dimensions=("ensemble", "time", "y", "x") - ) - - # setup input data in netCDF - - for variable in input_channels: - name = variable.name + variable.level - self.input_group.createVariable(name, "f", dimensions=("time", "y", "x")) - - def write_input(self, channel_name, time_index, val): - """Write input data to NetCDF file.""" - self.input_group[channel_name][time_index] = val - - def write_truth(self, channel_name, time_index, val): - """Write ground truth data to NetCDF file.""" - self.truth_group[channel_name][time_index] = val - - def write_prediction(self, channel_name, time_index, ensemble_index, val): - """Write prediction data to NetCDF file.""" - self.prediction_group[channel_name][ensemble_index, time_index] = val - - def write_time(self, time_index, time): - """Write time information to NetCDF file.""" - if self.has_lead_time: - self._f["time"][time_index] = time + vmin, vmax = calculate_bounds(target[idx,:,:], + prediction[:,idx,:,:], + baseline[input_channel_idx,:,:]) + _plot_projection(longitudes, latitudes, target[idx,:,:], + os.path.join(output_path_channel, f'{time_step}-{channel.name}-target.jpg'), + vmin=vmin, vmax=vmax) + if prediction.shape[0] > 1: + for member_idx in range(prediction.shape[0]): + _plot_projection(longitudes, latitudes, prediction[member_idx,idx,:,:], + os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction_{member_idx}.jpg'), + vmin=vmin, vmax=vmax) else: - time_v = self._f["time"] - self._f["time"][time_index] = cftime.date2num( - time, time_v.units, time_v.calendar - ) - - -############################################################################ -# CorrDiff time utilities # -############################################################################ - - -def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"): - """Generates a list of times within a given range. - - Args: - times_range: A list containing start time, end time, and optional interval (hours). - time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S"). - - Returns: - A list of times within the specified range. - """ - - start_time = datetime.datetime.strptime(times_range[0], time_format) - end_time = datetime.datetime.strptime(times_range[1], time_format) - interval = ( - datetime.timedelta(hours=times_range[2]) - if len(times_range) > 2 - else datetime.timedelta(hours=1) - ) - - times = [ - t.strftime(time_format) - for t in time_range(start_time, end_time, interval, inclusive=True) - ] - return times + _plot_projection(longitudes, latitudes, + prediction[0,idx,:,:], os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction.jpg'), + vmin=vmin, vmax=vmax) + _plot_projection(longitudes, latitudes, baseline[input_channel_idx,:,:], + os.path.join(output_path_channel, f'{time_step}-{channel.name}-input.jpg'), + vmin=vmin, vmax=vmax) + if mean_pred is not None: + _plot_projection(longitudes, latitudes, mean_pred[idx,:,:], + os.path.join(output_path_channel, f'{time_step}-{channel.name}-mean_prediction.jpg'), + vmin=vmin, vmax=vmax) + + _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) + plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-baseline-error.jpg')) + if prediction.shape[0] > 1: + for member_idx in range(prediction.shape[0]): + _, prediction_errors = compute_mae(prediction[member_idx,idx,:,:], target[idx,:,:]) + plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction_{member_idx}-error.jpg')) + else: + _, prediction_errors = compute_mae(prediction[0,idx,:,:], target[idx,:,:]) + plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction-error.jpg')) + if mean_pred is not None: + _, mean_prediction_errors = compute_mae(mean_pred[idx,:,:], target[idx,:,:]) + plot_error_projection(mean_prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-mean-prediction-error.jpg')) + + b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0) + freqs['baseline'] = b_freq + power['baseline'] = b_power + #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) + t_freq, t_power = average_power_spectrum(target[idx,:,:].squeeze(), 2.0) + freqs['target'] = t_freq + power['target'] = t_power + p_freq, p_power = average_power_spectrum(prediction[-1,idx,:,:].squeeze(), 2.0) + freqs['prediction'] = p_freq + power['prediction'] = p_power + if mean_pred is not None: + mp_freq, mp_power = average_power_spectrum(mean_pred[idx,:,:].squeeze(), 2.0) + freqs['mean_prediction'] = mp_freq + power['mean_prediction'] = mp_power + plot_power_spectra(freqs, power, channel.name, os.path.join(output_path_channel, f'{time_step}-{channel.name}-spectra.jpg')) + +def transform_channel(channel_array, channel_name="tp"): + # precip_array = np.clip(precip_array, 0, None) + # precip_array = np.where(precip_array == 0, 1e-6, precip_array) + # epsilon = 1e-2 + # precip_array = precip_array + epsilon + # precip_array = np.log10(precip_array) + # log_min, log_max = precip_array.min(), precip_array.max() + # precip_array = (precip_array-log_min)/(log_max-log_min) + if channel_name == "tp": + channel_array = np.clip(channel_array, 0, None) + channel_array = (np.power(channel_array,0.25)-1)/0.25 + elif channel_name == "2t": + channel_array = channel_array - 273.15 + # precip_array = np.sqrt(precip_array) + return channel_array + +@DeprecationWarning +def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): + + """Plot observed or interpolated data in a scatter plot.""" + # TODO: Refactor this somehow, it's not really generalizing well across variables. + fig = plt.figure() + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) + ax.coastlines() + ax.gridlines(draw_labels=True) + plt.colorbar(p, orientation="horizontal") + plt.savefig(filename) + plt.close('all') + +def calculate_bounds(*arrays: np.ndarray) -> tuple[float]: + """Calculate consistent bounds across all arrays""" + valid_arrays = [arr for arr in arrays if arr is not None] + if not valid_arrays: + return 0, 1 + + # hanndle if there are masked arrays with invalid values (e.g. NaNs) + all_values = [] + for arr in valid_arrays: + if hasattr(arr, 'compressed'): # Masked array + compressed = arr.compressed() + if len(compressed) > 0: + all_values.extend(compressed) + elif hasattr(arr, 'flatten'): # Regular numpy array + all_values.extend(arr.flatten()) + else: + all_values.append(arr) + + if not all_values: + return 0, 1 + + vmin = min(all_values) + vmax = max(all_values) + return vmin, vmax diff --git a/src/hirad/utils/patching.py b/src/hirad/utils/patching.py index 6f4bc4d..bd537af 100644 --- a/src/hirad/utils/patching.py +++ b/src/hirad/utils/patching.py @@ -591,14 +591,26 @@ def image_batching( ) # (padding_left,padding_right,padding_top,padding_bottom) input_padded = image_padding(input) patch_num = patch_num_x * patch_num_y + + # Cast to float for unfold + if input.dtype == torch.int32: + input_padded = input_padded.view(torch.float32) + elif input.dtype == torch.int64: + input_padded = input_padded.view(torch.float64) + x_unfold = torch.nn.functional.unfold( - input=input_padded.view(_cast_type(input_padded)), # Cast to float + input=input_padded, kernel_size=(patch_shape_y, patch_shape_x), stride=( patch_shape_y - overlap_pix - boundary_pix, patch_shape_x - overlap_pix - boundary_pix, ), - ).to(input_padded.dtype) + ) + + # Cast back to original dtype + if input.dtype in [torch.int32, torch.int64]: + x_unfold = x_unfold.view(input.dtype) + x_unfold = rearrange( x_unfold, "b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w", @@ -608,16 +620,7 @@ def image_batching( nb_p_w=patch_num_x, ) if input_interp is not None: - input_interp_repeated = rearrange( - torch.repeat_interleave( - input=input_interp, - repeats=patch_num, - dim=0, - output_size=x_unfold.shape[0], - ), - "(b p) c h w -> (p b) c h w", - p=patch_num, - ) + input_interp_repeated = input_interp.repeat(patch_num, 1, 1, 1) return torch.cat((x_unfold, input_interp_repeated), dim=1) else: return x_unfold @@ -722,6 +725,13 @@ def image_fuse( nb_p_h=patch_num_y, nb_p_w=patch_num_x, ) + + # Cast to float for fold + if input.dtype == torch.int32: + x = x.view(torch.float32) + elif input.dtype == torch.int64: + x = x.view(torch.float64) + # Stitch patches together (by summing over overlapping patches) x_folded = torch.nn.functional.fold( input=x, @@ -733,6 +743,10 @@ def image_fuse( ), ) + # Cast back to original dtype + if input.dtype in [torch.int32, torch.int64]: + x_folded = x_folded.view(input.dtype) + # Remove padding x_no_padding = x_folded[ ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x @@ -743,25 +757,3 @@ def image_fuse( # Normalize by overlap count return x_no_padding / overlap_count_no_padding - - -def _cast_type(input: Tensor) -> torch.dtype: - """Return float type based on input tensor type. - - Parameters - ---------- - input : Tensor - Input tensor to determine float type from - - Returns - ------- - torch.dtype - Float type corresponding to input tensor type for int32/64, - otherwise returns original dtype - """ - if input.dtype == torch.int32: - return torch.float32 - elif input.dtype == torch.int64: - return torch.float64 - else: - return input.dtype diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index 218d6f1..fd6e8cc 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -16,9 +16,13 @@ import torch import numpy as np -from omegaconf import ListConfig import warnings +import mlflow +from omegaconf import DictConfig, OmegaConf +import os +from hirad.distributed import DistributedManager +from hirad.utils.env_info import get_env_info, flatten_dict def set_patch_shape(img_shape, patch_shape): img_shape_y, img_shape_x = img_shape @@ -100,11 +104,6 @@ def handle_and_clip_gradients(model, grad_clip_threshold=None): torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold) -def parse_model_args(args): - """Convert ListConfig values in args to tuples.""" - return {k: tuple(v) if isinstance(v, ListConfig) else v for k, v in args.items()} - - def is_time_for_periodic_task( cur_nimg, freq, done, batch_size, rank, rank_0_only=False ): @@ -115,3 +114,47 @@ def is_time_for_periodic_task( return True else: return cur_nimg % freq < batch_size + + +def init_mlflow(cfg: DictConfig, dist: DistributedManager) -> None: + if dist.rank==0: + print("Started activating initial mlflow run") + if cfg.logging.uri is not None: + mlflow.set_tracking_uri(cfg.logging.uri) + mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) + run_id = None + if os.path.isfile('run_id.txt'): + with open('run_id.txt','r') as f: + run_id = f.read() + if dist.world_size<=4: + mlflow.system_metrics.set_system_metrics_node_id("node-0") + if run_id: + mlflow.start_run(run_id=run_id, log_system_metrics=False if dist.world_size>4 else True) + else: + mlflow.start_run(run_name=cfg.logging.run_name, log_system_metrics=False if dist.world_size>4 else True) + if run_id is None: + run = mlflow.active_run() + with open("run_id.txt", 'w') as f: + f.write(run.info.run_id) + # log environment info if run is not continuing from previous checkpoint + mlflow.log_params(flatten_dict(OmegaConf.to_object(cfg))) + python_environment, git_diff = get_env_info(exclude_prefixes=['hirad', '__mp_main__']) + mlflow.log_dict(python_environment, "environment.json") + if git_diff: + mlflow.log_text(git_diff, "git_diff.txt") + mlflow.log_dict(cfg, "config.json") + + if dist.world_size > 4: + torch.distributed.barrier() + + if (dist.rank!=0 and dist._local_rank==0) or (dist.rank==1 and dist.world_size>4): + print("Started actvating sub mlflow run.") + if cfg.logging.uri is not None: + mlflow.set_tracking_uri(cfg.logging.uri) + mlflow.system_metrics.set_system_metrics_node_id(f"node-{(dist.rank//4)}" + if dist.rank!=1 + else "node-0") + mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) + with open("run_id.txt", 'r') as f: + run_id = f.read() + mlflow.start_run(run_id=run_id, log_system_metrics=True)