From a430e8c560fa0e9d267ed07a3dc0020589ccd818 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 10 Apr 2025 14:52:53 +0200 Subject: [PATCH 001/189] add basic example for regridding using scipy --- read_datasets.py | 159 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 read_datasets.py diff --git a/read_datasets.py b/read_datasets.py new file mode 100644 index 0000000..4ae11e4 --- /dev/null +++ b/read_datasets.py @@ -0,0 +1,159 @@ +from anemoi.datasets import open_dataset +import matplotlib.pyplot as plt +import cartopy.crs as ccrs +import numpy as np +from scipy.interpolate import griddata + +COSMO_PATH = '/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr' # path in Balfrin +ERA_PATH = '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr' # path in Balfrin + +# trim edge removes boundary +cosmo = open_dataset(COSMO_PATH, select="2t", trim_edge=20) # open cosmo, select only 2m-temperature +start_date = cosmo.metadata()['start_date'] # get start and end date of cosmo +end_date = cosmo.metadata()['end_date'] +era = open_dataset(ERA_PATH, select="2t", start=start_date, end=end_date) # load era5 2m-temperature in the time-range of cosmo + + +# get indeces of era5 data that is in the bounding rectangle of cosmo data - this is just for plotting +min_lat_cosmo = min(cosmo.latitudes) +max_lat_cosmo = max(cosmo.latitudes) +min_lon_cosmo = min(cosmo.longitudes) +max_lon_cosmo = max(cosmo.longitudes) +box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) +box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) +indeces = np.where(box_lon*box_lat) + + +#### Approach 1 ######################################################### +#### Scipy Interpolate ################################################## + +grid = np.column_stack((era.longitudes, era.latitudes)) # stack lon-lat columns of era5 points +values = np.array(era[0,0,0,:]) # get era grid 2m-temperature values on the first avaialble date-time + +interp_grid = np.column_stack((cosmo.longitudes, cosmo.latitudes)) # stack lon-lat column of cosmo points + +values_int = griddata(grid,values,interp_grid,method='linear') # interpolate era5 to cosmo grid using scipy griddata linear + + +################ plotting ################################################ + +# plot era original +fig = plt.figure() +fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) +p = ax.scatter(x=era.longitudes[indeces], y=era.latitudes[indeces], c=era[0, 0, 0, :][indeces]) +ax.coastlines() +ax.gridlines(draw_labels=True) +plt.colorbar(p, label="K", orientation="horizontal") +plt.savefig("temperature-2m-era.jpg") + +# plot cosmo original +fig = plt.figure() +fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) +p = ax.scatter(x=cosmo.longitudes, y=cosmo.latitudes, c=cosmo[0, 0, 0, :]) +ax.coastlines() +ax.gridlines(draw_labels=True) +plt.colorbar(p, label="K", orientation="horizontal") +plt.savefig("temperature-2m-cosmo.jpg") + +#plot inerpolated era5 +fig = plt.figure() +fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) +p = ax.scatter(x=cosmo.longitudes, y=cosmo.latitudes, c=values_int) +ax.coastlines() +ax.gridlines(draw_labels=True) +plt.colorbar(p, label="K", orientation="horizontal") +plt.savefig("temperature-2m-era-downscaled.jpg") + + + + + + + + + + + + + +# cosmo = xr.open_zarr(COSMO_PATH) +# print(cosmo.attrs['data_request']) +# era = xr.open_zarr(ERA_PATH) +# print(cosmo[0,0,0,0]) +# print(era) + + + + +# min_lat_cosmo = min(cosmo.latitudes) +# max_lat_cosmo = max(cosmo.latitudes) +# min_lon_cosmo = min(cosmo.longitudes) +# max_lon_cosmo = max(cosmo.longitudes) +# box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) +# box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) +# indeces = np.where(box_lon*box_lat) + +# ds = xr.tutorial.open_dataset( +# "air_temperature" +# ) # use xr.tutorial.load_dataset() for xarray Date: Thu, 10 Apr 2025 14:56:53 +0200 Subject: [PATCH 002/189] rename and delete comments --- read_datasets.py => interpolate_basic.py | 96 +----------------------- 1 file changed, 1 insertion(+), 95 deletions(-) rename read_datasets.py => interpolate_basic.py (50%) diff --git a/read_datasets.py b/interpolate_basic.py similarity index 50% rename from read_datasets.py rename to interpolate_basic.py index 4ae11e4..d17ab5a 100644 --- a/read_datasets.py +++ b/interpolate_basic.py @@ -62,98 +62,4 @@ ax.coastlines() ax.gridlines(draw_labels=True) plt.colorbar(p, label="K", orientation="horizontal") -plt.savefig("temperature-2m-era-downscaled.jpg") - - - - - - - - - - - - - -# cosmo = xr.open_zarr(COSMO_PATH) -# print(cosmo.attrs['data_request']) -# era = xr.open_zarr(ERA_PATH) -# print(cosmo[0,0,0,0]) -# print(era) - - - - -# min_lat_cosmo = min(cosmo.latitudes) -# max_lat_cosmo = max(cosmo.latitudes) -# min_lon_cosmo = min(cosmo.longitudes) -# max_lon_cosmo = max(cosmo.longitudes) -# box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) -# box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) -# indeces = np.where(box_lon*box_lat) - -# ds = xr.tutorial.open_dataset( -# "air_temperature" -# ) # use xr.tutorial.load_dataset() for xarray Date: Thu, 10 Apr 2025 15:07:11 +0200 Subject: [PATCH 003/189] start a branch for transfering corrdiff from nvidia-modulus --- interpolate_basic.py | 65 -------------------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 interpolate_basic.py diff --git a/interpolate_basic.py b/interpolate_basic.py deleted file mode 100644 index d17ab5a..0000000 --- a/interpolate_basic.py +++ /dev/null @@ -1,65 +0,0 @@ -from anemoi.datasets import open_dataset -import matplotlib.pyplot as plt -import cartopy.crs as ccrs -import numpy as np -from scipy.interpolate import griddata - -COSMO_PATH = '/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr' # path in Balfrin -ERA_PATH = '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr' # path in Balfrin - -# trim edge removes boundary -cosmo = open_dataset(COSMO_PATH, select="2t", trim_edge=20) # open cosmo, select only 2m-temperature -start_date = cosmo.metadata()['start_date'] # get start and end date of cosmo -end_date = cosmo.metadata()['end_date'] -era = open_dataset(ERA_PATH, select="2t", start=start_date, end=end_date) # load era5 2m-temperature in the time-range of cosmo - - -# get indeces of era5 data that is in the bounding rectangle of cosmo data - this is just for plotting -min_lat_cosmo = min(cosmo.latitudes) -max_lat_cosmo = max(cosmo.latitudes) -min_lon_cosmo = min(cosmo.longitudes) -max_lon_cosmo = max(cosmo.longitudes) -box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) -box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) -indeces = np.where(box_lon*box_lat) - - -#### Approach 1 ######################################################### -#### Scipy Interpolate ################################################## - -grid = np.column_stack((era.longitudes, era.latitudes)) # stack lon-lat columns of era5 points -values = np.array(era[0,0,0,:]) # get era grid 2m-temperature values on the first avaialble date-time - -interp_grid = np.column_stack((cosmo.longitudes, cosmo.latitudes)) # stack lon-lat column of cosmo points - -values_int = griddata(grid,values,interp_grid,method='linear') # interpolate era5 to cosmo grid using scipy griddata linear - - -################ plotting ################################################ - -# plot era original -fig = plt.figure() -fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -p = ax.scatter(x=era.longitudes[indeces], y=era.latitudes[indeces], c=era[0, 0, 0, :][indeces]) -ax.coastlines() -ax.gridlines(draw_labels=True) -plt.colorbar(p, label="K", orientation="horizontal") -plt.savefig("temperature-2m-era.jpg") - -# plot cosmo original -fig = plt.figure() -fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -p = ax.scatter(x=cosmo.longitudes, y=cosmo.latitudes, c=cosmo[0, 0, 0, :]) -ax.coastlines() -ax.gridlines(draw_labels=True) -plt.colorbar(p, label="K", orientation="horizontal") -plt.savefig("temperature-2m-cosmo.jpg") - -#plot inerpolated era5 -fig = plt.figure() -fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -p = ax.scatter(x=cosmo.longitudes, y=cosmo.latitudes, c=values_int) -ax.coastlines() -ax.gridlines(draw_labels=True) -plt.colorbar(p, label="K", orientation="horizontal") -plt.savefig("temperature-2m-era-downscaled.jpg") \ No newline at end of file From 99d55dffe2893523a9a0c795a68604c20e81ef3e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 11 Apr 2025 13:43:37 +0200 Subject: [PATCH 004/189] add model and loss scripts --- src/losses/loss.py | 914 +++++++++++++++++++++ src/models/layers.py | 567 ++++++++++++++ src/models/preconditioning copy.py | 1176 ++++++++++++++++++++++++++++ src/models/preconditioning.py | 1176 ++++++++++++++++++++++++++++ src/models/song_unet.py | 906 +++++++++++++++++++++ src/models/unet.py | 267 +++++++ src/models/utils.py | 66 ++ 7 files changed, 5072 insertions(+) create mode 100644 src/losses/loss.py create mode 100644 src/models/layers.py create mode 100644 src/models/preconditioning copy.py create mode 100644 src/models/preconditioning.py create mode 100644 src/models/song_unet.py create mode 100644 src/models/unet.py create mode 100644 src/models/utils.py diff --git a/src/losses/loss.py b/src/losses/loss.py new file mode 100644 index 0000000..18dde13 --- /dev/null +++ b/src/losses/loss.py @@ -0,0 +1,914 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Loss functions used in the paper +"Elucidating the Design Space of Diffusion-Based Generative Models".""" + +import random +from typing import Callable, Optional, Union + +import numpy as np +import torch + + +class VPLoss: + """ + Loss function corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + beta_d: float, optional + Coefficient for the diffusion process, by default 19.9. + beta_min: float, optional + Minimum bound, by defaults 0.1. + epsilon_t: float, optional + Small positive value, by default 1e-5. + + Note: + ----- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + """ + + def __init__( + self, beta_d: float = 19.9, beta_min: float = 0.1, epsilon_t: float = 1e-5 + ): + self.beta_d = beta_d + self.beta_min = beta_min + self.epsilon_t = epsilon_t + + def __call__( + self, + net: torch.nn.Module, + images: torch.Tensor, + labels: torch.Tensor, + augment_pipe: Optional[Callable] = None, + ): + """ + Calculate and return the loss corresponding to the variance preserving (VP) + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'epsilon_t' and random values. The calculated loss is weighted based on the + inverse of 'sigma^2'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) + weight = 1 / sigma**2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + def sigma( + self, t: Union[float, torch.Tensor] + ): # NOTE: also exists in preconditioning + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + +class VELoss: + """ + Loss function corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + + Note: + ----- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__(self, sigma_min: float = 0.02, sigma_max: float = 100.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def __call__(self, net, images, labels, augment_pipe=None): + """ + Calculate and return the loss corresponding to the variance exploding (VE) + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'sigma_min' and 'sigma_max' and random values. The calculated loss is weighted + based on the inverse of 'sigma^2'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) + weight = 1 / sigma**2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class EDMLoss: + """ + Loss function proposed in the EDM paper. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): + """ + Calculate and return the loss corresponding to the EDM formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'P_mean' and 'P_std' random values. The calculated loss is weighted as a + function of 'sigma' and 'sigma_data'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + if condition is not None: + D_yn = net( + y + n, + sigma, + condition=condition, + class_labels=labels, + augment_labels=augment_labels, + ) + else: + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class EDMLossSR: + """ + Variation of the loss function proposed in the EDM paper for Super-Resolution. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + """ + Calculate and return the loss corresponding to the EDM formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'P_mean' and 'P_std' random values. The calculated loss is weighted as a + function of 'sigma' and 'sigma_data'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # augment for conditional generaiton + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + n = torch.randn_like(y) * sigma + D_yn = net(y + n, y_lr, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class RegressionLoss: + """ + Regression loss function for the U-Net for deterministic predictions. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + """ + Calculate and return the loss for the U-Net for deterministic predictions. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = ( + 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + ) + + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + input = torch.zeros_like(y, device=img_clean.device) + D_yn = net(input, y_lr, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + + return loss + + +class ResLoss: + """ + Mixture loss function for denoising score matching. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + regression_net, + img_shape_x, + img_shape_y, + patch_shape_x, + patch_shape_y, + patch_num, + P_mean: float = 0.0, + P_std: float = 1.2, + sigma_data: float = 0.5, + hr_mean_conditioning: bool = False, + ): + self.unet = regression_net + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + self.img_shape_x = img_shape_x + self.img_shape_y = img_shape_y + self.patch_shape_x = patch_shape_x + self.patch_shape_y = patch_shape_y + self.patch_num = patch_num + self.hr_mean_conditioning = hr_mean_conditioning + + def __call__( + self, + net, + img_clean, + img_lr, + labels=None, + lead_time_label=None, + augment_pipe=None, + ): + """ + Calculate and return the loss for denoising score matching. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # augment for conditional generaiton + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + y_lr_res = y_lr + + # global index + b = y.shape[0] + Nx = torch.arange(self.img_shape_x).int() + Ny = torch.arange(self.img_shape_y).int() + grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ + None, + ].expand(b, -1, -1, -1) + + # form residual + if lead_time_label is not None: + y_mean = self.unet( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + sigma, + labels, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.unet( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + sigma, + labels, + augment_labels=augment_labels, + ) + + y = y - y_mean + + if self.hr_mean_conditioning: + y_lr = torch.cat((y_mean, y_lr), dim=1).contiguous() + global_index = None + # patchified training + # conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4 + if ( + self.img_shape_x != self.patch_shape_x + or self.img_shape_y != self.patch_shape_y + ): + c_in = y_lr.shape[1] + c_out = y.shape[1] + rnd_normal = torch.randn( + [img_clean.shape[0] * self.patch_num, 1, 1, 1], device=img_clean.device + ) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / ( + sigma * self.sigma_data + ) ** 2 + + # global interpolation + input_interp = torch.nn.functional.interpolate( + img_lr, + (self.patch_shape_y, self.patch_shape_x), + mode="bilinear", + ) + + # patch generation from a single sample (not from random samples due to memory consumption of regression) + y_new = torch.zeros( + b * self.patch_num, + c_out, + self.patch_shape_y, + self.patch_shape_x, + device=img_clean.device, + ) + y_lr_new = torch.zeros( + b * self.patch_num, + c_in + input_interp.shape[1], + self.patch_shape_y, + self.patch_shape_x, + device=img_clean.device, + ) + global_index = torch.zeros( + b * self.patch_num, + 2, + self.patch_shape_y, + self.patch_shape_x, + dtype=torch.int, + device=img_clean.device, + ) + for i in range(self.patch_num): + rnd_x = random.randint(0, self.img_shape_x - self.patch_shape_x) + rnd_y = random.randint(0, self.img_shape_y - self.patch_shape_y) + y_new[b * i : b * (i + 1),] = y[ + :, + :, + rnd_y : rnd_y + self.patch_shape_y, + rnd_x : rnd_x + self.patch_shape_x, + ] + global_index[b * i : b * (i + 1),] = grid[ + :, + :, + rnd_y : rnd_y + self.patch_shape_y, + rnd_x : rnd_x + self.patch_shape_x, + ] + y_lr_new[b * i : b * (i + 1),] = torch.cat( + ( + y_lr[ + :, + :, + rnd_y : rnd_y + self.patch_shape_y, + rnd_x : rnd_x + self.patch_shape_x, + ], + input_interp, + ), + 1, + ) + y = y_new + y_lr = y_lr_new + latent = y + torch.randn_like(y) * sigma + + if lead_time_label is not None: + D_yn = net( + latent, + y_lr, + sigma, + labels, + global_index=global_index, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + latent, + y_lr, + sigma, + labels, + global_index=global_index, + augment_labels=augment_labels, + ) + loss = weight * ((D_yn - y) ** 2) + + return loss + + +class VELoss_dfsr: + """ + Loss function for dfsr model, modified from class VELoss. + + Parameters + ---------- + beta_start : float + Noise level at the initial step of the forward diffusion process, by default 0.0001. + beta_end : float + Noise level at the Final step of the forward diffusion process, by default 0.02. + num_diffusion_timesteps : int + Total number of forward/backward diffusion steps, by default 1000. + + + Note: + ----- + Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. + Advances in neural information processing systems. 2020;33:6840-51. + """ + + def __init__( + self, + beta_start: float = 0.0001, + beta_end: float = 0.02, + num_diffusion_timesteps: int = 1000, + ): + # scheduler for diffusion: + self.beta_schedule = "linear" + self.beta_start = beta_start + self.beta_end = beta_end + self.num_diffusion_timesteps = num_diffusion_timesteps + betas = self.get_beta_schedule( + beta_schedule=self.beta_schedule, + beta_start=self.beta_start, + beta_end=self.beta_end, + num_diffusion_timesteps=self.num_diffusion_timesteps, + ) + self.betas = torch.from_numpy(betas).float() + self.num_timesteps = betas.shape[0] + + def get_beta_schedule( + self, beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps + ): + """ + Compute the variance scheduling parameters {beta(0), ..., beta(t), ..., beta(T)} + based on the VP formulation. + + beta_schedule: str + Method to construct the sequence of beta(t)'s. + beta_start: float + Noise level at the initial step of the forward diffusion process, e.g., beta(0) + beta_end: float + Noise level at the final step of the forward diffusion process, e.g., beta(T) + num_diffusion_timesteps: int + Total number of forward/backward diffusion steps + """ + + def sigmoid(x): + return 1 / (np.exp(-x) + 1) + + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "sigmoid": + betas = np.linspace(-6, 6, num_diffusion_timesteps) + betas = sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(beta_schedule) + if betas.shape != (num_diffusion_timesteps,): + raise ValueError( + f"Expected betas to have shape ({num_diffusion_timesteps},), " + f"but got {betas.shape}" + ) + return betas + + def __call__(self, net, images, labels, augment_pipe=None): + """ + Calculate and return the loss corresponding to the variance preserving + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the noise samples added + to the t-th step of the diffusion process. + The noise level is determined by 'beta_t' based on the given parameters 'beta_start', + 'beta_end' and the current diffusion timestep t. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input fluid flow data samples to the neural network. + + labels: torch.Tensor + Ground truth labels for the input fluid flow data samples. Not required for dfsr. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + t = torch.randint( + low=0, high=self.num_timesteps, size=(images.size(0) // 2 + 1,) + ).to(images.device) + t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[: images.size(0)] + e = torch.randn_like(images) + b = self.betas.to(images.device) + a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) + x = images * a.sqrt() + e * (1.0 - a).sqrt() + + output = net(x, t, labels) + loss = (e - output).square() + + return loss + + +class RegressionLossCE: + """ + A regression loss function for the GEFS-HRRR model with probability channels, adapted + from RegressionLoss. In this version, probability channels are evaluated using + CrossEntropyLoss instead of MSELoss. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + prob_channels: list, optional + A index list of output probability channels. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + P_mean: float = -1.2, + P_std: float = 1.2, + sigma_data: float = 0.5, + prob_channels: list = [4, 5, 6, 7, 8], + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + self.entropy = torch.nn.CrossEntropyLoss(reduction="none") + self.prob_channels = prob_channels + + def __call__( + self, + net, + img_clean, + img_lr, + lead_time_label=None, + labels=None, + augment_pipe=None, + ): + """ + Calculate and return the loss for the U-Net for deterministic predictions. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + lead_time_label: torch.Tensor + Lead time labels for input batches. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + all_channels = list(range(img_clean.shape[1])) # [0, 1, 2, ..., 10] + scalar_channels = [ + item for item in all_channels if item not in self.prob_channels + ] + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = ( + 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + ) + + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + input = torch.zeros_like(y, device=img_clean.device) + + if lead_time_label is not None: + D_yn = net( + input, + y_lr, + sigma, + labels, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + input, + y_lr, + sigma, + labels, + augment_labels=augment_labels, + ) + loss1 = weight * ((D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2) + loss2 = ( + weight + * self.entropy(D_yn[:, self.prob_channels], y[:, self.prob_channels])[ + :, None + ] + ) + loss = torch.cat((loss1, loss2), dim=1) + return loss diff --git a/src/models/layers.py b/src/models/layers.py new file mode 100644 index 0000000..1fb3b17 --- /dev/null +++ b/src/models/layers.py @@ -0,0 +1,567 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architecture layers used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from typing import Any, Dict, List + +import numpy as np +import torch +from einops import rearrange +from torch.nn.functional import silu + +from physicsnemo.models.diffusion import weight_init + + +class Linear(torch.nn.Module): + """ + A fully connected (dense) layer implementation. The layer's weights and biases can + be initialized using custom initialization strategies like "kaiming_normal", + and can be further scaled by factors `init_weight` and `init_bias`. + + Parameters + ---------- + in_features : int + Size of each input sample. + out_features : int + Size of each output sample. + bias : bool, optional + The biases of the layer. If set to `None`, the layer will not learn an additive + bias. By default True. + init_mode : str, optional (default="kaiming_normal") + The mode/type of initialization to use for weights and biases. Supported modes + are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + By default "kaiming_normal". + init_weight : float, optional + A scaling factor to multiply with the initialized weights. By default 1. + init_bias : float, optional + A scaling factor to multiply with the initialized biases. By default 0. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + init_mode: str = "kaiming_normal", + init_weight: int = 1, + init_bias: int = 0, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) + self.weight = torch.nn.Parameter( + weight_init([out_features, in_features], **init_kwargs) * init_weight + ) + self.bias = ( + torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) + if bias + else None + ) + + def forward(self, x): + x = x @ self.weight.to(x.dtype).t() + if self.bias is not None: + x = x.add_(self.bias.to(x.dtype)) + return x + + +class Conv2d(torch.nn.Module): + """ + A custom 2D convolutional layer implementation with support for up-sampling, + down-sampling, and custom weight and bias initializations. The layer's weights + and biases canbe initialized using custom initialization strategies like + "kaiming_normal", and can be further scaled by factors `init_weight` and + `init_bias`. + + Parameters + ---------- + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels produced by the convolution. + kernel : int + Size of the convolving kernel. + bias : bool, optional + The biases of the layer. If set to `None`, the layer will not learn an + additive bias. By default True. + up : bool, optional + Whether to perform up-sampling. By default False. + down : bool, optional + Whether to perform down-sampling. By default False. + resample_filter : List[int], optional + Filter to be used for resampling. By default [1, 1]. + fused_resample : bool, optional + If True, performs fused up-sampling and convolution or fused down-sampling + and convolution. By default False. + init_mode : str, optional (default="kaiming_normal") + init_mode : str, optional (default="kaiming_normal") + The mode/type of initialization to use for weights and biases. Supported modes + are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + By default "kaiming_normal". + init_weight : float, optional + A scaling factor to multiply with the initialized weights. By default 1.0. + init_bias : float, optional + A scaling factor to multiply with the initialized biases. By default 0.0. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + bias: bool = True, + up: bool = False, + down: bool = False, + resample_filter: List[int] = [1, 1], + fused_resample: bool = False, + init_mode: str = "kaiming_normal", + init_weight: float = 1.0, + init_bias: float = 0.0, + ): + if up and down: + raise ValueError("Both 'up' and 'down' cannot be true at the same time.") + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.up = up + self.down = down + self.fused_resample = fused_resample + init_kwargs = dict( + mode=init_mode, + fan_in=in_channels * kernel * kernel, + fan_out=out_channels * kernel * kernel, + ) + self.weight = ( + torch.nn.Parameter( + weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) + * init_weight + ) + if kernel + else None + ) + self.bias = ( + torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) + if kernel and bias + else None + ) + f = torch.as_tensor(resample_filter, dtype=torch.float32) + f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() + self.register_buffer("resample_filter", f if up or down else None) + + def forward(self, x): + w = self.weight.to(x.dtype) if self.weight is not None else None + b = self.bias.to(x.dtype) if self.bias is not None else None + f = ( + self.resample_filter.to(x.dtype) + if self.resample_filter is not None + else None + ) + w_pad = w.shape[-1] // 2 if w is not None else 0 + f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 + + if self.fused_resample and self.up and w is not None: + x = torch.nn.functional.conv_transpose2d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=max(f_pad - w_pad, 0), + ) + x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) + elif self.fused_resample and self.down and w is not None: + x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + ) + else: + if self.up: + x = torch.nn.functional.conv_transpose2d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if self.down: + x = torch.nn.functional.conv2d( + x, + f.tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if w is not None: + x = torch.nn.functional.conv2d(x, w, padding=w_pad) + if b is not None: + x = x.add_(b.reshape(1, -1, 1, 1)) + return x + + +class GroupNorm(torch.nn.Module): + """ + A custom Group Normalization layer implementation. + + Group Normalization (GN) divides the channels of the input tensor into groups and + normalizes the features within each group independently. It does not require the + batch size as in Batch Normalization, making itsuitable for batch sizes of any size + or even for batch-free scenarios. + + Parameters + ---------- + num_channels : int + Number of channels in the input tensor. + num_groups : int, optional + Desired number of groups to divide the input channels, by default 32. + This might be adjusted based on the `min_channels_per_group`. + min_channels_per_group : int, optional + Minimum channels required per group. This ensures that no group has fewer + channels than this number. By default 4. + eps : float, optional + A small number added to the variance to prevent division by zero, by default + 1e-5. + + Notes + ----- + If `num_channels` is not divisible by `num_groups`, the actual number of groups + might be adjusted to satisfy the `min_channels_per_group` condition. + """ + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + min_channels_per_group: int = 4, + eps: float = 1e-5, + ): + super().__init__() + self.num_groups = min(num_groups, num_channels // min_channels_per_group) + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(num_channels)) + self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + + def forward(self, x): + if self.training: + # Use default torch implementation of GroupNorm for training + # This does not support channels last memory format + x = torch.nn.functional.group_norm( + x, + num_groups=self.num_groups, + weight=self.weight.to(x.dtype), + bias=self.bias.to(x.dtype), + eps=self.eps, + ) + else: + # Use custom GroupNorm implementation that supports channels last + # memory layout for inference + dtype = x.dtype + x = x.float() + x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups) + + mean = x.mean(dim=[2, 3, 4], keepdim=True) + var = x.var(dim=[2, 3, 4], keepdim=True) + + x = (x - mean) * (var + self.eps).rsqrt() + x = rearrange(x, "b g c h w -> b (g c) h w") + + weight = rearrange(self.weight, "c -> 1 c 1 1") + bias = rearrange(self.bias, "c -> 1 c 1 1") + x = x * weight + bias + + x = x.type(dtype) + return x + + +class AttentionOp(torch.autograd.Function): + """ + Attention weight computation, i.e., softmax(Q^T * K). + Performs all computation using FP32, but uses the original datatype for + inputs/outputs/gradients to conserve memory. + """ + + @staticmethod + def forward(ctx, q, k): + w = ( + torch.einsum( + "ncq,nck->nqk", + q.to(torch.float32), + (k / torch.sqrt(torch.tensor(k.shape[1]))).to(torch.float32), + ) + .softmax(dim=2) + .to(q.dtype) + ) + ctx.save_for_backward(q, k, w) + return w + + @staticmethod + def backward(ctx, dw): + q, k, w = ctx.saved_tensors + db = torch._softmax_backward_data( + grad_output=dw.to(torch.float32), + output=w.to(torch.float32), + dim=2, + input_dtype=torch.float32, + ) + dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( + q.dtype + ) / np.sqrt(k.shape[1]) + dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to( + k.dtype + ) / np.sqrt(k.shape[1]) + return dq, dk + + +class UNetBlock(torch.nn.Module): + """ + Unified U-Net block with optional up/downsampling and self-attention. Represents + the union of all features employed by the DDPM++, NCSN++, and ADM architectures. + + Parameters: + ----------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + emb_channels : int + Number of embedding channels. + up : bool, optional + If True, applies upsampling in the forward pass. By default False. + down : bool, optional + If True, applies downsampling in the forward pass. By default False. + attention : bool, optional + If True, enables the self-attention mechanism in the block. By default False. + num_heads : int, optional + Number of attention heads. If None, defaults to `out_channels // 64`. + channels_per_head : int, optional + Number of channels per attention head. By default 64. + dropout : float, optional + Dropout probability. By default 0.0. + skip_scale : float, optional + Scale factor applied to skip connections. By default 1.0. + eps : float, optional + Epsilon value used for normalization layers. By default 1e-5. + resample_filter : List[int], optional + Filter for resampling layers. By default [1, 1]. + resample_proj : bool, optional + If True, resampling projection is enabled. By default False. + adaptive_scale : bool, optional + If True, uses adaptive scaling in the forward pass. By default True. + init : dict, optional + Initialization parameters for convolutional and linear layers. + init_zero : dict, optional + Initialization parameters with zero weights for certain layers. By default + {'init_weight': 0}. + init_attn : dict, optional + Initialization parameters specific to attention mechanism layers. + Defaults to 'init' if not provided. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + up: bool = False, + down: bool = False, + attention: bool = False, + num_heads: int = None, + channels_per_head: int = 64, + dropout: float = 0.0, + skip_scale: float = 1.0, + eps: float = 1e-5, + resample_filter: List[int] = [1, 1], + resample_proj: bool = False, + adaptive_scale: bool = True, + init: Dict[str, Any] = dict(), + init_zero: Dict[str, Any] = dict(init_weight=0), + init_attn: Any = None, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.emb_channels = emb_channels + self.num_heads = ( + 0 + if not attention + else num_heads + if num_heads is not None + else out_channels // channels_per_head + ) + self.dropout = dropout + self.skip_scale = skip_scale + self.adaptive_scale = adaptive_scale + + self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) + self.conv0 = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel=3, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + self.affine = Linear( + in_features=emb_channels, + out_features=out_channels * (2 if adaptive_scale else 1), + **init, + ) + self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) + self.conv1 = Conv2d( + in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero + ) + + self.skip = None + if out_channels != in_channels or up or down: + kernel = 1 if resample_proj or out_channels != in_channels else 0 + self.skip = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel=kernel, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + + if self.num_heads: + self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) + self.qkv = Conv2d( + in_channels=out_channels, + out_channels=out_channels * 3, + kernel=1, + **(init_attn if init_attn is not None else init), + ) + self.proj = Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel=1, + **init_zero, + ) + + def forward(self, x, emb): + torch.cuda.nvtx.range_push("UNetBlock") + orig = x + x = self.conv0(silu(self.norm0(x))) + + params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = silu(self.norm1(x.add_(params))) + + x = self.conv1( + torch.nn.functional.dropout(x, p=self.dropout, training=self.training) + ) + x = x.add_(self.skip(orig) if self.skip is not None else orig) + x = x * self.skip_scale + + if self.num_heads: + q, k, v = ( + self.qkv(self.norm2(x)) + .reshape( + x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 + ) + .unbind(2) + ) + w = AttentionOp.apply(q, k) + a = torch.einsum("nqk,nck->ncq", w, v) + x = self.proj(a.reshape(*x.shape)).add_(x) + x = x * self.skip_scale + torch.cuda.nvtx.range_pop() + return x + + +class PositionalEmbedding(torch.nn.Module): + """ + A module for generating positional embeddings based on timesteps. + This embedding technique is employed in the DDPM++ and ADM architectures. + + Parameters: + ----------- + num_channels : int + Number of channels for the embedding. + max_positions : int, optional + Maximum number of positions for the embeddings, by default 10000. + endpoint : bool, optional + If True, the embedding considers the endpoint. By default False. + + """ + + def __init__( + self, num_channels: int, max_positions: int = 10000, endpoint: bool = False + ): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + + def forward(self, x): + freqs = torch.arange( + start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device + ) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + + +class FourierEmbedding(torch.nn.Module): + """ + Generates Fourier embeddings for timesteps, primarily used in the NCSN++ + architecture. + + This class generates embeddings by first multiplying input tensor `x` and + internally stored random frequencies, and then concatenating the cosine and sine of + the resultant. + + Parameters: + ----------- + num_channels : int + The number of channels in the embedding. The final embedding size will be + 2 * num_channels because of concatenation of cosine and sine results. + scale : int, optional + A scale factor applied to the random frequencies, controlling their range + and thereby the frequency of oscillations in the embedding space. By default 16. + """ + + def __init__(self, num_channels: int, scale: int = 16): + super().__init__() + self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) + + def forward(self, x): + x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x diff --git a/src/models/preconditioning copy.py b/src/models/preconditioning copy.py new file mode 100644 index 0000000..52a1660 --- /dev/null +++ b/src/models/preconditioning copy.py @@ -0,0 +1,1176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preconditioning schemes used in the paper"Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +import importlib +import warnings +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import nvtx +import torch + +from physicsnemo.models.diffusion import ( + DhariwalUNet, # noqa: F401 for globals + SongUNet, # noqa: F401 for globals +) +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + +network_module = importlib.import_module("physicsnemo.models.diffusion") + + +@dataclass +class VPPrecondMetaData(ModelMetaData): + """VPPrecond meta data""" + + name: str = "VPPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VPPrecond(Module): + """ + Preconditioning corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + beta_d : float + Extent of the noise level schedule, by default 19.9. + beta_min : float + Initial slope of the noise level schedule, by default 0.1. + M : int + Original number of timesteps in the DDPM formulation, by default 1000. + epsilon_t : float + Minimum t-value used during training, by default 1e-5. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + beta_d: float = 19.9, + beta_min: float = 0.1, + M: int = 1000, + epsilon_t: float = 1e-5, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__(meta=VPPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.beta_d = beta_d + self.beta_min = beta_min + self.M = M + self.epsilon_t = epsilon_t + self.sigma_min = float(self.sigma(epsilon_t)) + self.sigma_max = float(self.sigma(1)) + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = (self.M - 1) * self.sigma_inv(sigma) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def sigma(self, t: Union[float, torch.Tensor]): + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + def sigma_inv(self, sigma: Union[float, torch.Tensor]): + """ + Compute the inverse of the sigma function for a given sigma. + + This function effectively calculates t from a given sigma(t) based on the + parameters `beta_d` and `beta_min`. + + Parameters + ---------- + sigma : Union[float, torch.Tensor] + The sigma(t) value or set of sigma(t) values for which to compute the + inverse. + + Returns + ------- + torch.Tensor + The computed t value(s) corresponding to the provided sigma(t). + """ + sigma = torch.as_tensor(sigma) + return ( + (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() + - self.beta_min + ) / self.beta_d + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class VEPrecondMetaData(ModelMetaData): + """VEPrecond meta data""" + + name: str = "VEPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VEPrecond(Module): + """ + Preconditioning corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__(meta=VEPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = sigma + c_in = 1 + c_noise = (0.5 * sigma).log() + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class iDDPMPrecondMetaData(ModelMetaData): + """iDDPMPrecond meta data""" + + name: str = "iDDPMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class iDDPMPrecond(Module): + """ + Preconditioning corresponding to the improved DDPM (iDDPM) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + C_1 : float + Timestep adjustment at low noise levels., by default 0.001. + C_2 : float + Timestep adjustment at high noise levels., by default 0.008. + M: int + Original number of timesteps in the DDPM formulation, by default 1000. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Nichol, A.Q. and Dhariwal, P., 2021, July. Improved denoising diffusion + probabilistic models. In International Conference on Machine Learning + (pp. 8162-8171). PMLR. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + C_1=0.001, + C_2=0.008, + M=1000, + model_type="DhariwalUNet", + **model_kwargs, + ): + super().__init__(meta=iDDPMPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.C_1 = C_1 + self.C_2 = C_2 + self.M = M + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels * 2, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + u = torch.zeros(M + 1) + for j in range(M, 0, -1): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) + / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) + - 1 + ).sqrt() + self.register_buffer("u", u) + self.sigma_min = float(u[M - 1]) + self.sigma_max = float(u[0]) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = ( + self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) + ) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x[:, : self.img_channels].to(torch.float32) + return D_x + + def alpha_bar(self, j): + """ + Compute the alpha_bar(j) value for a given j based on the iDDPM formulation. + + Parameters + ---------- + j : Union[int, torch.Tensor] + The timestep or set of timesteps for which to compute alpha_bar(j). + + Returns + ------- + torch.Tensor + The computed alpha_bar(j) value(s). + """ + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + + def round_sigma(self, sigma, return_index=False): + """ + Round the provided sigma value(s) to the nearest value(s) in a + pre-defined set `u`. + + Parameters + ---------- + sigma : Union[float, list, torch.Tensor] + The sigma value(s) to round. + return_index : bool, optional + Whether to return the index/indices of the rounded value(s) in `u` instead + of the rounded value(s) themselves, by default False. + + Returns + ------- + torch.Tensor + The rounded sigma value(s) or their index/indices in `u`, depending on the + value of `return_index`. + """ + sigma = torch.as_tensor(sigma) + index = torch.cdist( + sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), + self.u.reshape(1, -1, 1), + ).argmin(2) + result = index if return_index else self.u[index.flatten()].to(sigma.dtype) + return result.reshape(sigma.shape).to(sigma.device) + + +@dataclass +class EDMPrecondMetaData(ModelMetaData): + """EDMPrecond meta data""" + + name: str = "EDMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecond(Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels (for both input and output). If your model + requires a different number of input or output chanels, + override this by passing either of the optional + img_in_channels or img_out_channels args + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + img_in_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the input + This is useful in the case of additional (conditional) channels + img_out_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the output + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="DhariwalUNet", + img_in_channels=None, + img_out_channels=None, + **model_kwargs, + ): + super().__init__(meta=EDMPrecondMetaData) + self.img_resolution = img_resolution + if img_in_channels is not None: + img_in_channels = img_in_channels + else: + img_in_channels = img_channels + if img_out_channels is not None: + img_out_channels = img_out_channels + else: + img_out_channels = img_channels + + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels, + out_channels=img_out_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward( + self, + x, + sigma, + condition=None, + class_labels=None, + force_fp32=False, + **model_kwargs, + ): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + arg = c_in * x + + if condition is not None: + arg = torch.cat([arg, condition], dim=1) + + F_x = self.model( + arg.to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class EDMPrecondSRMetaData(ModelMetaData): + """EDMPrecondSR meta data""" + + name: str = "EDMPrecondSR" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecondSR(Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) for super-resolution tasks + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "SongUNetPosEmbd". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + References: + - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNetPosEmbd", + scale_cond_input=True, + **model_kwargs, + ): + super().__init__(meta=EDMPrecondSRMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels # TODO: this is not used, remove it + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + self.scale_cond_input = scale_cond_input + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + **model_kwargs, + ) # TODO needs better handling + self.scaling_fn = self._get_scaling_fn() + + def _get_scaling_fn(self): + if self.scale_cond_input: + warnings.warn( + "scale_cond_input=True does not properly scale the conditional input. " + "(see https://github.com/NVIDIA/modulus/issues/229). " + "This setup will be deprecated. " + "Please set scale_cond_input=False.", + DeprecationWarning, + ) + return self._legacy_scaling_fn + else: + return self._scaling_fn + + @staticmethod + def _scaling_fn(x, img_lr, c_in): + return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) + + @staticmethod + def _legacy_scaling_fn(x, img_lr, c_in): + return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1) + + @nvtx.annotate(message="EDMPrecondSR", color="orange") + def forward( + self, + x, + img_lr, + sigma, + force_fp32=False, + **model_kwargs, + ): + # Concatenate input channels + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + if img_lr is None: + arg = c_in * x + else: + arg = self.scaling_fn(x, img_lr, c_in) + arg = arg.to(dtype) + + F_x = self.model( + arg, + c_noise.flatten(), + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + See EDMPrecond.round_sigma + """ + return EDMPrecond.round_sigma(sigma) + + +class VEPrecond_dfsr(torch.nn.Module): + """ + Preconditioning for dfsr model, modified from class VEPrecond, where the input + argument 'sigma' in forward propagation function is used to receive the timestep + of the backward diffusion process. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. + Advances in neural information processing systems. 2020;33:6840-51. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=self.img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + # print("sigma: ", sigma) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma # Change the definitation of c_noise to avoid -inf values for zero sigma + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + return F_x + + +class VEPrecond_dfsr_cond(torch.nn.Module): + """ + Preconditioning for dfsr model with physics-informed conditioning input, modified + from class VEPrecond, where the input argument 'sigma' in forward propagation function + is used to receive the timestep of the backward diffusion process. The gradient of PDE + residual with respect to the vorticity in the governing Navier-Stokes equation is computed + as the physics-informed conditioning variable and is combined with the backward diffusion + timestep before being sent to the underlying model for noise prediction. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: + [1] Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + [2] Shu D, Li Z, Farimani AB. A physics-informed diffusion model for high-fidelity + flow field reconstruction. Journal of Computational Physics. 2023 Apr 1;478:111972. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=model_kwargs["model_channels"] * 2, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + # modules to embed residual loss + self.conv_in = torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ) + self.emb_conv = torch.nn.Sequential( + torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=1, + stride=1, + padding=0, + ), + torch.nn.GELU(), + torch.nn.Conv2d( + model_kwargs["model_channels"], + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ), + ) + self.dataset_mean = dataset_mean + self.dataset_scale = dataset_scale + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma + + # Compute physics-informed conditioning information using vorticity residual + dx = ( + self.voriticity_residual((x * self.dataset_scale + self.dataset_mean)) + / self.dataset_scale + ) + x = self.conv_in(x) + cond_emb = self.emb_conv(dx) + x = torch.cat((x, cond_emb), dim=1) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + return F_x + + def voriticity_residual(self, w, re=1000.0, dt=1 / 32): + """ + Compute the gradient of PDE residual with respect to a given vorticity w using the + spectrum method. + + Parameters + ---------- + w: torch.Tensor + The fluid flow data sample (vorticity). + re: float + The value of Reynolds number used in the governing Navier-Stokes equation. + dt: float + Time step used to compute the time-derivative of vorticity included in the governing + Navier-Stokes equation. + + Returns + ------- + torch.Tensor + The computed vorticity gradient. + """ + + # w [b t h w] + w = w.clone() + w.requires_grad_(True) + nx = w.size(2) + device = w.device + + w_h = torch.fft.fft2(w[:, 1:-1], dim=[2, 3]) + # Wavenumbers in y-direction + k_max = nx // 2 + N = nx + k_x = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(N, 1) + .repeat(1, N) + .reshape(1, 1, N, N) + ) + k_y = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(1, N) + .repeat(N, 1) + .reshape(1, 1, N, N) + ) + # Negative Laplacian in Fourier space + lap = k_x**2 + k_y**2 + lap[..., 0, 0] = 1.0 + psi_h = w_h / lap + + u_h = 1j * k_y * psi_h + v_h = -1j * k_x * psi_h + wx_h = 1j * k_x * w_h + wy_h = 1j * k_y * w_h + wlap_h = -lap * w_h + + u = torch.fft.irfft2(u_h[..., :, : k_max + 1], dim=[2, 3]) + v = torch.fft.irfft2(v_h[..., :, : k_max + 1], dim=[2, 3]) + wx = torch.fft.irfft2(wx_h[..., :, : k_max + 1], dim=[2, 3]) + wy = torch.fft.irfft2(wy_h[..., :, : k_max + 1], dim=[2, 3]) + wlap = torch.fft.irfft2(wlap_h[..., :, : k_max + 1], dim=[2, 3]) + advection = u * wx + v * wy + + wt = (w[:, 2:, :, :] - w[:, :-2, :, :]) / (2 * dt) + + # establish forcing term + x = torch.linspace(0, 2 * np.pi, nx + 1, device=device) + x = x[0:-1] + X, Y = torch.meshgrid(x, x) + f = -4 * torch.cos(4 * Y) + + residual = wt + (advection - (1.0 / re) * wlap + 0.1 * w[:, 1:-1]) - f + residual_loss = (residual**2).mean() + dw = torch.autograd.grad(residual_loss, w)[0] + + return dw diff --git a/src/models/preconditioning.py b/src/models/preconditioning.py new file mode 100644 index 0000000..52a1660 --- /dev/null +++ b/src/models/preconditioning.py @@ -0,0 +1,1176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preconditioning schemes used in the paper"Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +import importlib +import warnings +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import nvtx +import torch + +from physicsnemo.models.diffusion import ( + DhariwalUNet, # noqa: F401 for globals + SongUNet, # noqa: F401 for globals +) +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + +network_module = importlib.import_module("physicsnemo.models.diffusion") + + +@dataclass +class VPPrecondMetaData(ModelMetaData): + """VPPrecond meta data""" + + name: str = "VPPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VPPrecond(Module): + """ + Preconditioning corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + beta_d : float + Extent of the noise level schedule, by default 19.9. + beta_min : float + Initial slope of the noise level schedule, by default 0.1. + M : int + Original number of timesteps in the DDPM formulation, by default 1000. + epsilon_t : float + Minimum t-value used during training, by default 1e-5. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + beta_d: float = 19.9, + beta_min: float = 0.1, + M: int = 1000, + epsilon_t: float = 1e-5, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__(meta=VPPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.beta_d = beta_d + self.beta_min = beta_min + self.M = M + self.epsilon_t = epsilon_t + self.sigma_min = float(self.sigma(epsilon_t)) + self.sigma_max = float(self.sigma(1)) + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = (self.M - 1) * self.sigma_inv(sigma) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def sigma(self, t: Union[float, torch.Tensor]): + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + def sigma_inv(self, sigma: Union[float, torch.Tensor]): + """ + Compute the inverse of the sigma function for a given sigma. + + This function effectively calculates t from a given sigma(t) based on the + parameters `beta_d` and `beta_min`. + + Parameters + ---------- + sigma : Union[float, torch.Tensor] + The sigma(t) value or set of sigma(t) values for which to compute the + inverse. + + Returns + ------- + torch.Tensor + The computed t value(s) corresponding to the provided sigma(t). + """ + sigma = torch.as_tensor(sigma) + return ( + (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() + - self.beta_min + ) / self.beta_d + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class VEPrecondMetaData(ModelMetaData): + """VEPrecond meta data""" + + name: str = "VEPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VEPrecond(Module): + """ + Preconditioning corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__(meta=VEPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = sigma + c_in = 1 + c_noise = (0.5 * sigma).log() + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class iDDPMPrecondMetaData(ModelMetaData): + """iDDPMPrecond meta data""" + + name: str = "iDDPMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class iDDPMPrecond(Module): + """ + Preconditioning corresponding to the improved DDPM (iDDPM) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + C_1 : float + Timestep adjustment at low noise levels., by default 0.001. + C_2 : float + Timestep adjustment at high noise levels., by default 0.008. + M: int + Original number of timesteps in the DDPM formulation, by default 1000. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Nichol, A.Q. and Dhariwal, P., 2021, July. Improved denoising diffusion + probabilistic models. In International Conference on Machine Learning + (pp. 8162-8171). PMLR. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + C_1=0.001, + C_2=0.008, + M=1000, + model_type="DhariwalUNet", + **model_kwargs, + ): + super().__init__(meta=iDDPMPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.C_1 = C_1 + self.C_2 = C_2 + self.M = M + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels * 2, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + u = torch.zeros(M + 1) + for j in range(M, 0, -1): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) + / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) + - 1 + ).sqrt() + self.register_buffer("u", u) + self.sigma_min = float(u[M - 1]) + self.sigma_max = float(u[0]) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = ( + self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) + ) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x[:, : self.img_channels].to(torch.float32) + return D_x + + def alpha_bar(self, j): + """ + Compute the alpha_bar(j) value for a given j based on the iDDPM formulation. + + Parameters + ---------- + j : Union[int, torch.Tensor] + The timestep or set of timesteps for which to compute alpha_bar(j). + + Returns + ------- + torch.Tensor + The computed alpha_bar(j) value(s). + """ + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + + def round_sigma(self, sigma, return_index=False): + """ + Round the provided sigma value(s) to the nearest value(s) in a + pre-defined set `u`. + + Parameters + ---------- + sigma : Union[float, list, torch.Tensor] + The sigma value(s) to round. + return_index : bool, optional + Whether to return the index/indices of the rounded value(s) in `u` instead + of the rounded value(s) themselves, by default False. + + Returns + ------- + torch.Tensor + The rounded sigma value(s) or their index/indices in `u`, depending on the + value of `return_index`. + """ + sigma = torch.as_tensor(sigma) + index = torch.cdist( + sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), + self.u.reshape(1, -1, 1), + ).argmin(2) + result = index if return_index else self.u[index.flatten()].to(sigma.dtype) + return result.reshape(sigma.shape).to(sigma.device) + + +@dataclass +class EDMPrecondMetaData(ModelMetaData): + """EDMPrecond meta data""" + + name: str = "EDMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecond(Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels (for both input and output). If your model + requires a different number of input or output chanels, + override this by passing either of the optional + img_in_channels or img_out_channels args + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + img_in_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the input + This is useful in the case of additional (conditional) channels + img_out_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the output + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="DhariwalUNet", + img_in_channels=None, + img_out_channels=None, + **model_kwargs, + ): + super().__init__(meta=EDMPrecondMetaData) + self.img_resolution = img_resolution + if img_in_channels is not None: + img_in_channels = img_in_channels + else: + img_in_channels = img_channels + if img_out_channels is not None: + img_out_channels = img_out_channels + else: + img_out_channels = img_channels + + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels, + out_channels=img_out_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward( + self, + x, + sigma, + condition=None, + class_labels=None, + force_fp32=False, + **model_kwargs, + ): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + arg = c_in * x + + if condition is not None: + arg = torch.cat([arg, condition], dim=1) + + F_x = self.model( + arg.to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class EDMPrecondSRMetaData(ModelMetaData): + """EDMPrecondSR meta data""" + + name: str = "EDMPrecondSR" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecondSR(Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) for super-resolution tasks + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "SongUNetPosEmbd". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + References: + - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNetPosEmbd", + scale_cond_input=True, + **model_kwargs, + ): + super().__init__(meta=EDMPrecondSRMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels # TODO: this is not used, remove it + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + self.scale_cond_input = scale_cond_input + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + **model_kwargs, + ) # TODO needs better handling + self.scaling_fn = self._get_scaling_fn() + + def _get_scaling_fn(self): + if self.scale_cond_input: + warnings.warn( + "scale_cond_input=True does not properly scale the conditional input. " + "(see https://github.com/NVIDIA/modulus/issues/229). " + "This setup will be deprecated. " + "Please set scale_cond_input=False.", + DeprecationWarning, + ) + return self._legacy_scaling_fn + else: + return self._scaling_fn + + @staticmethod + def _scaling_fn(x, img_lr, c_in): + return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) + + @staticmethod + def _legacy_scaling_fn(x, img_lr, c_in): + return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1) + + @nvtx.annotate(message="EDMPrecondSR", color="orange") + def forward( + self, + x, + img_lr, + sigma, + force_fp32=False, + **model_kwargs, + ): + # Concatenate input channels + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + if img_lr is None: + arg = c_in * x + else: + arg = self.scaling_fn(x, img_lr, c_in) + arg = arg.to(dtype) + + F_x = self.model( + arg, + c_noise.flatten(), + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + See EDMPrecond.round_sigma + """ + return EDMPrecond.round_sigma(sigma) + + +class VEPrecond_dfsr(torch.nn.Module): + """ + Preconditioning for dfsr model, modified from class VEPrecond, where the input + argument 'sigma' in forward propagation function is used to receive the timestep + of the backward diffusion process. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. + Advances in neural information processing systems. 2020;33:6840-51. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=self.img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + # print("sigma: ", sigma) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma # Change the definitation of c_noise to avoid -inf values for zero sigma + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + return F_x + + +class VEPrecond_dfsr_cond(torch.nn.Module): + """ + Preconditioning for dfsr model with physics-informed conditioning input, modified + from class VEPrecond, where the input argument 'sigma' in forward propagation function + is used to receive the timestep of the backward diffusion process. The gradient of PDE + residual with respect to the vorticity in the governing Navier-Stokes equation is computed + as the physics-informed conditioning variable and is combined with the backward diffusion + timestep before being sent to the underlying model for noise prediction. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: + [1] Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + [2] Shu D, Li Z, Farimani AB. A physics-informed diffusion model for high-fidelity + flow field reconstruction. Journal of Computational Physics. 2023 Apr 1;478:111972. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=model_kwargs["model_channels"] * 2, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + # modules to embed residual loss + self.conv_in = torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ) + self.emb_conv = torch.nn.Sequential( + torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=1, + stride=1, + padding=0, + ), + torch.nn.GELU(), + torch.nn.Conv2d( + model_kwargs["model_channels"], + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ), + ) + self.dataset_mean = dataset_mean + self.dataset_scale = dataset_scale + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma + + # Compute physics-informed conditioning information using vorticity residual + dx = ( + self.voriticity_residual((x * self.dataset_scale + self.dataset_mean)) + / self.dataset_scale + ) + x = self.conv_in(x) + cond_emb = self.emb_conv(dx) + x = torch.cat((x, cond_emb), dim=1) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + return F_x + + def voriticity_residual(self, w, re=1000.0, dt=1 / 32): + """ + Compute the gradient of PDE residual with respect to a given vorticity w using the + spectrum method. + + Parameters + ---------- + w: torch.Tensor + The fluid flow data sample (vorticity). + re: float + The value of Reynolds number used in the governing Navier-Stokes equation. + dt: float + Time step used to compute the time-derivative of vorticity included in the governing + Navier-Stokes equation. + + Returns + ------- + torch.Tensor + The computed vorticity gradient. + """ + + # w [b t h w] + w = w.clone() + w.requires_grad_(True) + nx = w.size(2) + device = w.device + + w_h = torch.fft.fft2(w[:, 1:-1], dim=[2, 3]) + # Wavenumbers in y-direction + k_max = nx // 2 + N = nx + k_x = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(N, 1) + .repeat(1, N) + .reshape(1, 1, N, N) + ) + k_y = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(1, N) + .repeat(N, 1) + .reshape(1, 1, N, N) + ) + # Negative Laplacian in Fourier space + lap = k_x**2 + k_y**2 + lap[..., 0, 0] = 1.0 + psi_h = w_h / lap + + u_h = 1j * k_y * psi_h + v_h = -1j * k_x * psi_h + wx_h = 1j * k_x * w_h + wy_h = 1j * k_y * w_h + wlap_h = -lap * w_h + + u = torch.fft.irfft2(u_h[..., :, : k_max + 1], dim=[2, 3]) + v = torch.fft.irfft2(v_h[..., :, : k_max + 1], dim=[2, 3]) + wx = torch.fft.irfft2(wx_h[..., :, : k_max + 1], dim=[2, 3]) + wy = torch.fft.irfft2(wy_h[..., :, : k_max + 1], dim=[2, 3]) + wlap = torch.fft.irfft2(wlap_h[..., :, : k_max + 1], dim=[2, 3]) + advection = u * wx + v * wy + + wt = (w[:, 2:, :, :] - w[:, :-2, :, :]) / (2 * dt) + + # establish forcing term + x = torch.linspace(0, 2 * np.pi, nx + 1, device=device) + x = x[0:-1] + X, Y = torch.meshgrid(x, x) + f = -4 * torch.cos(4 * Y) + + residual = wt + (advection - (1.0 / re) * wlap + 0.1 * w[:, 1:-1]) - f + residual_loss = (residual**2).mean() + dw = torch.autograd.grad(residual_loss, w)[0] + + return dw diff --git a/src/models/song_unet.py b/src/models/song_unet.py new file mode 100644 index 0000000..d38484b --- /dev/null +++ b/src/models/song_unet.py @@ -0,0 +1,906 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architectures used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import nvtx +import torch +from torch.nn.functional import silu +from torch.utils.checkpoint import checkpoint + +from physicsnemo.models.diffusion import ( + Conv2d, + FourierEmbedding, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + + +@dataclass +class MetaData(ModelMetaData): + name: str = "SongUNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class SongUNet(Module): + """ + Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with + optional self-attention, embeddings, and encoder-decoder components. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters + ----------- + img_resolution : Union[List[int], int] + The resolution of the input/output image, 1 value represents a square image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [16]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.10. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default + 'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default + 'standard'. + resample_filter : List[int], optional (default=[1,1]) + Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. + checkpoint_level : int, optional (default=0) + How many layers should use gradient checkpointing, 0 is None + additive_pos_embed: bool = False, + Set to True to add a learned position embedding after the first conv (used in StormCast) + + + Reference + ---------- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + Note + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example + -------- + >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [16], + dropout: float = 0.10, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + checkpoint_level: int = 0, + additive_pos_embed: bool = False, + ): + valid_embedding_types = ["fourier", "positional", "zero"] + if embedding_type not in valid_embedding_types: + raise ValueError( + f"Invalid embedding_type: {embedding_type}. Must be one of {valid_embedding_types}." + ) + + valid_encoder_types = ["standard", "skip", "residual"] + if encoder_type not in valid_encoder_types: + raise ValueError( + f"Invalid encoder_type: {encoder_type}. Must be one of {valid_encoder_types}." + ) + + valid_decoder_types = ["standard", "skip"] + if decoder_type not in valid_decoder_types: + raise ValueError( + f"Invalid decoder_type: {decoder_type}. Must be one of {valid_decoder_types}." + ) + + super().__init__(meta=MetaData()) + self.label_dropout = label_dropout + self.embedding_type = embedding_type + emb_channels = model_channels * channel_mult_emb + self.emb_channels = emb_channels + noise_channels = model_channels * channel_mult_noise + init = dict(init_mode="xavier_uniform") + init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) + init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) + block_kwargs = dict( + emb_channels=emb_channels, + num_heads=1, + dropout=dropout, + skip_scale=np.sqrt(0.5), + eps=1e-6, + resample_filter=resample_filter, + resample_proj=True, + adaptive_scale=False, + init=init, + init_zero=init_zero, + init_attn=init_attn, + ) + + # for compatibility with older versions that took only 1 dimension + self.img_resolution = img_resolution + if isinstance(img_resolution, int): + self.img_shape_y = self.img_shape_x = img_resolution + else: + self.img_shape_y = img_resolution[0] + self.img_shape_x = img_resolution[1] + + # set the threshold for checkpointing based on image resolution + self.checkpoint_threshold = (self.img_shape_y >> checkpoint_level) + 1 + + # Optional additive learned positition embed after the first conv + self.additive_pos_embed = additive_pos_embed + if self.additive_pos_embed: + self.spatial_emb = torch.nn.Parameter( + torch.randn(1, model_channels, self.img_shape_y, self.img_shape_x) + ) + torch.nn.init.trunc_normal_(self.spatial_emb, std=0.02) + + # Mapping. + if self.embedding_type != "zero": + self.map_noise = ( + PositionalEmbedding(num_channels=noise_channels, endpoint=True) + if embedding_type == "positional" + else FourierEmbedding(num_channels=noise_channels) + ) + self.map_label = ( + Linear(in_features=label_dim, out_features=noise_channels, **init) + if label_dim + else None + ) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=noise_channels, + bias=False, + **init, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=noise_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + caux = in_channels + for level, mult in enumerate(channel_mult): + res = self.img_shape_y >> level + if level == 0: + cin = cout + cout = model_channels + self.enc[f"{res}x{res}_conv"] = Conv2d( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + if encoder_type == "skip": + self.enc[f"{res}x{res}_aux_down"] = Conv2d( + in_channels=caux, + out_channels=caux, + kernel=0, + down=True, + resample_filter=resample_filter, + ) + self.enc[f"{res}x{res}_aux_skip"] = Conv2d( + in_channels=caux, out_channels=cout, kernel=1, **init + ) + if encoder_type == "residual": + self.enc[f"{res}x{res}_aux_residual"] = Conv2d( + in_channels=caux, + out_channels=cout, + kernel=3, + down=True, + resample_filter=resample_filter, + fused_resample=True, + **init, + ) + caux = cout + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + attn = res in attn_resolutions + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + skips = [ + block.out_channels for name, block in self.enc.items() if "aux" not in name + ] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = self.img_shape_y >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + attn = idx == num_blocks and res in attn_resolutions + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + if decoder_type == "skip" or level == 0: + if decoder_type == "skip" and level < len(channel_mult) - 1: + self.dec[f"{res}x{res}_aux_up"] = Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel=0, + up=True, + resample_filter=resample_filter, + ) + self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( + num_channels=cout, eps=1e-6 + ) + self.dec[f"{res}x{res}_aux_conv"] = Conv2d( + in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + ) + + @nvtx.annotate(message="SongUNet", color="blue") + def forward(self, x, noise_labels, class_labels, augment_labels=None): + if self.embedding_type != "zero": + # Mapping. + emb = self.map_noise(noise_labels) + emb = ( + emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) + ) # swap sin/cos + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) + >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = silu(self.map_layer1(emb)) + else: + emb = torch.zeros( + (noise_labels.shape[0], self.emb_channels), device=x.device + ) + + # Encoder. + skips = [] + aux = x + for name, block in self.enc.items(): + with nvtx.annotate(f"SongUNet encoder: {name}", color="blue"): + if "aux_down" in name: + aux = block(aux) + elif "aux_skip" in name: + x = skips[-1] = x + block(aux) + elif "aux_residual" in name: + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + elif "_conv" in name: + x = block(x) + if self.additive_pos_embed: + x = x + self.spatial_emb.to(dtype=x.dtype) + skips.append(x) + else: + # For UNetBlocks check if we should use gradient checkpointing + if isinstance(block, UNetBlock): + if x.shape[-1] > self.checkpoint_threshold: + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + else: + x = block(x) + skips.append(x) + + # Decoder. + aux = None + tmp = None + for name, block in self.dec.items(): + with nvtx.annotate(f"SongUNet decoder: {name}", color="blue"): + if "aux_up" in name: + aux = block(aux) + elif "aux_norm" in name: + tmp = block(x) + elif "aux_conv" in name: + tmp = block(silu(tmp)) + aux = tmp if aux is None else tmp + aux + else: + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + # check for checkpointing on decoder blocks and up sampling blocks + if ( + x.shape[-1] > self.checkpoint_threshold and "_block" in name + ) or ( + x.shape[-1] > (self.checkpoint_threshold / 2) and "_up" in name + ): + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + return aux + + +class SongUNetPosEmbd(SongUNet): + """ + Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with + optional self-attention,embeddings, and encoder-decoder components. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters + ----------- + img_resolution : Union[List[int], int] + The resolution of the input/output image, 1 value represents a square image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [16]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.13. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default + 'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default + 'standard'. + resample_filter : List[int], optional (default=[1,1]) + Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. + + + Reference + ---------- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + Note + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example + -------- + >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [28], + dropout: float = 0.13, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + gridtype: str = "sinusoidal", + N_grid_channels: int = 4, + checkpoint_level: int = 0, + ): + super().__init__( + img_resolution, + in_channels, + out_channels, + label_dim, + augment_dim, + model_channels, + channel_mult, + channel_mult_emb, + num_blocks, + attn_resolutions, + dropout, + label_dropout, + embedding_type, + channel_mult_noise, + encoder_type, + decoder_type, + resample_filter, + checkpoint_level, + ) + + self.gridtype = gridtype + self.N_grid_channels = N_grid_channels + self.pos_embd = self._get_positional_embedding() + + @nvtx.annotate(message="SongUNet", color="blue") + def forward( + self, x, noise_labels, class_labels, global_index=None, augment_labels=None + ): + # append positional embedding to input conditioning + if self.pos_embd is not None: + selected_pos_embd = self.positional_embedding_indexing(x, global_index) + x = torch.cat((x, selected_pos_embd), dim=1) + + return super().forward(x, noise_labels, class_labels, augment_labels) + + def positional_embedding_indexing(self, x, global_index): + if global_index is None: + selected_pos_embd = ( + self.pos_embd.to(x.dtype) + .to(x.device)[None] + .expand((x.shape[0], -1, -1, -1)) + ) + else: + B = global_index.shape[0] + X = global_index.shape[2] + Y = global_index.shape[3] + global_index = torch.reshape( + torch.permute(global_index, (1, 0, 2, 3)), (2, -1) + ) # (B, 2, X, Y) to (2, B*X*Y) + selected_pos_embd = self.pos_embd.to(x.device)[ + :, global_index[0], global_index[1] + ] # (N_pe, B*X*Y) + selected_pos_embd = ( + torch.permute( + torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], B, X, Y)), + (1, 0, 2, 3), + ) + .to(x.device) + .to(x.dtype) + ) # (B, N_pe, X, Y) + return selected_pos_embd + + def _get_positional_embedding(self): + if self.N_grid_channels == 0: + return None + elif self.gridtype == "learnable": + grid = torch.nn.Parameter( + torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) + ) + elif self.gridtype == "linear": + if self.N_grid_channels != 2: + raise ValueError("N_grid_channels must be set to 2 for gridtype linear") + x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) + y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) + grid_x, grid_y = np.meshgrid(y, x) + grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: + # print('sinusuidal grid added ......') + x1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_y))) + x2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_y))) + y1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_x))) + y2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_x))) + grid_x1, grid_y1 = np.meshgrid(y1, x1) + grid_x2, grid_y2 = np.meshgrid(y2, x2) + grid = torch.squeeze( + torch.from_numpy( + np.expand_dims( + np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 + ) + ) + ) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: + if self.N_grid_channels % 4 != 0: + raise ValueError("N_grid_channels must be a factor of 4") + num_freq = self.N_grid_channels // 4 + freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) + grid_list = [] + grid_x, grid_y = np.meshgrid( + np.linspace(0, 2 * np.pi, self.img_shape_x), + np.linspace(0, 2 * np.pi, self.img_shape_y), + ) + for freq in freq_bands: + for p_fn in [np.sin, np.cos]: + grid_list.append(p_fn(grid_x * freq)) + grid_list.append(p_fn(grid_y * freq)) + grid = torch.from_numpy(np.stack(grid_list, axis=0)) + grid.requires_grad = False + elif self.gridtype == "test" and self.N_grid_channels == 2: + idx_x = torch.arange(self.img_shape_y) + idx_y = torch.arange(self.img_shape_x) + mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) + grid = torch.stack((mesh_x, mesh_y), dim=0) + else: + raise ValueError("Gridtype not supported.") + return grid + + +class SongUNetPosLtEmbd(SongUNet): + """ + This model is adapated from SongUNetPosEmbd, with the incoporatation of lead-time aware + embedding for the GEFS-HRRR model. The lead-time embedding is activated by setting the + lead_time_channels and lead_time_steps parameters. + + Parameters + ----------- + img_resolution : Union[List[int], int] + The resolution of the input/output image, 1 value represents a square image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [16]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.13. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default + 'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default + 'standard'. + resample_filter : List[int], optional (default=[1,1]) + Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. + lead_time_channels: int, optional + Length of lead time embedding vector + lead_time_steps: int, optional + Total number of lead times + + + Reference + ---------- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + Note + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example + -------- + >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [28], + dropout: float = 0.13, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + gridtype: str = "sinusoidal", + N_grid_channels: int = 4, + lead_time_channels: int = None, + lead_time_steps: int = 9, + prob_channels: List[int] = [], + checkpoint_level: int = 0, + ): + super().__init__( + img_resolution, + in_channels, + out_channels, + label_dim, + augment_dim, + model_channels, + channel_mult, + channel_mult_emb, + num_blocks, + attn_resolutions, + dropout, + label_dropout, + embedding_type, + channel_mult_noise, + encoder_type, + decoder_type, + resample_filter, + checkpoint_level, + ) + + self.gridtype = gridtype + self.N_grid_channels = N_grid_channels + self.pos_embd = self._get_positional_embedding() + self.lead_time_channels = lead_time_channels + self.lead_time_steps = lead_time_steps + self.lt_embd = self._get_lead_time_embedding() + self.prob_channels = prob_channels + if self.prob_channels: + self.scalar = torch.nn.Parameter( + torch.ones((1, len(self.prob_channels), 1, 1)) + ) + + @nvtx.annotate(message="SongUNet", color="blue") + def forward( + self, + x, + noise_labels, + class_labels, + lead_time_label=None, + global_index=None, + augment_labels=None, + ): + # append positional embedding to input conditioning + embeds = [] + if self.pos_embd is not None: + embeds.append(self.pos_embd.to(x.device)) + if self.lt_embd is not None: + embeds.append( + torch.reshape( + self.lt_embd[lead_time_label.int()], + (self.lead_time_channels, self.img_shape_y, self.img_shape_x), + ).to(x.device) + ) + if len(embeds) > 0: + embeds = torch.cat(embeds, dim=0) + selected_pos_embd = self.positional_embedding_indexing( + x, embeds, global_index + ) + x = torch.cat((x, selected_pos_embd), dim=1) + out = super().forward(x, noise_labels, class_labels, augment_labels) + # if training mode, let crossEntropyLoss do softmax. The model outputs logits. + # if eval mode, the model outputs probability + all_channels = list(range(out.shape[1])) # [0, 1, 2, ..., 10] + scalar_channels = [ + item for item in all_channels if item not in self.prob_channels + ] + if self.prob_channels and (not self.training): + out_final = torch.cat( + ( + out[:, scalar_channels], + (out[:, self.prob_channels] * self.scalar).softmax(dim=1), + ), + dim=1, + ) + elif self.prob_channels and self.training: + out_final = torch.cat( + (out[:, scalar_channels], (out[:, self.prob_channels] * self.scalar)), + dim=1, + ) + else: + out_final = out + return out_final + + def positional_embedding_indexing(self, x, pos_embd, global_index): + if global_index is None: + selected_pos_embd = ( + pos_embd.to(x.dtype).to(x.device)[None].expand((x.shape[0], -1, -1, -1)) + ) + else: + B = global_index.shape[0] + X = global_index.shape[2] + Y = global_index.shape[3] + global_index = torch.reshape( + torch.permute(global_index, (1, 0, 2, 3)), (2, -1) + ) # (B, 2, X, Y) to (2, B*X*Y) + selected_pos_embd = pos_embd.to(x.device)[ + :, global_index[0], global_index[1] + ] # (N_pe, B*X*Y) + selected_pos_embd = ( + torch.permute( + torch.reshape(selected_pos_embd, (pos_embd.shape[0], B, X, Y)), + (1, 0, 2, 3), + ) + .to(x.device) + .to(x.dtype) + ) # (B, N_pe, X, Y) + return selected_pos_embd + + def _get_positional_embedding(self): + if self.N_grid_channels == 0: + return None + elif self.gridtype == "learnable": + grid = torch.nn.Parameter( + torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) + ) + elif self.gridtype == "linear": + if self.N_grid_channels != 2: + raise ValueError("N_grid_channels must be set to 2 for gridtype linear") + x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) + y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) + grid_x, grid_y = np.meshgrid(y, x) + grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: + # print('sinusuidal grid added ......') + x1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_y))) + x2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_y))) + y1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_x))) + y2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_x))) + grid_x1, grid_y1 = np.meshgrid(y1, x1) + grid_x2, grid_y2 = np.meshgrid(y2, x2) + grid = torch.squeeze( + torch.from_numpy( + np.expand_dims( + np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 + ) + ) + ) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: + if self.N_grid_channels % 4 != 0: + raise ValueError("N_grid_channels must be a factor of 4") + num_freq = self.N_grid_channels // 4 + freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) + grid_list = [] + grid_x, grid_y = np.meshgrid( + np.linspace(0, 2 * np.pi, self.img_shape_x), + np.linspace(0, 2 * np.pi, self.img_shape_y), + ) + for freq in freq_bands: + for p_fn in [np.sin, np.cos]: + grid_list.append(p_fn(grid_x * freq)) + grid_list.append(p_fn(grid_y * freq)) + grid = torch.from_numpy(np.stack(grid_list, axis=0)) + grid.requires_grad = False + elif self.gridtype == "test" and self.N_grid_channels == 2: + idx_x = torch.arange(self.img_shape_y) + idx_y = torch.arange(self.img_shape_x) + mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) + grid = torch.stack((mesh_x, mesh_y), dim=0) + else: + raise ValueError("Gridtype not supported.") + return grid + + def _get_lead_time_embedding(self): + if (self.lead_time_steps is None) or (self.lead_time_channels is None): + return None + grid = torch.nn.Parameter( + torch.randn( + self.lead_time_steps, + self.lead_time_channels, + self.img_shape_y, + self.img_shape_x, + ) + ) + return grid diff --git a/src/models/unet.py b/src/models/unet.py new file mode 100644 index 0000000..7270606 --- /dev/null +++ b/src/models/unet.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from dataclasses import dataclass + +import torch + +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + +network_module = importlib.import_module("physicsnemo.models.diffusion") + + +@dataclass +class MetaData(ModelMetaData): + name: str = "UNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class UNet(Module): # TODO a lot of redundancy, need to clean up + """ + U-Net Wrapper for CorrDiff. + + Parameters + ----------- + img_resolution : int + The resolution of the input/output image. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16: bool, optional + Execute the underlying model at FP16 precision?, by default False. + sigma_min: float, optional + Minimum supported noise level, by default 0. + sigma_max: float, optional + Maximum supported noise level, by default float('inf'). + sigma_data: float, optional + Expected standard deviation of the training data, by default 0.5. + model_type: str, optional + Class name of the underlying model, by default 'DhariwalUNet'. + **model_kwargs : dict + Keyword arguments for the underlying model. + + + References + ---------- + Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNetPosEmbd", + **model_kwargs, + ): + super().__init__(meta=MetaData) + + self.img_channels = img_channels + + # for compatibility with older versions that took only 1 dimension + if isinstance(img_resolution, int): + self.img_shape_x = self.img_shape_y = img_resolution + else: + self.img_shape_x = img_resolution[0] + self.img_shape_y = img_resolution[1] + + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + **model_kwargs, + ) + + def forward(self, x, img_lr, sigma, force_fp32=False, **model_kwargs): + # SR: concatenate input channels + if img_lr is not None: + x = torch.cat((x, img_lr), dim=1) + + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + F_x = self.model( + x.to(dtype), # (c_in * x).to(dtype), + torch.zeros( + sigma.numel(), dtype=sigma.dtype, device=sigma.device + ), # c_noise.flatten() + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + # skip connection - for SR there's size mismatch bwtween input and output + D_x = F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +class StormCastUNet(Module): + """ + U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. + + Parameters + ----------- + img_resolution : int or List[int] + The resolution of the input/output image. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16: bool, optional + Execute the underlying model at FP16 precision?, by default False. + sigma_min: float, optional + Minimum supported noise level, by default 0. + sigma_max: float, optional + Maximum supported noise level, by default float('inf'). + sigma_data: float, optional + Expected standard deviation of the training data, by default 0.5. + model_type: str, optional + Class name of the underlying model, by default 'DhariwalUNet'. + **model_kwargs : dict + Keyword arguments for the underlying model. + + """ + + def __init__( + self, + img_resolution, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNet", + **model_kwargs, + ): + super().__init__(meta=MetaData("StormCastUNet")) + + if isinstance(img_resolution, int): + self.img_shape_x = self.img_shape_y = img_resolution + else: + self.img_shape_x = img_resolution[0] + self.img_shape_y = img_resolution[1] + + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels, + out_channels=img_out_channels, + **model_kwargs, + ) + + def forward(self, x, force_fp32=False, **model_kwargs): + """Run a forward pass of the StormCast regression U-Net. + + Args: + x (torch.Tensor): input to the U-Net + force_fp32 (bool, optional): force casting to fp_32 if True. Defaults to False. + + Raises: + ValueError: If input data type is a mismatch with provided options + + Returns: + D_x (torch.Tensor): Output (prediction) of the U-Net + """ + + x = x.to(torch.float32) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + F_x = self.model( + x.to(dtype), + torch.zeros(x.shape[0], dtype=x.dtype, device=x.device), + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = F_x.to(torch.float32) + return D_x diff --git a/src/models/utils.py b/src/models/utils.py new file mode 100644 index 0000000..e1cde9d --- /dev/null +++ b/src/models/utils.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + + +def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): + """ + Unified routine for initializing weights and biases. + This function provides a unified interface for various weight initialization + strategies like Xavier (Glorot) and Kaiming (He) initializations. + + Parameters + ---------- + shape : tuple + The shape of the tensor to initialize. It could represent weights or biases + of a layer in a neural network. + mode : str + The mode/type of initialization to use. Supported values are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + fan_in : int + The number of input units in the weight tensor. For convolutional layers, + this typically represents the number of input channels times the kernel height + times the kernel width. + fan_out : int + The number of output units in the weight tensor. For convolutional layers, + this typically represents the number of output channels times the kernel height + times the kernel width. + + Returns + ------- + torch.Tensor + The initialized tensor based on the specified mode. + + Raises + ------ + ValueError + If the provided `mode` is not one of the supported initialization modes. + """ + if mode == "xavier_uniform": + return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == "xavier_normal": + return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == "kaiming_uniform": + return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == "kaiming_normal": + return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') From 1f404b38d32d2e6f3a187696bf5c942d14fa1bc6 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 11 Apr 2025 17:29:49 +0200 Subject: [PATCH 005/189] add utils and change imports --- src/distributed/__init__.py | 1 + src/distributed/config.py | 247 ++++++ src/distributed/manager.py | 775 ++++++++++++++++++ src/models/__init__.py | 3 + src/models/layers.py | 2 +- src/models/preconditioning copy.py | 1176 ---------------------------- src/models/preconditioning.py | 18 +- src/models/song_unet.py | 6 +- src/models/unet.py | 8 +- src/utils/capture.py | 513 ++++++++++++ src/utils/checkpoint.py | 398 ++++++++++ src/utils/console.py | 88 +++ src/utils/deterministic_sampler.py | 231 ++++++ src/utils/function_utils.py | 775 ++++++++++++++++++ src/utils/inference_utils.py | 253 ++++++ src/utils/model_utils.py | 66 ++ src/utils/stochastic_sampler.py | 533 +++++++++++++ src/utils/train_helpers.py | 107 +++ 18 files changed, 4007 insertions(+), 1193 deletions(-) create mode 100644 src/distributed/__init__.py create mode 100644 src/distributed/config.py create mode 100644 src/distributed/manager.py create mode 100644 src/models/__init__.py delete mode 100644 src/models/preconditioning copy.py create mode 100644 src/utils/capture.py create mode 100644 src/utils/checkpoint.py create mode 100644 src/utils/console.py create mode 100644 src/utils/deterministic_sampler.py create mode 100644 src/utils/function_utils.py create mode 100644 src/utils/inference_utils.py create mode 100644 src/utils/model_utils.py create mode 100644 src/utils/stochastic_sampler.py create mode 100644 src/utils/train_helpers.py diff --git a/src/distributed/__init__.py b/src/distributed/__init__.py new file mode 100644 index 0000000..0da01f3 --- /dev/null +++ b/src/distributed/__init__.py @@ -0,0 +1 @@ +from .manager import DistributedManager \ No newline at end of file diff --git a/src/distributed/config.py b/src/distributed/config.py new file mode 100644 index 0000000..c5414b4 --- /dev/null +++ b/src/distributed/config.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Union + +from treelib import Tree + + +class ProcessGroupNode: + """ + Class to store the attributes of a distributed process group + + Attributes + ---------- + name : str + Name of the process group + size : Optional[int] + Optional, number of processes in the process group + """ + + def __init__( + self, + name: str, + size: Optional[int] = None, + ): + """ + Constructor for the ProcessGroupNode class + + Parameters + ---------- + name : str + Name of the process group + size : Optional[int] + Optional, size of the process group + """ + self.name = name + self.size = size + + def __str__(self): + """ + String representation of the process group node + + Returns + ------- + str + String representation of the process group node + """ + return "ProcessGroupNode(" f"name={self.name}, " f"size={self.size}, " + + def __repr__(self): + """ + String representation of the process group node + + Returns + ------- + str + String representation of the process group node + """ + return self.__str__() + + +class ProcessGroupConfig: + """ + Class to define the configuration of a model's parallel process group structure as a + tree. Each node of the tree is of type `ProcessGroupNode`. + + Once the process group config structure (i.e, the tree structure) is set, it is + sufficient to set only the sizes for each leaf process group. Then, the size of + every parent group can be automatically computed as the product reduction of the + sub-tree of that parent group node. + + Examples + -------- + >>> from physicsnemo.distributed import ProcessGroupNode, ProcessGroupConfig + >>> + >>> # Create world group that contains all processes that are part of this job + >>> world = ProcessGroupNode("world") + >>> + >>> # Create the process group config with the highest level process group + >>> config = ProcessGroupConfig(world) + >>> + >>> # Create model and data parallel sub-groups + >>> # Sub-groups of a single node are guaranteed to be orthogonal by construction + >>> # Nodes can be added with either the name of the node or the node itself + >>> config.add_node(ProcessGroupNode("model_parallel"), parent=world) + >>> config.add_node(ProcessGroupNode("data_parallel"), parent="world") + >>> + >>> # Create spatial and channel parallel sub-groups + >>> config.add_node(ProcessGroupNode("spatial_parallel"), parent="model_parallel") + >>> config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel") + >>> + >>> config.leaf_groups() + ['data_parallel', 'spatial_parallel', 'channel_parallel'] + >>> + >>> # Set leaf group sizes + >>> # Note: product of all leaf-node sizes should be the world size + >>> group_sizes = {"channel_parallel": 3, "spatial_parallel": 2, "data_parallel": 4} + >>> config.set_leaf_group_sizes(group_sizes) # Update all parent group sizes too + >>> config.get_node("model_parallel").size + 6 + """ + + def __init__(self, node: ProcessGroupNode): + """ + Constructor to the ProcessGroupConfig class + + Parameters + ---------- + node : ProcessGroupNode + Root node of the tree, typically would be 'world' + Note, it is generally recommended to set the child groups for 'world' + to 'model_parallel' and 'data_parallel' to aid with distributed + data parallel training unless there is a specific reason to choose a + different structure + """ + self.root = node + self.root_id = node.name + self.tree = Tree() + self.tree.create_node(node.name, node.name, data=node) + + def add_node(self, node: ProcessGroupNode, parent=Union[str, ProcessGroupNode]): + """ + Add a node to the process group config + + Parameters + ---------- + node : ProcessGroupNode + The new node to be added to the config + parent : Union[str, ProcessGroupNode] + Parent node of the node to be added. Should already be in the config. + If str, it is the name of the parent node. Otherwise, the parent + ProcessGroupNode itself. + """ + if isinstance(parent, ProcessGroupNode): + parent = parent.name + self.tree.create_node(node.name, node.name, data=node, parent=parent) + + def get_node(self, name: str) -> ProcessGroupNode: + """ + Method to get the node given the name of the node + + Parameters + ---------- + name : str + Name of the node to retrieve + + Returns + ------- + ProcessGroupNode + Node with the given name from the config + """ + return self.tree.get_node(name).data + + def update_parent_sizes(self, verbose: bool = False) -> int: + """ + Method to update parent node sizes after setting the sizes for each leaf node + + Parameters + ---------- + verbose : bool + If True, print a message each time a parent node size was updated + + Returns + ------- + int + Size of the root node + """ + return _tree_product_reduction(self.tree, self.root_id, verbose=verbose) + + def leaf_groups(self) -> List[str]: + """ + Get a list of all leaf group names + + Returns + ------- + List[str] + List of all leaf node names + """ + return [n.identifier for n in self.tree.leaves()] + + def set_leaf_group_sizes( + self, group_sizes: Dict[str, int], update_parent_sizes: bool = True + ): + """ + Set process group sizes for all leaf groups + + Parameters + ---------- + group_sizes : Dict[str, int] + Dictionary with a mapping of each leaf group name to its size + update_parent_sizes : bool + Update all parent group sizes based on the leaf group if True + If False, only set the leaf group sizes. + """ + for id, size in group_sizes.items(): + if not self.tree.contains(id): + raise AssertionError( + f"Process group {id} is not in this process group config" + ) + node = self.tree.get_node(id) + if not node.is_leaf(): + raise AssertionError(f"Process group {id} is not a leaf group") + node.data.size = size + + if update_parent_sizes: + self.update_parent_sizes() + + +def _tree_product_reduction(tree, node_id, verbose=False): + """ + Function to traverse a tree and compute the product reduction of + the sub-tree for each node starting from `node_id` + """ + children = tree.children(node_id) + node = tree.get_node(node_id) + if not children: + if node.data.size is None: + raise AssertionError("Leaf nodes should have a valid size set") + return node.data.size + + product = 1 + + for child in children: + product *= _tree_product_reduction(tree, child.identifier) + + if node.data.size != product: + if verbose: + print( + "Updating size of node " + f"{node.data.name} from {node.data.size} to {product}" + ) + node.data.size = product + + return product diff --git a/src/distributed/manager.py b/src/distributed/manager.py new file mode 100644 index 0000000..facb466 --- /dev/null +++ b/src/distributed/manager.py @@ -0,0 +1,775 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import os +import queue +import warnings +from typing import Optional, Tuple +from warnings import warn + +import numpy as np +import torch +import torch.distributed as dist + +from src.distributed.config import ProcessGroupConfig, ProcessGroupNode + +warnings.simplefilter("default", DeprecationWarning) + + +class UndefinedGroupError(Exception): + """Exception for querying an undefined process group using the PhysicsNeMo DistributedManager""" + + def __init__(self, name: str): + """ + + Parameters + ---------- + name : str + Name of the process group being queried. + + """ + message = ( + f"Cannot query process group '{name}' before it is explicitly created." + ) + super().__init__(message) + + +class UninitializedDistributedManagerWarning(Warning): + """Warning to indicate usage of an uninitialized DistributedManager""" + + def __init__(self): + message = ( + "A DistributedManager object is being instantiated before " + + "this singleton class has been initialized. Instantiating a manager before " + + "initialization can lead to unexpected results where processes fail " + + "to communicate. Initialize the distributed manager via " + + "DistributedManager.initialize() before instantiating." + ) + super().__init__(message) + + +class DistributedManager(object): + """Distributed Manager for setting up distributed training environment. + + This is a singleton that creates a persistance class instance for storing parallel + environment information through out the life time of the program. This should be + used to help set up Distributed Data Parallel and parallel datapipes. + + Note + ---- + One should call `DistributedManager.initialize()` prior to constructing a manager + object + + Example + ------- + >>> DistributedManager.initialize() + >>> manager = DistributedManager() + >>> manager.rank + 0 + >>> manager.world_size + 1 + """ + + _shared_state = {} + + def __new__(cls): + obj = super(DistributedManager, cls).__new__(cls) + obj.__dict__ = cls._shared_state + + # Set the defaults + if not hasattr(obj, "_rank"): + obj._rank = 0 + if not hasattr(obj, "_world_size"): + obj._world_size = 1 + if not hasattr(obj, "_local_rank"): + obj._local_rank = 0 + if not hasattr(obj, "_distributed"): + obj._distributed = False + if not hasattr(obj, "_device"): + obj._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if not hasattr(obj, "_cuda"): + obj._cuda = torch.cuda.is_available() + if not hasattr(obj, "_broadcast_buffers"): + obj._broadcast_buffers = False + if not hasattr(obj, "_find_unused_parameters"): + obj._find_unused_parameters = False + if not hasattr(obj, "_initialization_method"): + obj._initialization_method = "None" + if not hasattr(obj, "_groups"): + obj._groups = {} + if not hasattr(obj, "_group_ranks"): + obj._group_ranks = {} + if not hasattr(obj, "_group_names"): + obj._group_names = {} + if not hasattr(obj, "_is_initialized"): + obj._is_initialized = False + if not hasattr(obj, "_global_mesh"): + obj._global_mesh = None # Lazy initialized right when it's first needed + if not hasattr(obj, "_mesh_dims"): + obj._mesh_dims = {} # Dictionary mapping axis names to sizes + + return obj + + def __init__(self): + if not self._is_initialized: + raise UninitializedDistributedManagerWarning() + super().__init__() + + @property + def rank(self): + """Process rank""" + return self._rank + + @property + def local_rank(self): + """Process rank on local machine""" + return self._local_rank + + @property + def world_size(self): + """Number of processes in distributed environment""" + return self._world_size + + @property + def device(self): + """Process device""" + return self._device + + @property + def distributed(self): + """Distributed environment""" + return self._distributed + + @property + def cuda(self): + """If cuda is available""" + return self._cuda + + @property + def mesh_dims(self): + """Mesh Dimensions as dictionary (axis name : size)""" + return self._mesh_dims + + @property + def group_names(self): + """ + Returns a list of all named process groups created + """ + return self._groups.keys() + + @property + def global_mesh(self): + """ + Returns the global mesh. If it's not initialized, it will be created when this is called. + """ + if self._global_mesh is None: + # Fully flat mesh (1D) by default: + self.initialize_mesh(mesh_shape=(-1,), mesh_dim_names=("world",)) + + return self._global_mesh + + def mesh_names(self): + """ + Return mesh axis names + """ + return self._mesh_dims.keys() + + def mesh_sizes(self): + """ + Return mesh axis sizes + """ + return self._mesh_dims.values() + + def group(self, name=None): + """ + Returns a process group with the given name + If name is None, group is also None indicating the default process group + If named group does not exist, UndefinedGroupError exception is raised + """ + if name in self._groups.keys(): + return self._groups[name] + elif name is None: + return None + else: + raise UndefinedGroupError(name) + + def mesh(self, name=None): + """ + Return a device_mesh with the given name. + Does not initialize. If the mesh is not created + already, will raise and error + + Parameters + ---------- + name : str, optional + Name of desired mesh, by default None + """ + + if name in self._global_mesh.axis_names: + return self._global_mesh[name] + elif name is None: + return self._global_mesh + else: + raise UndefinedGroupError(f"Mesh axis {name} not defined") + + def group_size(self, name=None): + """ + Returns the size of named process group + """ + if name is None: + return self._world_size + group = self.group(name) + return dist.get_world_size(group=group) + + def group_rank(self, name=None): + """ + Returns the rank in named process group + """ + if name is None: + return self._rank + group = self.group(name) + return dist.get_rank(group=group) + + def group_name(self, group=None): + """ + Returns the name of process group + """ + if group is None: + return None + return self._group_names[group] + + @property + def broadcast_buffers(self): + """broadcast_buffers in PyTorch DDP""" + return self._broadcast_buffers + + @broadcast_buffers.setter + def broadcast_buffers(self, broadcast: bool): + """Setter for broadcast_buffers""" + self._broadcast_buffers = broadcast + + @property + def find_unused_parameters(self): + """find_unused_parameters in PyTorch DDP""" + return self._find_unused_parameters + + @find_unused_parameters.setter + def find_unused_parameters(self, find_params: bool): + """Setter for find_unused_parameters""" + if find_params: + warn( + "Setting `find_unused_parameters` in DDP to true, " + "use only if necessary." + ) + self._find_unused_parameters = find_params + + def __str__(self): + output = ( + f"Initialized process {self.rank} of {self.world_size} using " + f"method '{self._initialization_method}'. Device set to {str(self.device)}" + ) + return output + + @classmethod + def is_initialized(cls) -> bool: + """If manager singleton has been initialized""" + return cls._shared_state.get("_is_initialized", False) + + @staticmethod + def get_available_backend(): + """Get communication backend""" + if torch.cuda.is_available() and torch.distributed.is_nccl_available(): + return "nccl" + else: + return "gloo" + + @staticmethod + def initialize_env(): + """Setup method using generic initialization""" + rank = int(os.environ.get("RANK")) + world_size = int(os.environ.get("WORLD_SIZE")) + if "LOCAL_RANK" in os.environ: + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + local_rank = int(local_rank) + else: + local_rank = rank % torch.cuda.device_count() + + else: + local_rank = rank % torch.cuda.device_count() + + # Read env variables + addr = os.environ.get("MASTER_ADDR") + port = os.environ.get("MASTER_PORT") + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + ) + + @staticmethod + def initialize_open_mpi(addr, port): + """Setup method using OpenMPI initialization""" + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK")) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE")) + local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")) + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="openmpi", + ) + + @staticmethod + def initialize_slurm(port): + """Setup method using SLURM initialization""" + rank = int(os.environ.get("SLURM_PROCID")) + world_size = int(os.environ.get("SLURM_NPROCS")) + local_rank = int(os.environ.get("SLURM_LOCALID")) + addr = os.environ.get("SLURM_LAUNCH_NODE_IPADDR") + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="slurm", + ) + + @staticmethod + def initialize(): + """ + Initialize distributed manager + + Current supported initialization methods are: + `ENV`: PyTorch environment variable initialization + https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + `SLURM`: Initialization on SLURM systems. + Uses `SLURM_PROCID`, `SLURM_NPROCS`, `SLURM_LOCALID` and + `SLURM_LAUNCH_NODE_IPADDR` environment variables. + `OPENMPI`: Initialization for OpenMPI launchers. + Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and + `OMPI_COMM_WORLD_LOCAL_RANK` environment variables. + + Initialization by default is done using the first valid method in the order + listed above. Initialization method can also be explicitly controlled using the + `PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD` environment variable and setting it + to one of the options above. + """ + if DistributedManager.is_initialized(): + warn("Distributed manager is already intialized") + return + + addr = os.getenv("MASTER_ADDR", "localhost") + port = os.getenv("MASTER_PORT", "12355") + # https://pytorch.org/docs/master/notes/cuda.html#id5 + # was changed in version 2.2 + if torch.__version__ < (2, 2): + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + else: + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" + initialization_method = os.getenv( + "PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD" + ) + if initialization_method is None: + try: + DistributedManager.initialize_env() + except TypeError: + if "SLURM_PROCID" in os.environ: + DistributedManager.initialize_slurm(port) + elif "OMPI_COMM_WORLD_RANK" in os.environ: + DistributedManager.initialize_open_mpi(addr, port) + else: + warn( + "Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job" + ) + DistributedManager._shared_state["_is_initialized"] = True + elif initialization_method == "ENV": + DistributedManager.initialize_env() + elif initialization_method == "SLURM": + DistributedManager.initialize_slurm(port) + elif initialization_method == "OPENMPI": + DistributedManager.initialize_open_mpi(addr, port) + else: + raise RuntimeError( + "Unknown initialization method " + f"{initialization_method}. " + "Supported values for " + "PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD are " + "ENV, SLURM and OPENMPI" + ) + + # Set per rank numpy random seed for data sampling + np.random.seed(seed=DistributedManager().rank) + + def initialize_mesh( + self, mesh_shape: Tuple[int, ...], mesh_dim_names: Tuple[str, ...] + ) -> dist.DeviceMesh: + """ + Initialize a global device mesh over the entire distributed job. + + Creates a multi-dimensional mesh of processes that can be used for distributed + operations. The mesh shape must multiply to equal the total world size, with + one dimension optionally being flexible (-1). + + Parameters + ---------- + mesh_shape : Tuple[int, ...] + Tuple of ints describing the size of each mesh dimension. Product must equal + world_size. One dimension can be -1 to be automatically calculated. + + mesh_dim_names : Tuple[str, ...] + Names for each mesh dimension. Must match length of mesh_shape. + + Returns + ------- + torch.distributed.DeviceMesh + The initialized device mesh + + Raises + ------ + RuntimeError + If mesh dimensions are invalid or don't match world size + AssertionError + If distributed environment is not available + """ + + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "torch.distributed is unavailable. " + "Check pytorch build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + # Assert basic properties: + if len(mesh_shape) == 0: + raise RuntimeError( + "Device Mesh requires at least one mesh dimension in `mesh_shape`" + ) + if len(mesh_shape) != len(mesh_dim_names): + raise RuntimeError( + "mesh_shape and mesh_dim_names must have the same length, but found " + f"{len(mesh_shape)} and {len(mesh_dim_names)} respectively." + ) + if len(set(mesh_dim_names)) != len(mesh_dim_names): + raise RuntimeError("Mesh dimension names must be unique") + + # Check against the total mesh shape vs. world size: + total_mesh_shape = np.prod(mesh_shape) + + # Allow one shape to be -1 + if -1 in mesh_shape: + residual_shape = int(self.world_size / (-1 * total_mesh_shape)) + + # Replace -1 with the computed size: + mesh_shape = [residual_shape if m == -1 else m for m in mesh_shape] + # Recompute total shape: + total_mesh_shape = np.prod(mesh_shape) + + if total_mesh_shape != self.world_size: + raise RuntimeError( + "Device Mesh num elements must equal world size of " + f"{total_mesh_shape} but was configured by user with " + f"global size of {self.world_size}." + ) + + # Actually create the mesh: + self._global_mesh = dist.init_device_mesh( + "cuda" if self.cuda else "cpu", + mesh_shape, + mesh_dim_names=mesh_dim_names, + ) + + # Finally, upon success, cache the mesh dimensions: + self._mesh_dims = {key: val for key, val in zip(mesh_dim_names, mesh_shape)} + + return self._global_mesh + + @staticmethod + def setup( + rank=0, + world_size=1, + local_rank=None, + addr="localhost", + port="12355", + backend="nccl", + method="env", + ): + """Set up PyTorch distributed process group and update manager attributes""" + os.environ["MASTER_ADDR"] = addr + os.environ["MASTER_PORT"] = str(port) + + DistributedManager._shared_state["_is_initialized"] = True + manager = DistributedManager() + + manager._distributed = torch.distributed.is_available() + if manager._distributed: + # Update rank and world_size if using distributed + manager._rank = rank + manager._world_size = world_size + if local_rank is None: + manager._local_rank = rank % torch.cuda.device_count() + else: + manager._local_rank = local_rank + + manager._device = torch.device( + f"cuda:{manager.local_rank}" if torch.cuda.is_available() else "cpu" + ) + + if manager._distributed: + # Setup distributed process group + try: + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + device_id=manager.device, + ) + except TypeError: + # device_id only introduced in PyTorch 2.3 + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + ) + + if torch.cuda.is_available(): + # Set device for this process and empty cache to optimize memory usage + torch.cuda.set_device(manager.device) + torch.cuda.device(manager.device) + torch.cuda.empty_cache() + + manager._initialization_method = method + + @staticmethod + def create_process_subgroup( + name: str, size: int, group_name: Optional[str] = None, verbose: bool = False + ): # pragma: no cover + """ + Create a process subgroup of a parent process group. This must be a collective + call by all processes participating in this application. + + Parameters + ---------- + name : str + Name of the process subgroup to be created. + + size : int + Size of the process subgroup to be created. This must be an integer factor of + the parent group's size. + + group_name : Optional[str] + Name of the parent process group, optional. If None, the default process group + will be used. Default None. + + verbose : bool + Print out ranks of each created process group, default False. + + """ + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "torch.distributed is unavailable. " + "Check pytorch build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + if name in manager._groups: + raise AssertionError(f"Group with name {name} already exists") + + # Get parent group's params + group = manager._groups[group_name] if group_name else None + group_size = dist.get_world_size(group=group) + num_groups = manager.world_size // group_size + + # Get number of sub-groups per parent group + if group_size % size != 0: + raise AssertionError( + f"Cannot divide group size {group_size} evenly into subgroups of" + f" size {size}" + ) + num_subgroups = group_size // size + + # Create all the sub-groups + # Note: all ranks in the job need to create all sub-groups in + # the same order even if a rank is not part of a sub-group + manager._group_ranks[name] = [] + for g in range(num_groups): + for i in range(num_subgroups): + # Get global ranks that are part of this sub-group + start = i * size + end = start + size + if group_name: + ranks = manager._group_ranks[group_name][g][start:end] + else: + ranks = list(range(start, end)) + # Create sub-group and keep track of ranks + tmp_group = dist.new_group(ranks=ranks) + manager._group_ranks[name].append(ranks) + if manager.rank in ranks: + # Set group in manager only if this rank is part of the group + manager._groups[name] = tmp_group + manager._group_names[tmp_group] = name + + if verbose and manager.rank == 0: + print(f"Process group '{name}':") + for grp in manager._group_ranks[name]: + print(" ", grp) + + @staticmethod + def create_orthogonal_process_group( + orthogonal_group_name: str, group_name: str, verbose: bool = False + ): # pragma: no cover + """ + Create a process group that is orthogonal to the specified process group. + + Parameters + ---------- + orthogonal_group_name : str + Name of the orthogonal process group to be created. + + group_name : str + Name of the existing process group. + + verbose : bool + Print out ranks of each created process group, default False. + + """ + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "torch.distributed is unavailable. " + "Check pytorch build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + if group_name not in manager._groups: + raise ValueError(f"Group with name {group_name} does not exist") + if orthogonal_group_name in manager._groups: + raise ValueError(f"Group with name {orthogonal_group_name} already exists") + + group_ranks = manager._group_ranks[group_name] + orthogonal_ranks = [list(i) for i in zip(*group_ranks)] + + for ranks in orthogonal_ranks: + tmp_group = dist.new_group(ranks=ranks) + if manager.rank in ranks: + # Set group in manager only if this rank is part of the group + manager._groups[orthogonal_group_name] = tmp_group + manager._group_names[tmp_group] = orthogonal_group_name + + manager._group_ranks[orthogonal_group_name] = orthogonal_ranks + + if verbose and manager.rank == 0: + print(f"Process group '{orthogonal_group_name}':") + for grp in manager._group_ranks[orthogonal_group_name]: + print(" ", grp) + + @staticmethod + def create_group_from_node( + node: ProcessGroupNode, + parent: Optional[str] = None, + verbose: bool = False, + ): # pragma: no cover + if node.size is None: + raise AssertionError( + "Cannot create groups from a ProcessGroupNode that is not fully" + " populated. Ensure that config.set_leaf_group_sizes is called first" + " with `update_parent_sizes = True`" + ) + + DistributedManager.create_process_subgroup( + node.name, node.size, group_name=parent, verbose=verbose + ) + # Create orthogonal process group + orthogonal_group = f"__orthogonal_to_{node.name}" + DistributedManager.create_orthogonal_process_group( + orthogonal_group, node.name, verbose=verbose + ) + return orthogonal_group + + @staticmethod + def create_groups_from_config( + config: ProcessGroupConfig, verbose: bool = False + ): # pragma: no cover + + warnings.warn( + "DistributedManager.create_groups_from_config is no longer the most simple " + "way to organize process groups. Please switch to DeviceMesh, " + "and DistributedManager.initialize_mesh", + category=DeprecationWarning, + stacklevel=2, + ) + + # Traverse process group tree in breadth first order + # to create nested process groups + q = queue.Queue() + q.put(config.root_id) + DistributedManager.create_group_from_node(config.root) + + while not q.empty(): + node_id = q.get() + if verbose: + print(f"Node ID: {node_id}") + + children = config.tree.children(node_id) + if verbose: + print(f" Children: {children}") + + parent_group = node_id + for child in children: + # Create child group and replace parent group by orthogonal group so + # that each child forms an independent block of processes + parent_group = DistributedManager.create_group_from_node( + child.data, + parent=parent_group, + ) + + # Add child ids to the queue + q.put(child.identifier) + + @atexit.register + @staticmethod + def cleanup(): + """Clean up distributed group and singleton""" + # Destroying group.WORLD is enough for all process groups to get destroyed + if ( + "_is_initialized" in DistributedManager._shared_state + and DistributedManager._shared_state["_is_initialized"] + and "_distributed" in DistributedManager._shared_state + and DistributedManager._shared_state["_distributed"] + ): + if torch.cuda.is_available(): + dist.barrier(device_ids=[DistributedManager().local_rank]) + else: + dist.barrier() + dist.destroy_process_group() + DistributedManager._shared_state = {} diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..6b790ae --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,3 @@ +from .unet import UNet +from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd +from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding \ No newline at end of file diff --git a/src/models/layers.py b/src/models/layers.py index 1fb3b17..d5a1ab2 100644 --- a/src/models/layers.py +++ b/src/models/layers.py @@ -26,7 +26,7 @@ from einops import rearrange from torch.nn.functional import silu -from physicsnemo.models.diffusion import weight_init +from src.utils.model_utils import weight_init class Linear(torch.nn.Module): diff --git a/src/models/preconditioning copy.py b/src/models/preconditioning copy.py deleted file mode 100644 index 52a1660..0000000 --- a/src/models/preconditioning copy.py +++ /dev/null @@ -1,1176 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Preconditioning schemes used in the paper"Elucidating the Design Space of -Diffusion-Based Generative Models". -""" - -import importlib -import warnings -from dataclasses import dataclass -from typing import List, Union - -import numpy as np -import nvtx -import torch - -from physicsnemo.models.diffusion import ( - DhariwalUNet, # noqa: F401 for globals - SongUNet, # noqa: F401 for globals -) -from physicsnemo.models.meta import ModelMetaData -from physicsnemo.models.module import Module - -network_module = importlib.import_module("physicsnemo.models.diffusion") - - -@dataclass -class VPPrecondMetaData(ModelMetaData): - """VPPrecond meta data""" - - name: str = "VPPrecond" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class VPPrecond(Module): - """ - Preconditioning corresponding to the variance preserving (VP) formulation. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - beta_d : float - Extent of the noise level schedule, by default 19.9. - beta_min : float - Initial slope of the noise level schedule, by default 0.1. - M : int - Original number of timesteps in the DDPM formulation, by default 1000. - epsilon_t : float - Minimum t-value used during training, by default 1e-5. - model_type :str - Class name of the underlying model, by default "SongUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. - """ - - def __init__( - self, - img_resolution: int, - img_channels: int, - label_dim: int = 0, - use_fp16: bool = False, - beta_d: float = 19.9, - beta_min: float = 0.1, - M: int = 1000, - epsilon_t: float = 1e-5, - model_type: str = "SongUNet", - **model_kwargs: dict, - ): - super().__init__(meta=VPPrecondMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.beta_d = beta_d - self.beta_min = beta_min - self.M = M - self.epsilon_t = epsilon_t - self.sigma_min = float(self.sigma(epsilon_t)) - self.sigma_max = float(self.sigma(1)) - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_channels, - out_channels=img_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = 1 - c_out = -sigma - c_in = 1 / (sigma**2 + 1).sqrt() - c_noise = (self.M - 1) * self.sigma_inv(sigma) - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - - def sigma(self, t: Union[float, torch.Tensor]): - """ - Compute the sigma(t) value for a given t based on the VP formulation. - - The function calculates the noise level schedule for the diffusion process based - on the given parameters `beta_d` and `beta_min`. - - Parameters - ---------- - t : Union[float, torch.Tensor] - The timestep or set of timesteps for which to compute sigma(t). - - Returns - ------- - torch.Tensor - The computed sigma(t) value(s). - """ - t = torch.as_tensor(t) - return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() - - def sigma_inv(self, sigma: Union[float, torch.Tensor]): - """ - Compute the inverse of the sigma function for a given sigma. - - This function effectively calculates t from a given sigma(t) based on the - parameters `beta_d` and `beta_min`. - - Parameters - ---------- - sigma : Union[float, torch.Tensor] - The sigma(t) value or set of sigma(t) values for which to compute the - inverse. - - Returns - ------- - torch.Tensor - The computed t value(s) corresponding to the provided sigma(t). - """ - sigma = torch.as_tensor(sigma) - return ( - (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() - - self.beta_min - ) / self.beta_d - - def round_sigma(self, sigma: Union[float, List, torch.Tensor]): - """ - Convert a given sigma value(s) to a tensor representation. - - Parameters - ---------- - sigma : Union[float list, torch.Tensor] - The sigma value(s) to convert. - - Returns - ------- - torch.Tensor - The tensor representation of the provided sigma value(s). - """ - return torch.as_tensor(sigma) - - -@dataclass -class VEPrecondMetaData(ModelMetaData): - """VEPrecond meta data""" - - name: str = "VEPrecond" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class VEPrecond(Module): - """ - Preconditioning corresponding to the variance exploding (VE) formulation. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.02. - sigma_max : float - Maximum supported noise level, by default 100.0. - model_type :str - Class name of the underlying model, by default "SongUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. - """ - - def __init__( - self, - img_resolution: int, - img_channels: int, - label_dim: int = 0, - use_fp16: bool = False, - sigma_min: float = 0.02, - sigma_max: float = 100.0, - model_type: str = "SongUNet", - **model_kwargs: dict, - ): - super().__init__(meta=VEPrecondMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_channels, - out_channels=img_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = 1 - c_out = sigma - c_in = 1 - c_noise = (0.5 * sigma).log() - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - - def round_sigma(self, sigma: Union[float, List, torch.Tensor]): - """ - Convert a given sigma value(s) to a tensor representation. - - Parameters - ---------- - sigma : Union[float list, torch.Tensor] - The sigma value(s) to convert. - - Returns - ------- - torch.Tensor - The tensor representation of the provided sigma value(s). - """ - return torch.as_tensor(sigma) - - -@dataclass -class iDDPMPrecondMetaData(ModelMetaData): - """iDDPMPrecond meta data""" - - name: str = "iDDPMPrecond" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class iDDPMPrecond(Module): - """ - Preconditioning corresponding to the improved DDPM (iDDPM) formulation. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - C_1 : float - Timestep adjustment at low noise levels., by default 0.001. - C_2 : float - Timestep adjustment at high noise levels., by default 0.008. - M: int - Original number of timesteps in the DDPM formulation, by default 1000. - model_type :str - Class name of the underlying model, by default "DhariwalUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Nichol, A.Q. and Dhariwal, P., 2021, July. Improved denoising diffusion - probabilistic models. In International Conference on Machine Learning - (pp. 8162-8171). PMLR. - """ - - def __init__( - self, - img_resolution, - img_channels, - label_dim=0, - use_fp16=False, - C_1=0.001, - C_2=0.008, - M=1000, - model_type="DhariwalUNet", - **model_kwargs, - ): - super().__init__(meta=iDDPMPrecondMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.C_1 = C_1 - self.C_2 = C_2 - self.M = M - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_channels, - out_channels=img_channels * 2, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - u = torch.zeros(M + 1) - for j in range(M, 0, -1): # M, ..., 1 - u[j - 1] = ( - (u[j] ** 2 + 1) - / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - - 1 - ).sqrt() - self.register_buffer("u", u) - self.sigma_min = float(u[M - 1]) - self.sigma_max = float(u[0]) - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = 1 - c_out = -sigma - c_in = 1 / (sigma**2 + 1).sqrt() - c_noise = ( - self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) - ) - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - D_x = c_skip * x + c_out * F_x[:, : self.img_channels].to(torch.float32) - return D_x - - def alpha_bar(self, j): - """ - Compute the alpha_bar(j) value for a given j based on the iDDPM formulation. - - Parameters - ---------- - j : Union[int, torch.Tensor] - The timestep or set of timesteps for which to compute alpha_bar(j). - - Returns - ------- - torch.Tensor - The computed alpha_bar(j) value(s). - """ - j = torch.as_tensor(j) - return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 - - def round_sigma(self, sigma, return_index=False): - """ - Round the provided sigma value(s) to the nearest value(s) in a - pre-defined set `u`. - - Parameters - ---------- - sigma : Union[float, list, torch.Tensor] - The sigma value(s) to round. - return_index : bool, optional - Whether to return the index/indices of the rounded value(s) in `u` instead - of the rounded value(s) themselves, by default False. - - Returns - ------- - torch.Tensor - The rounded sigma value(s) or their index/indices in `u`, depending on the - value of `return_index`. - """ - sigma = torch.as_tensor(sigma) - index = torch.cdist( - sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), - self.u.reshape(1, -1, 1), - ).argmin(2) - result = index if return_index else self.u[index.flatten()].to(sigma.dtype) - return result.reshape(sigma.shape).to(sigma.device) - - -@dataclass -class EDMPrecondMetaData(ModelMetaData): - """EDMPrecond meta data""" - - name: str = "EDMPrecond" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class EDMPrecond(Module): - """ - Improved preconditioning proposed in the paper "Elucidating the Design Space of - Diffusion-Based Generative Models" (EDM) - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels (for both input and output). If your model - requires a different number of input or output chanels, - override this by passing either of the optional - img_in_channels or img_out_channels args - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.0. - sigma_max : float - Maximum supported noise level, by default inf. - sigma_data : float - Expected standard deviation of the training data, by default 0.5. - model_type :str - Class name of the underlying model, by default "DhariwalUNet". - img_in_channels: int - Optional setting for when number of input channels =/= number of output - channels. If set, will override img_channels for the input - This is useful in the case of additional (conditional) channels - img_out_channels: int - Optional setting for when number of input channels =/= number of output - channels. If set, will override img_channels for the output - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the - design space of diffusion-based generative models. Advances in Neural Information - Processing Systems, 35, pp.26565-26577. - """ - - def __init__( - self, - img_resolution, - img_channels, - label_dim=0, - use_fp16=False, - sigma_min=0.0, - sigma_max=float("inf"), - sigma_data=0.5, - model_type="DhariwalUNet", - img_in_channels=None, - img_out_channels=None, - **model_kwargs, - ): - super().__init__(meta=EDMPrecondMetaData) - self.img_resolution = img_resolution - if img_in_channels is not None: - img_in_channels = img_in_channels - else: - img_in_channels = img_channels - if img_out_channels is not None: - img_out_channels = img_out_channels - else: - img_out_channels = img_channels - - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data - - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_in_channels, - out_channels=img_out_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - def forward( - self, - x, - sigma, - condition=None, - class_labels=None, - force_fp32=False, - **model_kwargs, - ): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() - c_noise = sigma.log() / 4 - - arg = c_in * x - - if condition is not None: - arg = torch.cat([arg, condition], dim=1) - - F_x = self.model( - arg.to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - - @staticmethod - def round_sigma(sigma: Union[float, List, torch.Tensor]): - """ - Convert a given sigma value(s) to a tensor representation. - - Parameters - ---------- - sigma : Union[float list, torch.Tensor] - The sigma value(s) to convert. - - Returns - ------- - torch.Tensor - The tensor representation of the provided sigma value(s). - """ - return torch.as_tensor(sigma) - - -@dataclass -class EDMPrecondSRMetaData(ModelMetaData): - """EDMPrecondSR meta data""" - - name: str = "EDMPrecondSR" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class EDMPrecondSR(Module): - """ - Improved preconditioning proposed in the paper "Elucidating the Design Space of - Diffusion-Based Generative Models" (EDM) for super-resolution tasks - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - img_in_channels : int - Number of input color channels. - img_out_channels : int - Number of output color channels. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.0. - sigma_max : float - Maximum supported noise level, by default inf. - sigma_data : float - Expected standard deviation of the training data, by default 0.5. - model_type :str - Class name of the underlying model, by default "SongUNetPosEmbd". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - References: - - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the - design space of diffusion-based generative models. Advances in Neural Information - Processing Systems, 35, pp.26565-26577. - - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., - Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. - Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. - arXiv preprint arXiv:2309.15214. - """ - - def __init__( - self, - img_resolution, - img_channels, - img_in_channels, - img_out_channels, - use_fp16=False, - sigma_min=0.0, - sigma_max=float("inf"), - sigma_data=0.5, - model_type="SongUNetPosEmbd", - scale_cond_input=True, - **model_kwargs, - ): - super().__init__(meta=EDMPrecondSRMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels # TODO: this is not used, remove it - self.img_in_channels = img_in_channels - self.img_out_channels = img_out_channels - self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data - self.scale_cond_input = scale_cond_input - - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_in_channels + img_out_channels, - out_channels=img_out_channels, - **model_kwargs, - ) # TODO needs better handling - self.scaling_fn = self._get_scaling_fn() - - def _get_scaling_fn(self): - if self.scale_cond_input: - warnings.warn( - "scale_cond_input=True does not properly scale the conditional input. " - "(see https://github.com/NVIDIA/modulus/issues/229). " - "This setup will be deprecated. " - "Please set scale_cond_input=False.", - DeprecationWarning, - ) - return self._legacy_scaling_fn - else: - return self._scaling_fn - - @staticmethod - def _scaling_fn(x, img_lr, c_in): - return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) - - @staticmethod - def _legacy_scaling_fn(x, img_lr, c_in): - return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1) - - @nvtx.annotate(message="EDMPrecondSR", color="orange") - def forward( - self, - x, - img_lr, - sigma, - force_fp32=False, - **model_kwargs, - ): - # Concatenate input channels - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() - c_noise = sigma.log() / 4 - - if img_lr is None: - arg = c_in * x - else: - arg = self.scaling_fn(x, img_lr, c_in) - arg = arg.to(dtype) - - F_x = self.model( - arg, - c_noise.flatten(), - class_labels=None, - **model_kwargs, - ) - - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - - @staticmethod - def round_sigma(sigma: Union[float, List, torch.Tensor]): - """ - Convert a given sigma value(s) to a tensor representation. - See EDMPrecond.round_sigma - """ - return EDMPrecond.round_sigma(sigma) - - -class VEPrecond_dfsr(torch.nn.Module): - """ - Preconditioning for dfsr model, modified from class VEPrecond, where the input - argument 'sigma' in forward propagation function is used to receive the timestep - of the backward diffusion process. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.02. - sigma_max : float - Maximum supported noise level, by default 100.0. - model_type :str - Class name of the underlying model, by default "SongUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. - Advances in neural information processing systems. 2020;33:6840-51. - """ - - def __init__( - self, - img_resolution: int, - img_channels: int, - label_dim: int = 0, - use_fp16: bool = False, - sigma_min: float = 0.02, - sigma_max: float = 100.0, - dataset_mean: float = 5.85e-05, - dataset_scale: float = 4.79, - model_type: str = "SongUNet", - **model_kwargs: dict, - ): - super().__init__() - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.model = globals()[model_type]( - img_resolution=img_resolution, - in_channels=self.img_channels, - out_channels=img_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - # print("sigma: ", sigma) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_in = 1 - c_noise = sigma # Change the definitation of c_noise to avoid -inf values for zero sigma - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - - if F_x.dtype != dtype: - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - return F_x - - -class VEPrecond_dfsr_cond(torch.nn.Module): - """ - Preconditioning for dfsr model with physics-informed conditioning input, modified - from class VEPrecond, where the input argument 'sigma' in forward propagation function - is used to receive the timestep of the backward diffusion process. The gradient of PDE - residual with respect to the vorticity in the governing Navier-Stokes equation is computed - as the physics-informed conditioning variable and is combined with the backward diffusion - timestep before being sent to the underlying model for noise prediction. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.02. - sigma_max : float - Maximum supported noise level, by default 100.0. - model_type :str - Class name of the underlying model, by default "SongUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: - [1] Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. - [2] Shu D, Li Z, Farimani AB. A physics-informed diffusion model for high-fidelity - flow field reconstruction. Journal of Computational Physics. 2023 Apr 1;478:111972. - """ - - def __init__( - self, - img_resolution: int, - img_channels: int, - label_dim: int = 0, - use_fp16: bool = False, - sigma_min: float = 0.02, - sigma_max: float = 100.0, - dataset_mean: float = 5.85e-05, - dataset_scale: float = 4.79, - model_type: str = "SongUNet", - **model_kwargs: dict, - ): - super().__init__() - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.model = globals()[model_type]( - img_resolution=img_resolution, - in_channels=model_kwargs["model_channels"] * 2, - out_channels=img_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - # modules to embed residual loss - self.conv_in = torch.nn.Conv2d( - img_channels, - model_kwargs["model_channels"], - kernel_size=3, - stride=1, - padding=1, - padding_mode="circular", - ) - self.emb_conv = torch.nn.Sequential( - torch.nn.Conv2d( - img_channels, - model_kwargs["model_channels"], - kernel_size=1, - stride=1, - padding=0, - ), - torch.nn.GELU(), - torch.nn.Conv2d( - model_kwargs["model_channels"], - model_kwargs["model_channels"], - kernel_size=3, - stride=1, - padding=1, - padding_mode="circular", - ), - ) - self.dataset_mean = dataset_mean - self.dataset_scale = dataset_scale - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_in = 1 - c_noise = sigma - - # Compute physics-informed conditioning information using vorticity residual - dx = ( - self.voriticity_residual((x * self.dataset_scale + self.dataset_mean)) - / self.dataset_scale - ) - x = self.conv_in(x) - cond_emb = self.emb_conv(dx) - x = torch.cat((x, cond_emb), dim=1) - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - - if F_x.dtype != dtype: - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - return F_x - - def voriticity_residual(self, w, re=1000.0, dt=1 / 32): - """ - Compute the gradient of PDE residual with respect to a given vorticity w using the - spectrum method. - - Parameters - ---------- - w: torch.Tensor - The fluid flow data sample (vorticity). - re: float - The value of Reynolds number used in the governing Navier-Stokes equation. - dt: float - Time step used to compute the time-derivative of vorticity included in the governing - Navier-Stokes equation. - - Returns - ------- - torch.Tensor - The computed vorticity gradient. - """ - - # w [b t h w] - w = w.clone() - w.requires_grad_(True) - nx = w.size(2) - device = w.device - - w_h = torch.fft.fft2(w[:, 1:-1], dim=[2, 3]) - # Wavenumbers in y-direction - k_max = nx // 2 - N = nx - k_x = ( - torch.cat( - ( - torch.arange(start=0, end=k_max, step=1, device=device), - torch.arange(start=-k_max, end=0, step=1, device=device), - ), - 0, - ) - .reshape(N, 1) - .repeat(1, N) - .reshape(1, 1, N, N) - ) - k_y = ( - torch.cat( - ( - torch.arange(start=0, end=k_max, step=1, device=device), - torch.arange(start=-k_max, end=0, step=1, device=device), - ), - 0, - ) - .reshape(1, N) - .repeat(N, 1) - .reshape(1, 1, N, N) - ) - # Negative Laplacian in Fourier space - lap = k_x**2 + k_y**2 - lap[..., 0, 0] = 1.0 - psi_h = w_h / lap - - u_h = 1j * k_y * psi_h - v_h = -1j * k_x * psi_h - wx_h = 1j * k_x * w_h - wy_h = 1j * k_y * w_h - wlap_h = -lap * w_h - - u = torch.fft.irfft2(u_h[..., :, : k_max + 1], dim=[2, 3]) - v = torch.fft.irfft2(v_h[..., :, : k_max + 1], dim=[2, 3]) - wx = torch.fft.irfft2(wx_h[..., :, : k_max + 1], dim=[2, 3]) - wy = torch.fft.irfft2(wy_h[..., :, : k_max + 1], dim=[2, 3]) - wlap = torch.fft.irfft2(wlap_h[..., :, : k_max + 1], dim=[2, 3]) - advection = u * wx + v * wy - - wt = (w[:, 2:, :, :] - w[:, :-2, :, :]) / (2 * dt) - - # establish forcing term - x = torch.linspace(0, 2 * np.pi, nx + 1, device=device) - x = x[0:-1] - X, Y = torch.meshgrid(x, x) - f = -4 * torch.cos(4 * Y) - - residual = wt + (advection - (1.0 / re) * wlap + 0.1 * w[:, 1:-1]) - f - residual_loss = (residual**2).mean() - dw = torch.autograd.grad(residual_loss, w)[0] - - return dw diff --git a/src/models/preconditioning.py b/src/models/preconditioning.py index 52a1660..7b621e2 100644 --- a/src/models/preconditioning.py +++ b/src/models/preconditioning.py @@ -27,13 +27,13 @@ import numpy as np import nvtx import torch +import torch.nn as nn -from physicsnemo.models.diffusion import ( +from src.models import ( DhariwalUNet, # noqa: F401 for globals SongUNet, # noqa: F401 for globals ) from physicsnemo.models.meta import ModelMetaData -from physicsnemo.models.module import Module network_module = importlib.import_module("physicsnemo.models.diffusion") @@ -58,7 +58,7 @@ class VPPrecondMetaData(ModelMetaData): auto_grad: bool = False -class VPPrecond(Module): +class VPPrecond(nn.Module): """ Preconditioning corresponding to the variance preserving (VP) formulation. @@ -241,7 +241,7 @@ class VEPrecondMetaData(ModelMetaData): auto_grad: bool = False -class VEPrecond(Module): +class VEPrecond(nn.Module): """ Preconditioning corresponding to the variance exploding (VE) formulation. @@ -370,7 +370,7 @@ class iDDPMPrecondMetaData(ModelMetaData): auto_grad: bool = False -class iDDPMPrecond(Module): +class iDDPMPrecond(nn.Module): """ Preconditioning corresponding to the improved DDPM (iDDPM) formulation. @@ -544,7 +544,7 @@ class EDMPrecondMetaData(ModelMetaData): auto_grad: bool = False -class EDMPrecond(Module): +class EDMPrecond(nn.Module): """ Improved preconditioning proposed in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" (EDM) @@ -713,7 +713,7 @@ class EDMPrecondSRMetaData(ModelMetaData): auto_grad: bool = False -class EDMPrecondSR(Module): +class EDMPrecondSR(nn.Module): """ Improved preconditioning proposed in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" (EDM) for super-resolution tasks @@ -861,7 +861,7 @@ def round_sigma(sigma: Union[float, List, torch.Tensor]): return EDMPrecond.round_sigma(sigma) -class VEPrecond_dfsr(torch.nn.Module): +class VEPrecond_dfsr(nn.Module): """ Preconditioning for dfsr model, modified from class VEPrecond, where the input argument 'sigma' in forward propagation function is used to receive the timestep @@ -953,7 +953,7 @@ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs) return F_x -class VEPrecond_dfsr_cond(torch.nn.Module): +class VEPrecond_dfsr_cond(nn.Module): """ Preconditioning for dfsr model with physics-informed conditioning input, modified from class VEPrecond, where the input argument 'sigma' in forward propagation function diff --git a/src/models/song_unet.py b/src/models/song_unet.py index d38484b..68adbda 100644 --- a/src/models/song_unet.py +++ b/src/models/song_unet.py @@ -27,8 +27,9 @@ import torch from torch.nn.functional import silu from torch.utils.checkpoint import checkpoint +import torch.nn as nn -from physicsnemo.models.diffusion import ( +from src.models import ( Conv2d, FourierEmbedding, GroupNorm, @@ -37,7 +38,6 @@ UNetBlock, ) from physicsnemo.models.meta import ModelMetaData -from physicsnemo.models.module import Module @dataclass @@ -58,7 +58,7 @@ class MetaData(ModelMetaData): auto_grad: bool = False -class SongUNet(Module): +class SongUNet(nn.Module): """ Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with optional self-attention, embeddings, and encoder-decoder components. diff --git a/src/models/unet.py b/src/models/unet.py index 7270606..db8e4f8 100644 --- a/src/models/unet.py +++ b/src/models/unet.py @@ -18,11 +18,11 @@ from dataclasses import dataclass import torch +import torch.nn as nn from physicsnemo.models.meta import ModelMetaData -from physicsnemo.models.module import Module -network_module = importlib.import_module("physicsnemo.models.diffusion") +network_module = importlib.import_module("src.models") @dataclass @@ -43,7 +43,7 @@ class MetaData(ModelMetaData): auto_grad: bool = False -class UNet(Module): # TODO a lot of redundancy, need to clean up +class UNet(nn.Module): # TODO a lot of redundancy, need to clean up """ U-Net Wrapper for CorrDiff. @@ -166,7 +166,7 @@ def round_sigma(self, sigma): return torch.as_tensor(sigma) -class StormCastUNet(Module): +class StormCastUNet(nn.Module): """ U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. diff --git a/src/utils/capture.py b/src/utils/capture.py new file mode 100644 index 0000000..50057f9 --- /dev/null +++ b/src/utils/capture.py @@ -0,0 +1,513 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import os +import time +from contextlib import nullcontext +from logging import Logger +from typing import Any, Callable, Dict, NewType, Optional, Union + +import torch + +from src.distributed import DistributedManager + +float16 = NewType("float16", torch.float16) +bfloat16 = NewType("bfloat16", torch.bfloat16) +optim = NewType("optim", torch.optim) + + +class _StaticCapture(object): + """Base class for StaticCapture decorator. + + This class should not be used, rather StaticCaptureTraining and StaticCaptureEvaluate + should be used instead for training and evaluation functions. + """ + + # Grad scaler and checkpoint class variables use for checkpoint saving and loading + # Since an instance of Static capture does not exist for checkpoint functions + # one must use class functions to access state dicts + _amp_scalers = {} + _amp_scaler_checkpoints = {} + _logger = logging.getLogger("capture") + + def __new__(cls, *args, **kwargs): + obj = super(_StaticCapture, cls).__new__(cls) + obj.amp_scalers = cls._amp_scalers + obj.amp_scaler_checkpoints = cls._amp_scaler_checkpoints + obj.logger = cls._logger + return obj + + def __init__( + self, + model: "physicsnemo.Module", + optim: Optional[optim] = None, + logger: Optional[Logger] = None, + use_graphs: bool = True, + use_autocast: bool = True, + use_gradscaler: bool = True, + compile: bool = False, + cuda_graph_warmup: int = 11, + amp_type: Union[float16, bfloat16] = torch.float16, + gradient_clip_norm: Optional[float] = None, + label: Optional[str] = None, + ): + self.logger = logger if logger else self.logger + # Checkpoint label (used for gradscaler) + self.label = label if label else f"scaler_{len(self.amp_scalers.keys())}" + + # DDP fix + if not isinstance(model, physicsnemo.models.Module) and hasattr( + model, "module" + ): + model = model.module + + if not isinstance(model, physicsnemo.models.Module): + self.logger.error("Model not a PhysicsNeMo Module!") + raise ValueError("Model not a PhysicsNeMo Module!") + if compile: + model = torch.compile(model) + + self.model = model + + self.optim = optim + self.eval = False + self.no_grad = False + self.gradient_clip_norm = gradient_clip_norm + + # Set up toggles for optimizations + if not (amp_type == torch.float16 or amp_type == torch.bfloat16): + raise ValueError("AMP type must be torch.float16 or torch.bfloat16") + # CUDA device + if "cuda" in str(self.model.device): + # CUDA graphs + if use_graphs and not self.model.meta.cuda_graphs: + self.logger.warning( + f"Model {model.meta.name} does not support CUDA graphs, turning off" + ) + use_graphs = False + self.cuda_graphs_enabled = use_graphs + + # AMP GPU + if not self.model.meta.amp_gpu: + self.logger.warning( + f"Model {model.meta.name} does not support AMP on GPUs, turning off" + ) + use_autocast = False + use_gradscaler = False + self.use_gradscaler = use_gradscaler + self.use_autocast = use_autocast + + self.amp_device = "cuda" + # Check if bfloat16 is suppored on the GPU + if amp_type == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + self.logger.warning( + "Current CUDA device does not support bfloat16, falling back to float16" + ) + amp_type = torch.float16 + self.amp_dtype = amp_type + # Gradient Scaler + scaler_enabled = self.use_gradscaler and amp_type == torch.float16 + self.scaler = self._init_amp_scaler(scaler_enabled, self.logger) + + self.replay_stream = torch.cuda.Stream(self.model.device) + # CPU device + else: + self.cuda_graphs_enabled = False + # AMP CPU + if use_autocast and not self.model.meta.amp_cpu: + self.logger.warning( + f"Model {model.meta.name} does not support AMP on CPUs, turning off" + ) + use_autocast = False + + self.use_autocast = use_autocast + self.amp_device = "cpu" + # Only float16 is supported on CPUs + # https://pytorch.org/docs/stable/amp.html#cpu-op-specific-behavior + if amp_type == torch.float16 and use_autocast: + self.logger.warning( + "torch.float16 not supported for CPU AMP, switching to torch.bfloat16" + ) + amp_type = torch.bfloat16 + self.amp_dtype = torch.bfloat16 + # Gradient Scaler (not enabled) + self.scaler = self._init_amp_scaler(False, self.logger) + self.replay_stream = None + + if self.cuda_graphs_enabled: + self.graph = torch.cuda.CUDAGraph() + + self.output = None + self.iteration = 0 + self.cuda_graph_warmup = cuda_graph_warmup # Default for DDP = 11 + + def __call__(self, fn: Callable) -> Callable: + self.function = fn + + @functools.wraps(fn) + def decorated(*args: Any, **kwds: Any) -> Any: + """Training step decorator function""" + + with torch.no_grad() if self.no_grad else nullcontext(): + if self.cuda_graphs_enabled: + self._cuda_graph_forward(*args, **kwds) + else: + self._zero_grads() + self.output = self._amp_forward(*args, **kwds) + + if not self.eval: + # Update model parameters + self.scaler.step(self.optim) + self.scaler.update() + + return self.output + + return decorated + + def _cuda_graph_forward(self, *args: Any, **kwargs: Any) -> Any: + """Forward training step with CUDA graphs + + Returns + ------- + Any + Output of neural network forward + """ + # Graph warm up + if self.iteration < self.cuda_graph_warmup: + self.replay_stream.wait_stream(torch.cuda.current_stream()) + self._zero_grads() + with torch.cuda.stream(self.replay_stream): + output = self._amp_forward(*args, **kwargs) + self.output = output.detach() + torch.cuda.current_stream().wait_stream(self.replay_stream) + # CUDA Graphs + else: + # Graph record + if self.iteration == self.cuda_graph_warmup: + self.logger.warning(f"Recording graph of '{self.function.__name__}'") + self._zero_grads() + torch.cuda.synchronize() + if DistributedManager().distributed: + torch.distributed.barrier() + # TODO: temporary workaround till this issue is fixed: + # https://github.com/pytorch/pytorch/pull/104487#issuecomment-1638665876 + delay = os.environ.get("PHYSICSNEMO_CUDA_GRAPH_CAPTURE_DELAY", "10") + time.sleep(int(delay)) + with torch.cuda.graph(self.graph): + output = self._amp_forward(*args, **kwargs) + self.output = output.detach() + # Graph replay + self.graph.replay() + + self.iteration += 1 + return self.output + + def _zero_grads(self): + """Zero gradients + + Default to `set_to_none` since this will in general have lower memory + footprint, and can modestly improve performance. + + Note + ---- + Zeroing gradients can potentially cause an invalid CUDA memory access in another + graph. However if your graph involves gradients, you much set your gradients to none. + If there is already a graph recorded that includes these gradients, this will error. + Use the `NoGrad` version of capture to avoid this issue for inferencers / validators. + """ + # Skip zeroing if no grad is being used + if self.no_grad: + return + + try: + self.optim.zero_grad(set_to_none=True) + except Exception: + if self.optim: + self.optim.zero_grad() + # For apex optim support and eval mode (need to reset model grads) + self.model.zero_grad(set_to_none=True) + + def _amp_forward(self, *args, **kwargs) -> Any: + """Compute loss and gradients (if training) with AMP + + Returns + ------- + Any + Output of neural network forward + """ + with torch.autocast( + self.amp_device, enabled=self.use_autocast, dtype=self.amp_dtype + ): + output = self.function(*args, **kwargs) + + if not self.eval: + # In training mode output should be the loss + self.scaler.scale(output).backward() + if self.gradient_clip_norm is not None: + self.scaler.unscale_(self.optim) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.gradient_clip_norm + ) + + return output + + def _init_amp_scaler( + self, scaler_enabled: bool, logger: Logger + ) -> torch.cuda.amp.GradScaler: + # Create gradient scaler + scaler = torch.cuda.amp.GradScaler(enabled=scaler_enabled) + # Store scaler in class variable + self.amp_scalers[self.label] = scaler + logging.debug(f"Created gradient scaler {self.label}") + + # If our checkpoint dictionary has weights for this scaler lets load + if self.label in self.amp_scaler_checkpoints: + try: + scaler.load_state_dict(self.amp_scaler_checkpoints[self.label]) + del self.amp_scaler_checkpoints[self.label] + self.logger.info(f"Loaded grad scaler state dictionary {self.label}.") + except Exception as e: + self.logger.error( + f"Failed to load grad scaler {self.label} state dict from saved " + + "checkpoints. Did you switch the ordering of declared static captures?" + ) + raise ValueError(e) + return scaler + + @classmethod + def state_dict(cls) -> Dict[str, Any]: + """Class method for accsessing the StaticCapture state dictionary. + Use this in a training checkpoint function. + + Returns + ------- + Dict[str, Any] + Dictionary of states to save for file + """ + scaler_states = {} + for key, value in cls._amp_scalers.items(): + scaler_states[key] = value.state_dict() + + return scaler_states + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any]) -> None: + """Class method for loading a StaticCapture state dictionary. + Use this in a training checkpoint function. + + Returns + ------- + Dict[str, Any] + Dictionary of states to save for file + """ + for key, value in state_dict.items(): + # If scaler has been created already load the weights + if key in cls._amp_scalers: + try: + cls._amp_scalers[key].load_state_dict(value) + cls._logger.info(f"Loaded grad scaler state dictionary {key}.") + except Exception as e: + cls._logger.error( + f"Failed to load grad scaler state dict with id {key}." + + " Something went wrong!" + ) + raise ValueError(e) + # Otherwise store in checkpoints for later use + else: + cls._amp_scaler_checkpoints[key] = value + + @classmethod + def reset_state(cls): + cls._amp_scalers = {} + cls._amp_scaler_checkpoints = {} + + +class StaticCaptureTraining(_StaticCapture): + """A performance optimization decorator for PyTorch training functions. + + This class should be initialized as a decorator on a function that computes the + forward pass of the neural network and loss function. The user should only call the + defind training step function. This will apply optimizations including: AMP and + Cuda Graphs. + + Parameters + ---------- + model : physicsnemo.models.Module + PhysicsNeMo Model + optim : torch.optim + Optimizer + logger : Optional[Logger], optional + PhysicsNeMo Launch Logger, by default None + use_graphs : bool, optional + Toggle CUDA graphs if supported by model, by default True + use_amp : bool, optional + Toggle AMP if supported by mode, by default True + cuda_graph_warmup : int, optional + Number of warmup steps for cuda graphs, by default 11 + amp_type : Union[float16, bfloat16], optional + Auto casting type for AMP, by default torch.float16 + gradient_clip_norm : Optional[float], optional + Threshold for gradient clipping + label : Optional[str], optional + Static capture checkpoint label, by default None + + Raises + ------ + ValueError + If the model provided is not a physicsnemo.models.Module. I.e. has no meta data. + + Example + ------- + >>> # Create model + >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2) + >>> input = torch.rand(8, 2) + >>> output = torch.rand(8, 2) + >>> # Create optimizer + >>> optim = torch.optim.Adam(model.parameters(), lr=0.001) + >>> # Create training step function with optimization wrapper + >>> @StaticCaptureTraining(model=model, optim=optim) + ... def training_step(model, invar, outvar): + ... predvar = model(invar) + ... loss = torch.sum(torch.pow(predvar - outvar, 2)) + ... return loss + ... + >>> # Sample training loop + >>> for i in range(3): + ... loss = training_step(model, input, output) + ... + + Note + ---- + Static captures must be checkpointed when training using the `state_dict()` if AMP + is being used with gradient scaler. By default, this requires static captures to be + instantiated in the same order as when they were checkpointed. The label parameter + can be used to relax/circumvent this ordering requirement. + + Note + ---- + Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA + memory access errors on some systems. Prioritize capturing training graphs when this + occurs. + """ + + def __init__( + self, + model: "physicsnemo.Module", + optim: torch.optim, + logger: Optional[Logger] = None, + use_graphs: bool = True, + use_amp: bool = True, + compile: bool = False, + cuda_graph_warmup: int = 11, + amp_type: Union[float16, bfloat16] = torch.float16, + gradient_clip_norm: Optional[float] = None, + label: Optional[str] = None, + ): + super().__init__( + model, + optim, + logger, + use_graphs, + use_amp, + use_amp, + compile, + cuda_graph_warmup, + amp_type, + gradient_clip_norm, + label, + ) + + +class StaticCaptureEvaluateNoGrad(_StaticCapture): + + """An performance optimization decorator for PyTorch no grad evaluation. + + This class should be initialized as a decorator on a function that computes run the + forward pass of the model that does not require gradient calculations. This is the + recommended method to use for inference and validation methods. + + Parameters + ---------- + model : physicsnemo.models.Module + PhysicsNeMo Model + logger : Optional[Logger], optional + PhysicsNeMo Launch Logger, by default None + use_graphs : bool, optional + Toggle CUDA graphs if supported by model, by default True + use_amp : bool, optional + Toggle AMP if supported by mode, by default True + cuda_graph_warmup : int, optional + Number of warmup steps for cuda graphs, by default 11 + amp_type : Union[float16, bfloat16], optional + Auto casting type for AMP, by default torch.float16 + label : Optional[str], optional + Static capture checkpoint label, by default None + + Raises + ------ + ValueError + If the model provided is not a physicsnemo.models.Module. I.e. has no meta data. + + Example + ------- + >>> # Create model + >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2) + >>> input = torch.rand(8, 2) + >>> # Create evaluate function with optimization wrapper + >>> @StaticCaptureEvaluateNoGrad(model=model) + ... def eval_step(model, invar): + ... predvar = model(invar) + ... return predvar + ... + >>> output = eval_step(model, input) + >>> output.size() + torch.Size([8, 2]) + + Note + ---- + Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA + memory access errors on some systems. Prioritize capturing training graphs when this + occurs. + """ + + def __init__( + self, + model: "physicsnemo.Module", + logger: Optional[Logger] = None, + use_graphs: bool = True, + use_amp: bool = True, + compile: bool = False, + cuda_graph_warmup: int = 11, + amp_type: Union[float16, bfloat16] = torch.float16, + label: Optional[str] = None, + ): + super().__init__( + model, + None, + logger, + use_graphs, + use_amp, + compile, + False, + cuda_graph_warmup, + amp_type, + None, + label, + ) + self.eval = True # No optimizer/scaler calls + self.no_grad = True # No grad context and no grad zeroing diff --git a/src/utils/checkpoint.py b/src/utils/checkpoint.py new file mode 100644 index 0000000..8ec70fa --- /dev/null +++ b/src/utils/checkpoint.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import re +from pathlib import Path +from typing import Any, Dict, List, NewType, Optional, Union + +import torch +from torch.cuda.amp import GradScaler +from torch.optim.lr_scheduler import _LRScheduler + +from src.distributed import DistributedManager +from src.utils.console import PythonLogger +from src.utils.capture import _StaticCapture + +optimizer = NewType("optimizer", torch.optim) +scheduler = NewType("scheduler", _LRScheduler) +scaler = NewType("scaler", GradScaler) + +checkpoint_logging = PythonLogger("checkpoint") + + +def _get_checkpoint_filename( + path: str, + base_name: str = "checkpoint", + index: Union[int, None] = None, + saving: bool = False, + model_type: str = "mdlus", +) -> str: + """Gets the file name /path of checkpoint + + This function has three different ways of providing a checkout filename: + - If supplied an index this will return the checkpoint name using that index. + - If index is None and saving is false, this will get the checkpoint with the + largest index (latest save). + - If index is None and saving is true, it will return the next valid index file name + which is calculated by indexing the largest checkpoint index found by one. + + Parameters + ---------- + path : str + Path to checkpoints + base_name: str, optional + Base file name, by default checkpoint + index : Union[int, None], optional + Checkpoint index, by default None + saving : bool, optional + Get filename for saving a new checkpoint, by default False + model_type : str + Model type, by default "mdlus" for PhysicsNeMo models and "pt" for PyTorch models + + + Returns + ------- + str + Checkpoint file name + """ + # Get model parallel rank so all processes in the first model parallel group + # can save their checkpoint. In the case without model parallelism, + # model_parallel_rank should be the same as the process rank itself and + # only rank 0 saves + if not DistributedManager.is_initialized(): + checkpoint_logging.warning( + "`DistributedManager` not initialized already. Initializing now, but this might lead to unexpected errors" + ) + DistributedManager.initialize() + manager = DistributedManager() + model_parallel_rank = ( + manager.group_rank("model_parallel") + if "model_parallel" in manager.group_names + else 0 + ) + + # Input file name + checkpoint_filename = str( + Path(path).resolve() / f"{base_name}.{model_parallel_rank}" + ) + + # File extension for PhysicsNeMo models or PyTorch models + file_extension = ".mdlus" if model_type == "mdlus" else ".pt" + + # If epoch is provided load that file + if index is not None: + checkpoint_filename = checkpoint_filename + f".{index}" + checkpoint_filename += file_extension + # Otherwise try loading the latest epoch or rolling checkpoint + else: + file_names = [ + Path(fname).name + for fname in glob.glob( + checkpoint_filename + "*" + file_extension, recursive=False + ) + ] + + if len(file_names) > 0: + # If checkpoint from a null index save exists load that + # This is the most likely line to error since it will fail with + # invalid checkpoint names + file_idx = [ + int( + re.sub( + f"^{base_name}.{model_parallel_rank}.|" + file_extension, + "", + fname, + ) + ) + for fname in file_names + ] + file_idx.sort() + # If we are saving index by 1 to get the next free file name + if saving: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]+1}" + else: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}" + checkpoint_filename += file_extension + else: + checkpoint_filename += ".0" + file_extension + + return checkpoint_filename + + +def _unique_model_names( + models: List[torch.nn.Module], +) -> Dict[str, torch.nn.Module]: + """Util to clean model names and index if repeat names, will also strip DDP wrappers + if they exist. + + Parameters + ---------- + model : List[torch.nn.Module] + List of models to generate names for + + Returns + ------- + Dict[str, torch.nn.Module] + Dictionary of model names and respective modules + """ + # Loop through provided models and set up base names + model_dict = {} + for model0 in models: + if hasattr(model0, "module"): + # Strip out DDP layer + model0 = model0.module + # Base name of model is meta.name unless pytorch model + base_name = model0.__class__.__name__ + if isinstance(model0, physicsnemo.models.Module): + base_name = model0.meta.name + # If we have multiple models of the same name, introduce another index + if base_name in model_dict: + model_dict[base_name].append(model0) + else: + model_dict[base_name] = [model0] + + # Set up unique model names if needed + output_dict = {} + for key, model in model_dict.items(): + if len(model) > 1: + for i, model0 in enumerate(model): + output_dict[key + str(i)] = model0 + else: + output_dict[key] = model[0] + + return output_dict + + +def save_checkpoint( + path: str, + models: Union[torch.nn.Module, List[torch.nn.Module], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + """Training checkpoint saving utility + + This will save a training checkpoint in the provided path following the file naming + convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint + method in PhysicsNeMo core can then be used to read this file. + + Parameters + ---------- + path : str + Path to save the training checkpoint + models : Union[torch.nn.Module, List[torch.nn.Module], None], optional + A single or list of PyTorch models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler. Will attempt to save on in static capture if none provided, by + default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none this will save the checkpoint in the next + valid index, by default None + metadata : Optional[Dict[str, Any]], optional + Additional metadata to save, by default None + """ + # Create checkpoint directory if it does not exist + if not Path(path).is_dir(): + checkpoint_logging.warning( + f"Output directory {path} does not exist, will " "attempt to create" + ) + Path(path).mkdir(parents=True, exist_ok=True) + + # == Saving model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = ( + "mdlus" if isinstance(model, physicsnemo.models.Module) else "pt" + ) + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, saving=True, model_type=model_type + ) + + # Save state dictionary + if isinstance(model, physicsnemo.models.Module): + model.save(file_name) + else: + torch.save(model.state_dict(), file_name) + checkpoint_logging.success(f"Saved model state dictionary: {file_name}") + + # == Saving training checkpoint == + checkpoint_dict = {} + # Optimizer state dict + if optimizer: + checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict() + + # Scheduler state dict + if scheduler: + checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict() + + # Scheduler state dict + if scaler: + checkpoint_dict["scaler_state_dict"] = scaler.state_dict() + # Static capture is being used, save its grad scaler + if _StaticCapture._amp_scalers: + checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict() + + # Output file name + output_filename = _get_checkpoint_filename( + path, index=epoch, saving=True, model_type="pt" + ) + if epoch: + checkpoint_dict["epoch"] = epoch + if metadata: + checkpoint_dict["metadata"] = metadata + + # Save checkpoint to memory + if bool(checkpoint_dict): + torch.save( + checkpoint_dict, + output_filename, + ) + checkpoint_logging.success(f"Saved training checkpoint: {output_filename}") + + +def load_checkpoint( + path: str, + models: Union[torch.nn.Module, List[torch.nn.Module], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata_dict: Optional[Dict[str, Any]] = {}, + device: Union[str, torch.device] = "cpu", +) -> int: + """Checkpoint loading utility + + This loader is designed to be used with the save checkpoint utility in PhysicsNeMo + Launch. Given a path, this method will try to find a checkpoint and load state + dictionaries into the provided training objects. + + Parameters + ---------- + path : str + Path to training checkpoint + models : Union[torch.nn.Module, List[torch.nn.Module], None], optional + A single or list of PyTorch models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler, by default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none is provided this will attempt to load the + checkpoint with the largest index, by default None + metadata_dict: Optional[Dict[str, Any]], optional + Dictionary to store metadata from the checkpoint, by default None + device : Union[str, torch.device], optional + Target device, by default "cpu" + + Returns + ------- + int + Loaded epoch + """ + # Check if checkpoint directory exists + if not Path(path).is_dir(): + checkpoint_logging.warning( + f"Provided checkpoint directory {path} does not exist, skipping load" + ) + return 0 + + # == Loading model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = ( + "mdlus" if isinstance(model, physicsnemo.models.Module) else "pt" + ) + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, model_type=model_type + ) + if not Path(file_name).exists(): + checkpoint_logging.error( + f"Could not find valid model file {file_name}, skipping load" + ) + continue + # Load state dictionary + if isinstance(model, physicsnemo.models.Module): + model.load(file_name) + else: + model.load_state_dict(torch.load(file_name, map_location=device)) + + checkpoint_logging.success( + f"Loaded model state dictionary {file_name} to device {device}" + ) + + # == Loading training checkpoint == + checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") + if not Path(checkpoint_filename).is_file(): + checkpoint_logging.warning( + "Could not find valid checkpoint file, skipping load" + ) + return 0 + + checkpoint_dict = torch.load(checkpoint_filename, map_location=device) + checkpoint_logging.success( + f"Loaded checkpoint file {checkpoint_filename} to device {device}" + ) + + # Optimizer state dict + if optimizer and "optimizer_state_dict" in checkpoint_dict: + optimizer.load_state_dict(checkpoint_dict["optimizer_state_dict"]) + checkpoint_logging.success("Loaded optimizer state dictionary") + + # Scheduler state dict + if scheduler and "scheduler_state_dict" in checkpoint_dict: + scheduler.load_state_dict(checkpoint_dict["scheduler_state_dict"]) + checkpoint_logging.success("Loaded scheduler state dictionary") + + # Scaler state dict + if scaler and "scaler_state_dict" in checkpoint_dict: + scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) + checkpoint_logging.success("Loaded grad scaler state dictionary") + + if "static_capture_state_dict" in checkpoint_dict: + _StaticCapture.load_state_dict(checkpoint_dict["static_capture_state_dict"]) + checkpoint_logging.success("Loaded static capture state dictionary") + + epoch = 0 + if "epoch" in checkpoint_dict: + epoch = checkpoint_dict["epoch"] + + # Update metadata if exists and the dictionary object is provided + metadata = checkpoint_dict.get("metadata", {}) + for key, value in metadata.items(): + metadata_dict[key] = value + + return epoch diff --git a/src/utils/console.py b/src/utils/console.py new file mode 100644 index 0000000..4231576 --- /dev/null +++ b/src/utils/console.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from termcolor import colored + + +class PythonLogger: + """Simple console logger for DL training + This is a WIP + """ + + def __init__(self, name: str = "launch"): + self.logger = logging.getLogger(name) + + def file_logging(self, file_name: str = "launch.log"): + """Log to file""" + if os.path.exists(file_name): + try: + os.remove(file_name) + except FileNotFoundError: + # ignore if already removed (can happen with multiple processes) + pass + formatter = logging.Formatter( + "[%(asctime)s - %(name)s - %(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + filehandler = logging.FileHandler(file_name) + filehandler.setFormatter(formatter) + filehandler.setLevel(logging.DEBUG) + self.logger.addHandler(filehandler) + + def log(self, message: str): + """Log message""" + self.logger.info(message) + + def info(self, message: str): + """Log info""" + self.logger.info(colored(message, "light_blue")) + + def success(self, message: str): + """Log success""" + self.logger.info(colored(message, "light_green")) + + def warning(self, message: str): + """Log warning""" + self.logger.warning(colored(message, "light_yellow")) + + def error(self, message: str): + """Log error""" + self.logger.error(colored(message, "light_red")) + + +class RankZeroLoggingWrapper: + """Wrapper class to only log from rank 0 process in distributed training.""" + + def __init__(self, obj, dist): + self.obj = obj + self.dist = dist + + def __getattr__(self, name): + attr = getattr(self.obj, name) + if callable(attr): + + def wrapper(*args, **kwargs): + if self.dist.rank == 0: + return attr(*args, **kwargs) + else: + return None + + return wrapper + else: + return attr diff --git a/src/utils/deterministic_sampler.py b/src/utils/deterministic_sampler.py new file mode 100644 index 0000000..4b2f32b --- /dev/null +++ b/src/utils/deterministic_sampler.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import nvtx +import torch + +from src.models import EDMPrecond + +# ruff: noqa: E731 + + +@nvtx.annotate(message="deterministic_sampler", color="red") +def deterministic_sampler( + net, + latents, + img_lr, + img_shape=None, + class_labels=None, + randn_like=torch.randn_like, + num_steps=18, + sigma_min=None, + sigma_max=None, + rho=7, + solver="heun", + discretization="edm", + schedule="linear", + scaling="none", + epsilon_s=1e-3, + C_1=0.001, + C_2=0.008, + M=1000, + alpha=1, + S_churn=0, + S_min=0, + S_max=float("inf"), + S_noise=1, +): + """ + Generalized sampler, representing the superset of all sampling methods discussed + in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" + """ + + # conditioning + x_lr = img_lr + + if solver not in ["euler", "heun"]: + raise ValueError(f"Unknown solver {solver}") + if discretization not in ["vp", "ve", "iddpm", "edm"]: + raise ValueError(f"Unknown discretization {discretization}") + if schedule not in ["vp", "ve", "linear"]: + raise ValueError(f"Unknown schedule {schedule}") + if scaling not in ["vp", "none"]: + raise ValueError(f"Unknown scaling {scaling}") + + # Helper functions for VP & VE noise level schedules. + vp_sigma = ( + lambda beta_d, beta_min: lambda t: ( + np.e ** (0.5 * beta_d * (t**2) + beta_min * t) - 1 + ) + ** 0.5 + ) + vp_sigma_deriv = ( + lambda beta_d, beta_min: lambda t: 0.5 + * (beta_min + beta_d * t) + * (sigma(t) + 1 / sigma(t)) + ) + vp_sigma_inv = ( + lambda beta_d, beta_min: lambda sigma: ( + (beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min + ) + / beta_d + ) + ve_sigma = lambda t: t.sqrt() + ve_sigma_deriv = lambda t: 0.5 / t.sqrt() + ve_sigma_inv = lambda sigma: sigma**2 + + # Select default noise level range based on the specified time step discretization. + if sigma_min is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) + sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ + discretization + ] + if sigma_max is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1) + sigma_max = {"vp": vp_def, "ve": 100, "iddpm": 81, "edm": 80}[discretization] + + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Compute corresponding betas for VP. + vp_beta_d = ( + 2 + * (np.log(sigma_min**2 + 1) / epsilon_s - np.log(sigma_max**2 + 1)) + / (epsilon_s - 1) + ) + vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d + + # Define time steps in terms of noise level. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + if discretization == "vp": + orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) + sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + elif discretization == "ve": + orig_t_steps = (sigma_max**2) * ( + (sigma_min**2 / sigma_max**2) ** (step_indices / (num_steps - 1)) + ) + sigma_steps = ve_sigma(orig_t_steps) + elif discretization == "iddpm": + u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) + alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 + for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1 + ).sqrt() + u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] + sigma_steps = u_filtered[ + ((len(u_filtered) - 1) / (num_steps - 1) * step_indices) + .round() + .to(torch.int64) + ] + else: + sigma_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + + # Define noise level schedule. + if schedule == "vp": + sigma = vp_sigma(vp_beta_d, vp_beta_min) + sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) + sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) + elif schedule == "ve": + sigma = ve_sigma + sigma_deriv = ve_sigma_deriv + sigma_inv = ve_sigma_inv + else: + sigma = lambda t: t + sigma_deriv = lambda t: 1 + sigma_inv = lambda sigma: sigma + + # Define scaling schedule. + if scaling == "vp": + s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() + s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) + else: + s = lambda t: 1 + s_deriv = lambda t: 0 + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = ( + min(S_churn / num_steps, np.sqrt(2) - 1) + if S_min <= sigma(t_cur) <= S_max + else 0 + ) + t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) + x_hat = s(t_hat) / s(t_cur) * x_cur + ( + sigma(t_hat) ** 2 - sigma(t_cur) ** 2 + ).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur) + + # Euler step. + h = t_next - t_hat + if isinstance(net, EDMPrecond): + # Conditioning info is passed as keyword arg + denoised = net( + x_hat / s(t_hat), + sigma(t_hat), + condition=x_lr, + class_labels=class_labels, + ).to(torch.float64) + else: + denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to( + torch.float64 + ) + d_cur = ( + sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) + ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + alpha * h * d_cur + t_prime = t_hat + alpha * h + + # Apply 2nd order correction. + if solver == "euler" or i == num_steps - 1: + x_next = x_hat + h * d_cur + else: + if isinstance(net, EDMPrecond): + # Conditioning info is passed as keyword arg + denoised = net( + x_prime / s(t_prime), + sigma(t_prime), + condition=x_lr, + class_labels=class_labels, + ).to(torch.float64) + else: + denoised = net( + x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels + ).to(torch.float64) + d_prime = ( + sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) + ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ( + (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime + ) + + return x_next diff --git a/src/utils/function_utils.py b/src/utils/function_utils.py new file mode 100644 index 0000000..dcbb127 --- /dev/null +++ b/src/utils/function_utils.py @@ -0,0 +1,775 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Miscellaneous utility classes and functions.""" + +import contextlib +import ctypes +import datetime +import fnmatch +import importlib +import inspect +import os +import re +import shutil +import sys +import types +import warnings +from typing import Any, List, Tuple, Union + +import cftime +import numpy as np +import torch + +# ruff: noqa: E722 PERF203 S110 E713 S324 + + +class EasyDict(dict): # pragma: no cover + """ + Convenience class that behaves like a dict but allows access with the attribute + syntax. + """ + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class StackedRandomGenerator: # pragma: no cover + """ + Wrapper for torch.Generator that allows specifying a different random seed + for each sample in a minibatch. + """ + + def __init__(self, device, seeds): + super().__init__() + self.generators = [ + torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds + ] + + def randn(self, size, **kwargs): + if size[0] != len(self.generators): + raise ValueError( + f"Expected first dimension of size {len(self.generators)}, got {size[0]}" + ) + return torch.stack( + [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators] + ) + + def randn_like(self, input): + return self.randn( + input.shape, dtype=input.dtype, layout=input.layout, device=input.device + ) + + def randint(self, *args, size, **kwargs): + if size[0] != len(self.generators): + raise ValueError( + f"Expected first dimension of size {len(self.generators)}, got {size[0]}" + ) + return torch.stack( + [ + torch.randint(*args, size=size[1:], generator=gen, **kwargs) + for gen in self.generators + ] + ) + + +def parse_int_list(s): # pragma: no cover + """ + Parse a comma separated list of numbers or ranges and return a list of ints. + Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] + """ + if isinstance(s, list): + return s + ranges = [] + range_re = re.compile(r"^(\d+)-(\d+)$") + for p in s.split(","): + m = range_re.match(p) + if m: + ranges.extend(range(int(m.group(1)), int(m.group(2)) + 1)) + else: + ranges.append(int(p)) + return ranges + + +# Small util functions +# ------------------------------------------------------------------------------------- +def convert_datetime_to_cftime( + time: datetime.datetime, cls=cftime.DatetimeGregorian +) -> cftime.DatetimeGregorian: + """Convert a Python datetime object to a cftime DatetimeGregorian object.""" + return cls(time.year, time.month, time.day, time.hour, time.minute, time.second) + + +def time_range( + start_time: datetime.datetime, + end_time: datetime.datetime, + step: datetime.timedelta, + inclusive: bool = False, +): + """Like the Python `range` iterator, but with datetimes.""" + t = start_time + while (t <= end_time) if inclusive else (t < end_time): + yield t + t += step + + +def format_time(seconds: Union[int, float]) -> str: # pragma: no cover + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format( + s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60 + ) + + +def format_time_brief(seconds: Union[int, float]) -> str: # pragma: no cover + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def tuple_product(t: Tuple) -> Any: # pragma: no cover + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double, +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: # pragma: no cover + """ + Given a type name string (or an object having a __name__ attribute), return + matching Numpy and ctypes types that have the same size in bytes. + """ + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + if type_str not in _str_to_ctype.keys(): + raise ValueError("Unknown type name: " + type_str) + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + if my_dtype.itemsize != ctypes.sizeof(my_ctype): + raise ValueError( + "Numpy and ctypes types for '{}' have different sizes!".format(type_str) + ) + + return my_dtype, my_ctype + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------- + + +def get_module_from_obj_name( + obj_name: str, +) -> Tuple[types.ModuleType, str]: # pragma: no cover + """ + Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed). + """ + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [ + (".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1) + ] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith( + "No module named '" + module_name + "'" + ): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module( + module: types.ModuleType, obj_name: str +) -> Any: # pragma: no cover + """ + Traverses the object name and returns the last (rightmost) python object. + """ + if obj_name == "": + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: # pragma: no cover + """ + Finds the python object with the given name. + """ + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name( + *args, func_name: str = None, **kwargs +) -> Any: # pragma: no cover + """ + Finds the python object with the given name and calls it as a function. + """ + if func_name is None: + raise ValueError("func_name must be specified") + func_obj = get_obj_by_name(func_name) + if not callable(func_obj): + raise ValueError(func_name + " is not callable") + return func_obj(*args, **kwargs) + + +def construct_class_by_name( + *args, class_name: str = None, **kwargs +) -> Any: # pragma: no cover + """ + Finds the python class with the given name and constructs it with the given + arguments. + """ + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: # pragma: no cover + """ + Get the directory path of the module containing the given object name. + """ + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: # pragma: no cover + """ + Determine whether the given object is a top-level function, i.e., defined at module + scope using 'def'. + """ + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: # pragma: no cover + """ + Return the fully-qualified name of a top-level function. + """ + if not is_top_level_function(obj): + raise ValueError("Object is not a top-level function") + module = obj.__module__ + if module == "__main__": + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + + +def list_dir_recursively_with_ignore( + dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False +) -> List[Tuple[str, str]]: # pragma: no cover + """ + List all files recursively in a given directory while ignoring given file and + directory names. Returns list of tuples containing both absolute and relative paths. + """ + if not os.path.isdir(dir_path): + raise RuntimeError(f"Directory does not exist: {dir_path}") + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + if len(absolute_paths) != len(relative_paths): + raise ValueError("Number of absolute and relative paths do not match") + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs( + files: List[Tuple[str, str]] +) -> None: # pragma: no cover + """ + Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories. + """ + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# ---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + + +def constant( + value, shape=None, dtype=None, device=None, memory_format=None +): # pragma: no cover + """Cached construction of constant tensors""" + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device("cpu") + if memory_format is None: + memory_format = torch.contiguous_format + + key = ( + value.shape, + value.dtype, + value.tobytes(), + shape, + dtype, + device, + memory_format, + ) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + + +# ---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + + def nan_to_num( + input, nan=0.0, posinf=None, neginf=None, *, out=None + ): # pylint: disable=redefined-builtin # pragma: no cover + """Replace NaN/Inf with specified numerical values""" + if not isinstance(input, torch.Tensor): + raise TypeError("input should be a Tensor") + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + if nan != 0: + raise ValueError("nan_to_num only supports nan=0") + return torch.clamp( + input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out + ) + + +# ---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +# ---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + + +@contextlib.contextmanager +def suppress_tracer_warnings(): # pragma: no cover + """ + Context manager to temporarily suppress known warnings in torch.jit.trace(). + Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + """ + flt = ("ignore", None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + + +# ---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + + +def assert_shape(tensor, ref_shape): # pragma: no cover + """ + Assert that the shape of a tensor matches the given list of integers. + None indicates that the size of a dimension is allowed to vary. + Performs symbolic assertion when used in torch.jit.trace(). + """ + if tensor.ndim != len(ref_shape): + raise AssertionError( + f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}" + ) + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(torch.as_tensor(size), ref_size), + f"Wrong size for dimension {idx}", + ) + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(size, torch.as_tensor(ref_size)), + f"Wrong size for dimension {idx}: expected {ref_size}", + ) + elif size != ref_size: + raise AssertionError( + f"Wrong size for dimension {idx}: got {size}, expected {ref_size}" + ) + + +# ---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + + +def profiled_function(fn): # pragma: no cover + """Function decorator that calls torch.autograd.profiler.record_function().""" + + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + + decorator.__name__ = fn.__name__ + return decorator + + +# ---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + + +class InfiniteSampler(torch.utils.data.Sampler): # pragma: no cover + """ + Sampler for torch.utils.data.DataLoader that loops over the dataset + indefinitely, shuffling items as it goes. + """ + + def __init__( + self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 + ): + if not len(dataset) > 0: + raise ValueError("Dataset must contain at least one item") + if not num_replicas > 0: + raise ValueError("num_replicas must be positive") + if not 0 <= rank < num_replicas: + raise ValueError("rank must be non-negative and less than num_replicas") + if not 0 <= window_size <= 1: + raise ValueError("window_size must be between 0 and 1") + super().__init__() + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + + +# ---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + + +def params_and_buffers(module): # pragma: no cover + """Get parameters and buffers of a nn.Module""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + return list(module.parameters()) + list(module.buffers()) + + +def named_params_and_buffers(module): # pragma: no cover + """Get named parameters and buffers of a nn.Module""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + return list(module.named_parameters()) + list(module.named_buffers()) + + +@torch.no_grad() +def copy_params_and_buffers( + src_module, dst_module, require_all=False +): # pragma: no cover + """Copy parameters and buffers from a source module to target module""" + if not isinstance(src_module, torch.nn.Module): + raise TypeError("src_module must be a torch.nn.Module instance") + if not isinstance(dst_module, torch.nn.Module): + raise TypeError("dst_module must be a torch.nn.Module instance") + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + if not ((name in src_tensors) or (not require_all)): + raise ValueError(f"Missing source tensor for {name}") + if name in src_tensors: + tensor.copy_(src_tensors[name]) + + +# ---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + + +@contextlib.contextmanager +def ddp_sync(module, sync): # pragma: no cover + """ + Context manager for easily enabling/disabling DistributedDataParallel + synchronization. + """ + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + + +# ---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + + +def check_ddp_consistency(module, ignore_regex=None): # pragma: no cover + """Check DistributedDataParallel consistency across processes.""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + "." + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + if not (tensor == other).all(): + raise RuntimeError(f"DDP consistency check failed for {fullname}") + + +# ---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + + +def print_module_summary( + module, inputs, max_nesting=3, skip_redundant=True +): # pragma: no cover + """Print summary table of module hierarchy.""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + if isinstance(module, torch.jit.ScriptModule): + raise TypeError("module must not be a torch.jit.ScriptModule instance") + if not isinstance(inputs, (tuple, list)): + raise TypeError("inputs must be a tuple or list") + + # Register hooks. + entries = [] + nesting = [0] + + def pre_hook(_mod, _inputs): + nesting[0] += 1 + + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= { + id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs + } + + # Filter out redundant entries. + if skip_redundant: + entries = [ + e + for e in entries + if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs) + ] + + # Construct table. + rows = [ + [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"] + ] + rows += [["---"] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = "" if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs] + rows += [ + [ + name + (":0" if len(e.outputs) >= 2 else ""), + str(param_size) if param_size else "-", + str(buffer_size) if buffer_size else "-", + (output_shapes + ["-"])[0], + (output_dtypes + ["-"])[0], + ] + ] + for idx in range(1, len(e.outputs)): + rows += [ + [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]] + ] + param_total += param_size + buffer_total += buffer_size + rows += [["---"] * len(rows[0])] + rows += [["Total", str(param_total), str(buffer_total), "-", "-"]] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + for row in rows: + print( + " ".join( + cell + " " * (width - len(cell)) for cell, width in zip(row, widths) + ) + ) + return outputs + + +# ---------------------------------------------------------------------------- diff --git a/src/utils/inference_utils.py b/src/utils/inference_utils.py new file mode 100644 index 0000000..842bdd3 --- /dev/null +++ b/src/utils/inference_utils.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import cftime +import nvtx +import torch +import tqdm + +from src.utils.function_utils import StackedRandomGenerator, time_range + +############################################################################ +# CorrDiff Generation Utilities # +############################################################################ + + +def regression_step( + net: torch.nn.Module, + img_lr: torch.Tensor, + latents_shape: torch.Size, + lead_time_label: torch.Tensor = None, +) -> torch.Tensor: + """ + Given a low-res input, performs a regression step to produce ensemble mean. + This function performs the regression on a single instance and then replicates + the results across the batch dimension. + + Args: + net (torch.nn.Module): U-Net model for regression. + img_lr (torch.Tensor): Low-resolution input. + latents_shape (torch.Size): Shape of the latent representation. Typically + (batch_size, out_channels, image_shape_x, image_shape_y). + + + Returns: + torch.Tensor: Predicted output at the next time step. + """ + # Create a tensor of zeros with the given shape and move it to the appropriate device + x_hat = torch.zeros(latents_shape, dtype=torch.float64, device=net.device) + t_hat = torch.tensor(1.0, dtype=torch.float64, device=net.device) + + # Perform regression on a single batch element + with torch.inference_mode(): + if lead_time_label is not None: + x = net(x_hat[0:1], img_lr, t_hat, lead_time_label=lead_time_label) + else: + x = net(x_hat[0:1], img_lr, t_hat) + + # If the batch size is greater than 1, repeat the prediction + if x_hat.shape[0] > 1: + x = x.repeat([d if i == 0 else 1 for i, d in enumerate(x_hat.shape)]) + + return x + + +def diffusion_step( # TODO generalize the module and add defaults + net: torch.nn.Module, + sampler_fn: callable, + seed_batch_size: int, + img_shape: tuple, + img_out_channels: int, + rank_batches: list, + img_lr: torch.Tensor, + rank: int, + device: torch.device, + hr_mean: torch.Tensor = None, + lead_time_label: torch.Tensor = None, +) -> torch.Tensor: + + """ + Generate images using diffusion techniques as described in the relevant paper. + + Args: + net (torch.nn.Module): The diffusion model network. + sampler_fn (callable): Function used to sample images from the diffusion model. + seed_batch_size (int): Number of seeds per batch. + img_shape (tuple): Shape of the images, (height, width). + img_out_channels (int): Number of output channels for the image. + rank_batches (list): List of batches of seeds to process. + img_lr (torch.Tensor): Low-resolution input image. + rank (int): Rank of the current process for distributed processing. + device (torch.device): Device to perform computations. + mean_hr (torch.Tensor, optional): High-resolution mean tensor, to be used as an additional input. By default None. + + Returns: + torch.Tensor: Generated images concatenated across batches. + """ + + img_lr = img_lr.to(memory_format=torch.channels_last) + + # Handling of the high-res mean + additional_args = {} + if hr_mean is not None: + additional_args["mean_hr"] = hr_mean + if lead_time_label is not None: + additional_args["lead_time_label"] = lead_time_label + additional_args["img_shape"] = img_shape + + # Loop over batches + all_images = [] + for batch_seeds in tqdm.tqdm(rank_batches, unit="batch", disable=(rank != 0)): + with nvtx.annotate(f"generate {len(all_images)}", color="rapids"): + batch_size = len(batch_seeds) + if batch_size == 0: + continue + + # Initialize random generator, and generate latents + rnd = StackedRandomGenerator(device, batch_seeds) + latents = rnd.randn( + [ + seed_batch_size, + img_out_channels, + img_shape[0], + img_shape[1], + ], + device=device, + ).to(memory_format=torch.channels_last) + + with torch.inference_mode(): + images = sampler_fn( + net, latents, img_lr, randn_like=rnd.randn_like, **additional_args + ) + all_images.append(images) + return torch.cat(all_images) + + +############################################################################ +# CorrDiff writer utilities # +############################################################################ + + +class NetCDFWriter: + """NetCDF Writer""" + + def __init__( + self, f, lat, lon, input_channels, output_channels, has_lead_time=False + ): + self._f = f + self.has_lead_time = has_lead_time + # create unlimited dimensions + f.createDimension("time") + f.createDimension("ensemble") + + if lat.shape != lon.shape: + raise ValueError("lat and lon must have the same shape") + ny, nx = lat.shape + + # create lat/lon grid + f.createDimension("x", nx) + f.createDimension("y", ny) + + v = f.createVariable("lat", "f", dimensions=("y", "x")) + # NOTE rethink this for datasets whose samples don't have constant lat-lon. + v[:] = lat + v.standard_name = "latitude" + v.units = "degrees_north" + + v = f.createVariable("lon", "f", dimensions=("y", "x")) + v[:] = lon + v.standard_name = "longitude" + v.units = "degrees_east" + + # create time dimension + if has_lead_time: + v = f.createVariable("time", "str", ("time")) + else: + v = f.createVariable("time", "i8", ("time")) + v.calendar = "standard" + v.units = "hours since 1990-01-01 00:00:00" + + self.truth_group = f.createGroup("truth") + self.prediction_group = f.createGroup("prediction") + self.input_group = f.createGroup("input") + + for variable in output_channels: + name = variable.name + variable.level + self.truth_group.createVariable(name, "f", dimensions=("time", "y", "x")) + self.prediction_group.createVariable( + name, "f", dimensions=("ensemble", "time", "y", "x") + ) + + # setup input data in netCDF + + for variable in input_channels: + name = variable.name + variable.level + self.input_group.createVariable(name, "f", dimensions=("time", "y", "x")) + + def write_input(self, channel_name, time_index, val): + """Write input data to NetCDF file.""" + self.input_group[channel_name][time_index] = val + + def write_truth(self, channel_name, time_index, val): + """Write ground truth data to NetCDF file.""" + self.truth_group[channel_name][time_index] = val + + def write_prediction(self, channel_name, time_index, ensemble_index, val): + """Write prediction data to NetCDF file.""" + self.prediction_group[channel_name][ensemble_index, time_index] = val + + def write_time(self, time_index, time): + """Write time information to NetCDF file.""" + if self.has_lead_time: + self._f["time"][time_index] = time + else: + time_v = self._f["time"] + self._f["time"][time_index] = cftime.date2num( + time, time_v.units, time_v.calendar + ) + + +############################################################################ +# CorrDiff time utilities # +############################################################################ + + +def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"): + """Generates a list of times within a given range. + + Args: + times_range: A list containing start time, end time, and optional interval (hours). + time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S"). + + Returns: + A list of times within the specified range. + """ + + start_time = datetime.datetime.strptime(times_range[0], time_format) + end_time = datetime.datetime.strptime(times_range[1], time_format) + interval = ( + datetime.timedelta(hours=times_range[2]) + if len(times_range) > 2 + else datetime.timedelta(hours=1) + ) + + times = [ + t.strftime(time_format) + for t in time_range(start_time, end_time, interval, inclusive=True) + ] + return times diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py new file mode 100644 index 0000000..e1cde9d --- /dev/null +++ b/src/utils/model_utils.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + + +def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): + """ + Unified routine for initializing weights and biases. + This function provides a unified interface for various weight initialization + strategies like Xavier (Glorot) and Kaiming (He) initializations. + + Parameters + ---------- + shape : tuple + The shape of the tensor to initialize. It could represent weights or biases + of a layer in a neural network. + mode : str + The mode/type of initialization to use. Supported values are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + fan_in : int + The number of input units in the weight tensor. For convolutional layers, + this typically represents the number of input channels times the kernel height + times the kernel width. + fan_out : int + The number of output units in the weight tensor. For convolutional layers, + this typically represents the number of output channels times the kernel height + times the kernel width. + + Returns + ------- + torch.Tensor + The initialized tensor based on the specified mode. + + Raises + ------ + ValueError + If the provided `mode` is not one of the supported initialization modes. + """ + if mode == "xavier_uniform": + return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == "xavier_normal": + return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == "kaiming_uniform": + return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == "kaiming_normal": + return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') diff --git a/src/utils/stochastic_sampler.py b/src/utils/stochastic_sampler.py new file mode 100644 index 0000000..ddcf9cc --- /dev/null +++ b/src/utils/stochastic_sampler.py @@ -0,0 +1,533 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Any, Callable, Optional + +import torch +from torch import Tensor + + +def image_batching( + input: Tensor, + img_shape_y: int, + img_shape_x: int, + patch_shape_y: int, + patch_shape_x: int, + batch_size: int, + overlap_pix: int, + boundary_pix: int, + input_interp: Optional[Tensor] = None, +) -> Tensor: + """ + Splits a full image into a batch of patched images. + + This function takes a full image and splits it into patches, adding padding where necessary. + It can also concatenate additional interpolated data to each patch if provided. + + Parameters + ---------- + input : Tensor + The input tensor representing the full image with shape (batch_size, channels, img_shape_x, img_shape_y). + img_shape_x : int + The width (x-dimension) of the original full image. + img_shape_y : int + The height (y-dimension) of the original full image. + patch_shape_x : int + The width (x-dimension) of each image patch. + patch_shape_y : int + The height (y-dimension) of each image patch. + batch_size : int + The original batch size before patching. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + input_interp : Optional[Tensor], optional + Optional additional data to concatenate to each patch with shape (batch_size, interp_channels, patch_shape_x, patch_shape_y). + By default None. + + Returns + ------- + Tensor + A tensor containing the image patches, with shape (total_patches * batch_size, channels [+ interp_channels], patch_shape_x, patch_shape_y). + """ + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + input_padded = torch.zeros( + input.shape[0], input.shape[1], padded_shape_y, padded_shape_x + ).to(input.device) + image_padding = torch.nn.ReflectionPad2d( + (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + ).to( + input.device + ) # (padding_left,padding_right,padding_top,padding_bottom) + input_padded = image_padding(input) + patch_num = patch_num_x * patch_num_y + if input_interp is not None: + output = torch.zeros( + patch_num * batch_size, + input.shape[1] + input_interp.shape[1], + patch_shape_y, + patch_shape_x, + ).to(input.device) + else: + output = torch.zeros( + patch_num * batch_size, input.shape[1], patch_shape_y, patch_shape_x + ).to(input.device) + for x_index in range(patch_num_x): + for y_index in range(patch_num_y): + x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) + y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) + if input_interp is not None: + output[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + ] = torch.cat( + ( + input_padded[ + :, + :, + y_start : y_start + patch_shape_y, + x_start : x_start + patch_shape_x, + ], + input_interp, + ), + dim=1, + ) + else: + output[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + ] = input_padded[ + :, + :, + y_start : y_start + patch_shape_y, + x_start : x_start + patch_shape_x, + ] + return output + + +def image_fuse( + input: Tensor, + img_shape_y: int, + img_shape_x: int, + patch_shape_y: int, + patch_shape_x: int, + batch_size: int, + overlap_pix: int, + boundary_pix: int, +) -> Tensor: + """ + Reconstructs a full image from a batch of patched images. + + This function takes a batch of image patches and reconstructs the full image + by stitching the patches together. The function accounts for overlapping and + boundary pixels, ensuring that overlapping areas are averaged. + + Parameters + ---------- + input : Tensor + The input tensor containing the image patches with shape (total_patches * batch_size, channels, patch_shape_x, patch_shape_y). + img_shape_x : int + The width (x-dimension) of the original full image. + img_shape_y : int + The height (y-dimension) of the original full image. + patch_shape_x : int + The width (x-dimension) of each image patch. + patch_shape_y : int + The height (y-dimension) of each image patch. + batch_size : int + The original batch size before patching. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + + Returns + ------- + Tensor + The reconstructed full image tensor with shape (batch_size, channels, img_shape_x, img_shape_y). + + """ + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + residual_x = patch_shape_x - pad_x_right # residual pixels in the last patch + residual_y = patch_shape_y - pad_y_right # residual pixels in the last patch + output = torch.zeros( + batch_size, input.shape[1], img_shape_y, img_shape_x, device=input.device + ) + one_map = torch.ones(1, 1, input.shape[2], input.shape[3], device=input.device) + count_map = torch.zeros( + 1, 1, img_shape_y, img_shape_x, device=input.device + ) # to count the overlapping times + for x_index in range(patch_num_x): + for y_index in range(patch_num_y): + x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) + y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) + if (x_index == patch_num_x - 1) and (y_index != patch_num_y - 1): + output[ + :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: + ] += input[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + :, + boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : residual_x + boundary_pix, + ] + count_map[ + :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: + ] += one_map[ + :, + :, + boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : residual_x + boundary_pix, + ] + elif (y_index == patch_num_y - 1) and ((x_index != patch_num_x - 1)): + output[ + :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix + ] += input[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + :, + boundary_pix : residual_y + boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, + ] + count_map[ + :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix + ] += one_map[ + :, + :, + boundary_pix : residual_y + boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, + ] + elif x_index == patch_num_x - 1 and y_index == patch_num_y - 1: + output[:, :, y_start:, x_start:] += input[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + :, + boundary_pix : residual_y + boundary_pix, + boundary_pix : residual_x + boundary_pix, + ] + count_map[:, :, y_start:, x_start:] += one_map[ + :, + :, + boundary_pix : residual_y + boundary_pix, + boundary_pix : residual_x + boundary_pix, + ] + else: + output[ + :, + :, + y_start : y_start + patch_shape_y - 2 * boundary_pix, + x_start : x_start + patch_shape_x - 2 * boundary_pix, + ] += input[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + :, + boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, + ] + count_map[ + :, + :, + y_start : y_start + patch_shape_y - 2 * boundary_pix, + x_start : x_start + patch_shape_x - 2 * boundary_pix, + ] += one_map[ + :, + :, + boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, + ] + return output / count_map + + +def stochastic_sampler( + net: Any, + latents: Tensor, + img_lr: Tensor, + class_labels: Optional[Tensor] = None, + randn_like: Callable[[Tensor], Tensor] = torch.randn_like, + img_shape: int = 448, + patch_shape: int = 448, + overlap_pix: int = 4, + boundary_pix: int = 2, + mean_hr: Optional[Tensor] = None, + lead_time_label: Optional[Tensor] = None, + num_steps: int = 18, + sigma_min: float = 0.002, + sigma_max: float = 800, + rho: float = 7, + S_churn: float = 0, + S_min: float = 0, + S_max: float = float("inf"), + S_noise: float = 1, +) -> Tensor: + """ + Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution and patch-based diffusion. + + Parameters + ---------- + net : Any + The neural network model that generates denoised images from noisy inputs. + latents : Tensor + The latent variables (e.g., noise) used as the initial input for the sampler. + img_lr : Tensor + Low-resolution input image for conditioning the super-resolution process. + class_labels : Optional[Tensor], optional + Class labels for conditional generation, if required by the model. By default None. + randn_like : Callable[[Tensor], Tensor] + Function to generate random noise with the same shape as the input tensor. + By default torch.randn_like. + img_shape : int + The height and width of the full image (assumed to be square). By default 448. + patch_shape : int + The height and width of each patch (assumed to be square). By default 448. + overlap_pix : int + Number of overlapping pixels between adjacent patches. By default 4. + boundary_pix : int + Number of pixels to be cropped as a boundary from each patch. By default 2. + mean_hr : Optional[Tensor], optional + Optional tensor containing mean high-resolution images for conditioning. By default None. + num_steps : int + Number of time steps for the sampler. By default 18. + sigma_min : float + Minimum noise level. By default 0.002. + sigma_max : float + Maximum noise level. By default 800. + rho : float + Exponent used in the time step discretization. By default 7. + S_churn : float + Churn parameter controlling the level of noise added in each step. By default 0. + S_min : float + Minimum time step for applying churn. By default 0. + S_max : float + Maximum time step for applying churn. By default float("inf"). + S_noise : float + Noise scaling factor applied during the churn step. By default 1. + + Returns + ------- + Tensor + The final denoised image produced by the sampler. + """ + + # Adjust noise levels based on what's supported by the network. + "Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution." + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + if isinstance(img_shape, tuple): + img_shape_y, img_shape_x = img_shape + else: + img_shape_x = img_shape_y = img_shape + if patch_shape > img_shape_x or patch_shape > img_shape_y: + patch_shape = min(img_shape_x, img_shape_y) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + t_steps = torch.cat( + [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + ) # t_N = 0 + + b = latents.shape[0] + Nx = torch.arange(img_shape_x) + Ny = torch.arange(img_shape_y) + grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ + None, + ].expand(b, -1, -1, -1) + + # conditioning = [mean_hr, img_lr, global_lr, pos_embd] + batch_size = img_lr.shape[0] + x_lr = img_lr + if mean_hr is not None: + x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) + global_index = None + + # input and position padding + patching + if patch_shape != img_shape_x or patch_shape != img_shape_y: + input_interp = torch.nn.functional.interpolate( + img_lr, (patch_shape, patch_shape), mode="bilinear" + ) + x_lr = image_batching( + x_lr, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + input_interp, + ) + global_index = image_batching( + grid.float(), + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ).int() + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + # Increase noise temporarily. + gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + + x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. Perform patching operation on score tensor if patch-based generation is used + # denoised = net(x_hat, t_hat, class_labels,lead_time_label=lead_time_label).to(torch.float64) #x_lr + + if patch_shape != img_shape_x or patch_shape != img_shape_y: + x_hat_batch = image_batching( + x_hat, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ) + else: + x_hat_batch = x_hat + x_hat_batch = x_hat_batch.to(latents.device) + x_lr = x_lr.to(latents.device) + if global_index is not None: + global_index = global_index.to(latents.device) + + if lead_time_label is not None: + denoised = net( + x_hat_batch, + x_lr, + t_hat, + class_labels, + lead_time_label=lead_time_label, + global_index=global_index, + ).to(torch.float64) + else: + denoised = net( + x_hat_batch, + x_lr, + t_hat, + class_labels, + global_index=global_index, + ).to(torch.float64) + if patch_shape != img_shape_x or patch_shape != img_shape_y: + + denoised = image_fuse( + denoised, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + if patch_shape != img_shape_x or patch_shape != img_shape_y: + x_next_batch = image_batching( + x_next, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ) + else: + x_next_batch = x_next + # ask about this fix + x_next_batch = x_next_batch.to(latents.device) + if lead_time_label is not None: + denoised = net( + x_next_batch, + x_lr, + t_next, + class_labels, + lead_time_label=lead_time_label, + global_index=global_index, + ).to(torch.float64) + else: + denoised = net( + x_next_batch, + x_lr, + t_next, + class_labels, + global_index=global_index, + ).to(torch.float64) + if patch_shape != img_shape_x or patch_shape != img_shape_y: + denoised = image_fuse( + denoised, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next diff --git a/src/utils/train_helpers.py b/src/utils/train_helpers.py new file mode 100644 index 0000000..d4529ac --- /dev/null +++ b/src/utils/train_helpers.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from omegaconf import ListConfig + + +def set_patch_shape(img_shape, patch_shape): + img_shape_y, img_shape_x = img_shape + patch_shape_y, patch_shape_x = patch_shape + if (patch_shape_x is None) or (patch_shape_x > img_shape_x): + patch_shape_x = img_shape_x + if (patch_shape_y is None) or (patch_shape_y > img_shape_y): + patch_shape_y = img_shape_y + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: + if patch_shape_x != patch_shape_y: + raise NotImplementedError("Rectangular patch not supported yet") + if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0: + raise ValueError("Patch shape needs to be a multiple of 32") + return (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) + + +def set_seed(rank): + """ + Set seeds for NumPy and PyTorch to ensure reproducibility in distributed settings + """ + np.random.seed(rank % (1 << 31)) + torch.manual_seed(np.random.randint(1 << 31)) + + +def configure_cuda_for_consistent_precision(): + """ + Configures CUDA and cuDNN settings to ensure consistent precision by + disabling TensorFloat-32 (TF32) and reduced precision settings. + """ + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + +def compute_num_accumulation_rounds(total_batch_size, batch_size_per_gpu, world_size): + """ + Calculate the total batch size per GPU in a distributed setting, log the batch size per GPU, ensure it's within valid limits, + determine the number of accumulation rounds, and validate that the global batch size matches the expected value. + """ + batch_gpu_total = total_batch_size // world_size + batch_size_per_gpu = batch_size_per_gpu + if batch_size_per_gpu is None or batch_size_per_gpu > batch_gpu_total: + batch_size_per_gpu = batch_gpu_total + num_accumulation_rounds = batch_gpu_total // batch_size_per_gpu + if total_batch_size != batch_size_per_gpu * num_accumulation_rounds * world_size: + raise ValueError( + "total_batch_size must be equal to batch_size_per_gpu * num_accumulation_rounds * world_size" + ) + return batch_gpu_total, num_accumulation_rounds + + +def handle_and_clip_gradients(model, grad_clip_threshold=None): + """ + Handles NaNs and infinities in the gradients and optionally clips the gradients. + + Parameters: + - model (torch.nn.Module): The model whose gradients need to be processed. + - grad_clip_threshold (float, optional): The threshold for gradient clipping. If None, no clipping is performed. + """ + # Replace NaNs and infinities in gradients + for param in model.parameters(): + if param.grad is not None: + torch.nan_to_num( + param.grad, nan=0.0, posinf=1e5, neginf=-1e5, out=param.grad + ) + + # Clip gradients if a threshold is provided + if grad_clip_threshold is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold) + + +def parse_model_args(args): + """Convert ListConfig values in args to tuples.""" + return {k: tuple(v) if isinstance(v, ListConfig) else v for k, v in args.items()} + + +def is_time_for_periodic_task( + cur_nimg, freq, done, batch_size, rank, rank_0_only=False +): + """Should we perform a task that is done every `freq` samples?""" + if rank_0_only and rank != 0: + return False + elif done: # Run periodic tasks also at the end of training + return True + else: + return cur_nimg % freq < batch_size From 3b172ccc9093ceaa3b4a18a76e26d770a6301253 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 14 Apr 2025 17:14:37 +0200 Subject: [PATCH 006/189] restructure file system and add pyproject.toml --- README.md | 0 pyproject.toml | 24 +++++++++++ src/hirad/conf/train_regression.yaml | 0 src/{ => hirad}/distributed/__init__.py | 0 src/{ => hirad}/distributed/config.py | 0 src/{ => hirad}/distributed/manager.py | 0 src/hirad/losses/__init__.py | 0 src/{ => hirad}/losses/loss.py | 0 src/{ => hirad}/models/__init__.py | 0 .../__pycache__/__init__.cpython-312.pyc | Bin 0 -> 185 bytes .../models/__pycache__/dummy.cpython-312.pyc | Bin 0 -> 451 bytes .../models/__pycache__/unet.cpython-312.pyc | Bin 0 -> 8966 bytes src/{ => hirad}/models/layers.py | 0 src/{ => hirad}/models/preconditioning.py | 0 src/{ => hirad}/models/song_unet.py | 0 src/{ => hirad}/models/unet.py | 0 src/{ => hirad}/models/utils.py | 0 src/hirad/training/train.py | 0 src/hirad/utils/__init__.py | 0 src/{ => hirad}/utils/capture.py | 0 src/{ => hirad}/utils/checkpoint.py | 0 src/{ => hirad}/utils/console.py | 0 .../utils/deterministic_sampler.py | 0 src/{ => hirad}/utils/function_utils.py | 0 src/{ => hirad}/utils/inference_utils.py | 0 src/{ => hirad}/utils/model_utils.py | 0 src/{ => hirad}/utils/stochastic_sampler.py | 0 src/{ => hirad}/utils/train_helpers.py | 0 src/hirad_gen.egg-info/PKG-INFO | 38 ++++++++++++++++++ src/hirad_gen.egg-info/SOURCES.txt | 28 +++++++++++++ src/hirad_gen.egg-info/dependency_links.txt | 1 + src/hirad_gen.egg-info/top_level.txt | 7 ++++ 32 files changed, 98 insertions(+) create mode 100644 README.md create mode 100644 pyproject.toml create mode 100644 src/hirad/conf/train_regression.yaml rename src/{ => hirad}/distributed/__init__.py (100%) rename src/{ => hirad}/distributed/config.py (100%) rename src/{ => hirad}/distributed/manager.py (100%) create mode 100644 src/hirad/losses/__init__.py rename src/{ => hirad}/losses/loss.py (100%) rename src/{ => hirad}/models/__init__.py (100%) create mode 100644 src/hirad/models/__pycache__/__init__.cpython-312.pyc create mode 100644 src/hirad/models/__pycache__/dummy.cpython-312.pyc create mode 100644 src/hirad/models/__pycache__/unet.cpython-312.pyc rename src/{ => hirad}/models/layers.py (100%) rename src/{ => hirad}/models/preconditioning.py (100%) rename src/{ => hirad}/models/song_unet.py (100%) rename src/{ => hirad}/models/unet.py (100%) rename src/{ => hirad}/models/utils.py (100%) create mode 100644 src/hirad/training/train.py create mode 100644 src/hirad/utils/__init__.py rename src/{ => hirad}/utils/capture.py (100%) rename src/{ => hirad}/utils/checkpoint.py (100%) rename src/{ => hirad}/utils/console.py (100%) rename src/{ => hirad}/utils/deterministic_sampler.py (100%) rename src/{ => hirad}/utils/function_utils.py (100%) rename src/{ => hirad}/utils/inference_utils.py (100%) rename src/{ => hirad}/utils/model_utils.py (100%) rename src/{ => hirad}/utils/stochastic_sampler.py (100%) rename src/{ => hirad}/utils/train_helpers.py (100%) create mode 100644 src/hirad_gen.egg-info/PKG-INFO create mode 100644 src/hirad_gen.egg-info/SOURCES.txt create mode 100644 src/hirad_gen.egg-info/dependency_links.txt create mode 100644 src/hirad_gen.egg-info/top_level.txt diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b2fa56c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "hirad-gen" +version = "0.1.0" +description = "High resolution atmospheric downscaling using generative machine learning" +authors = [ + { name="Petar Stamenkovic", email="petar.stamenkovic@meteoswiss.ch" } +] +readme = "README.md" +requires-python = ">=3.12" +license = {file = "LICENSE"} + +dependencies = [ + "torch>=2.6.0" +] + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] \ No newline at end of file diff --git a/src/hirad/conf/train_regression.yaml b/src/hirad/conf/train_regression.yaml new file mode 100644 index 0000000..e69de29 diff --git a/src/distributed/__init__.py b/src/hirad/distributed/__init__.py similarity index 100% rename from src/distributed/__init__.py rename to src/hirad/distributed/__init__.py diff --git a/src/distributed/config.py b/src/hirad/distributed/config.py similarity index 100% rename from src/distributed/config.py rename to src/hirad/distributed/config.py diff --git a/src/distributed/manager.py b/src/hirad/distributed/manager.py similarity index 100% rename from src/distributed/manager.py rename to src/hirad/distributed/manager.py diff --git a/src/hirad/losses/__init__.py b/src/hirad/losses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/losses/loss.py b/src/hirad/losses/loss.py similarity index 100% rename from src/losses/loss.py rename to src/hirad/losses/loss.py diff --git a/src/models/__init__.py b/src/hirad/models/__init__.py similarity index 100% rename from src/models/__init__.py rename to src/hirad/models/__init__.py diff --git a/src/hirad/models/__pycache__/__init__.cpython-312.pyc b/src/hirad/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70d5748263686e7ee01a00c9c9ffd61e4261990b GIT binary patch literal 185 zcmX@j%ge<81eP(8=gQRm475F^(AnLQ) zQDjD9>o~pA)a3YQB1#<>S}gTyk;VCqxJWL~XOnrH3tdDayH#g=oRqaZdZNZE447`7==k|mj}EK82PsyZ3&%#yR( zo!Q)(l}KR)vxSQ;+q#mP6r@`haex4oDiu*4g2+$KQ(yWbg(|!mH~@j92#mZ?pxh>Y zNzS?RvrBqy2Tp*#_zF6E&;2>~@7#0GcZUBJk4G3t8zaA+{guElzd*)K0VlI@7c$F? z%*bq!$>OiC$d-IrA8VKSvwm8~mG~@A%lySaDVPnijE_;6Ledr3_?5Yvtj%S^iqFR! zWn}(MMh+;!n|{0erqehZp|wG%4Jmw~t0BY9xjnKm+A<6+BVW)mzQCrVAEAKhG;4)r z$&_+M$uNM4PwBE!oKj3_5=t*c43f)~u*k@{1LChUOJ?Jrfmw!CW?7lZ`p&=}W&H}L z@Urh&sFD3ySQBuZ!-XivJ6xD@0m=ol5z2*}x+vwsPF;+05vMLrxv0Y>C>L|MHp<13 zTS%natdP4Fi#s!|m=>o=CFO|aFQ}#!%~fPcoF!6u&ai?~sVwHomD6@^wp_77rcQEm zVt(G@X7VEkEKb+7c`K5yXgQHqSz)PS>L@{;PA9F9D57*xv?8JiyH+V;J|>Fasz^n< zhJ;{ZmftXm)h3FPrs<|+s=8)~A_-#^H>2xCf?7$UNMcChND@fekf279b|5#H4WP^? zgToa=A;xgoFkyV{{P2tFsqx97qlz{R2RLjW;BZA#%uIQKpkhQ*i@+!Autvv;O zAA=ri?I;CuU2EHW7mr`%R=M7K*UrUbHS<2Vlk~U)Z_*_mN-=G`2;^^?v?$h;iL^ED>V%E33oSzeCFp*PuVKnBgv25Zbbn+?6fq@yGV;-vrJ7$zb>Ix_^P zFMI>ex~vc(uM=THCuCC1=Q9!7;G{%g?@ZW6$__OW<<-)xNEAaaR`9rlF+tT#T7kk( z&nZGvrJff|C{eX?#T?cvCj1C$Nt#tMo~Ya$oITvCErNvUN@)h>2NH665hm%jYxGn# z(JkE4xzU%7{FxG9Mqez#nRcsLfUb%8^2mWP0d;jy(95_bQqkLv!}Ch6Vk&eKDw?d2 z;)1Hp3RDdR$rO&9962DAiIP(d==l7gFtZ@YN?xiIp-w6q&ORCHY)KMJsx~I%i@IdC z7->q?)KaA+7?pBaC#E6`nywm(P*g4`FtaAP`|KTbWJ>d&BU76H17*@#4fLy8zCYb8 z9u4Ye%Rf9{R&qF1xU8Bi5gFEgL8Vpz&mA?#BoZtPwhSe0S;o`$NOUtrbD^w^fgsXi zZzj;B3uty+v0Kk#ljQ!%IfZ+{f8d`rsS-7af^r5Z- zR>w5xwZi#t9SE)%RgP0iULlH>Q?}kQS~~@MA!+KMFqIh;4naHpqH10m6sCcl(B~9w zP3~7Iq)g|iCtz)r-EUYVpZlX zu&*XqK|9@IZ#td_%T#?;Ms#0t#>D4`rYdWC@^9j7Ad$8zJI_LU*DGQmk0aBhZL6}f z56%0g7Od=d%Dk3UKeWI%kftgJTw`1wIB(=?061Ki-7*LqzJ)dy0uCQVn#?{g_lTRT zhTWN8j{2BtROYKJjLltw232@X-tku>-)H{;SLFF;3rspNeS@_)cm^bT3?JcM-SAt9 z&DP(FHW^_n)@)BnH}puiTRd1Kg`h_O7ukfz9rw+Op&ZwGTM;mS8XAF|g69qX2v!?& zQdtq_o0$cx4QrYQO=3awz*fjL%U0MkoL*J}x5X|+i(&<5VTC66APH~@;0~X6ycLV5 z_dOlAI72Ds$zw2Sl17q2GK^#&k`W}MNcIDfc$2WdANfKCzgWuLf!Ao-3@i$5B0D04%E2&p`M2U zCf0RT`+n|zbZ>q1g}bpAYLoY4y({K_?;2j)HFRCNw`(6QsW&rgyE53jw=V3%-@dhO zp~n4vyxy}1eiH7_j>mv&`>?C;9qoEywd-Il`m>a<&hTB4`p$H{yZ`Rqv3uQP^`7*D zAe60#n6{k{!jP+RZ$+rjlXG1hSOtU1Pp|75T&-oegn!v%I#-fm3eI)IhHe7I_2qps zd&}pkG+Sj0%`%;)jf-FKuBj`9Vcd&<87{(??_Kb&YeTZ^ms$LG(=jub;gT1;i`-~O zo~&|48s_d@S!H+R#?KrNbrEeT3nto;`6f;c$mfXt|Dl}F=rIam$0loB}q4D8Fk zrupIR*^=Cu7`+VxDF>^8e*q}jS}xRt|CM)F|5q%u`tn*9s-hbnnI?N+UVtq9vReA+ zcaSU&Vb=|x#mnfvSOIi@zy!u@+fbBdz%Tx!$8(d=p`a}VQx|3w;gbmXE#Ohl>Zb5X z7`$x5R3teASdX_snLj6)me16!Aa&68k6J-+&e5#e#_JGF$l~bOc8ie%xk$kDqu&ZJ z5x^rl%KJyjL1<#j^B)1z?w=Ir4ZKIL^%JP|gZb(NJT3WBj9Ng4TY!c}KbVL`W-Z!& zH`@1MGPNAO7G4=y$*(+ledK;}s5ZV9i(i>sn*2^`x$j!vYHU}HtwXEOwa`1M@Akdj zx0>w#AUS+5Iego9d;Ipd?(AMoK3n7015A5rIsT9F?-g#p{NvYt_}Y)p{qWpR@;`n0 zrH6hd-f?B>Z>N6Q-h~p^gJ5{qV@z95O}iypLdLR0I@x%i@*cMrjFgqSuk75p_?0gW1T+P4U!NikSjx8NqN&PhT z$SwYZf$@6-((E`kfO$S-6{gdBgAPOsq2?79_TH zT*)uxzjJ2!+_iJ7iAQVxdRs@$F}cl#6nz~0zcB~|{G%^0x9$Z1^<|~w&^Y3r(}b>F zPzbng!YsIjnm|ulxPaO)V5Ct$CBalQ16)G%@)QHSV(QpC7!&1e!I@aJ1!`iEEqP2h zqk*$|mg=!kR1I)wsg7q((^1ZOGs8Zo;i7zcMmc-VO*1&9m=&VgE;#mchWBJOkKh%$ z1GHS83oyz=zo5zpZSV?s8)JL6C?V}5N8t=0ZeobGBq+fPL{0-S@X!S3GDL|EFhZn~ zE+y%OXewF3=AA&mDd=UAsVi?Ry|GeRZF~G~=yCey-e8Zw-9)2Luo*!EZUseA)^iY( z;pvW2cZ0l!)n~Cf?)WD;rC1cj8!Y)clp3flFWBx0y0J8_fbNiiDXTH==@p7_7(oygZ0j?dgo&wryl1c8`Eqb-}%4KCHN=o2nb5a!GABgMs@=&)0c>V zgzX3jjeP{s=vwht;1Ej_V6<@KBnz~ko}?WGx!4p1@fvH@KnHsL)L< zzY42-6;}BwtkM)qbMBiRG9VbuB?yJBkQ*+@H9PBQb|~<83M|@|I6>5j6S$TL;`le6 zR}x@{lVceZWzYmM0uRNzu*8cIcv1jedk#5`5xDVyMhjNvJ{32>xI|;5KS{*E-09BT z#C@Wv8ggfT8NRYqBXR(tR@E=_5FH3R(E$v&(*)6h;53qvq4=Xd-X_MwjE80P5Hg;)ch zT_-hz>@LsnZ$jJu2|bK|2ZCt3yMeafePi{JXDHCvksJPjM@SS0+emk zx0uysjzZZ)V=On!(1L zuktOWa1M}d%&yfJPG^w-WYJX61b3AKxEgFN7a{H@Aj|&|8W5-kz^4jJZ-Ek>s%V0Q zAzy@PLKy?>j#uJcC-7Ki!D*pEV(M-+dEVm?MIWP1kb5QPg@I;wJB{8Del7wifw-3n z*wg{Wc2K1-u=(=~^re;HqZuYfra?x&Ebbo_)Vy$-RFsS`X@6ufu*&(+;t4+`!H1JA zvuohoG2vL=nHokA1w_vR_Yy2b`XI5QOL04{;JfVGk?XsE(0`-<&hGaI-W#|dn|Q$U zfW`xe%Y&G6;+7PbC;kXr4st#lmy2(0hp&fR?E|BkSG_)Z&qVc&GZ3acgjU|3XVtSBlpjXfy_K|^Fi zRQuz9M$;@G{NG-lTQJm|p(!QZ`GbfI{Y-#X+KSuKoG7~ee52!I1*Yg<0Z`0n`wr+Q z*fYV9ofUc(0_5;_0?!c*SbV08IUwtPmSs2gFl_tJnccr&Vn1h6|HF)|wGUkF`JjE^ zUi-lH-qrR4iwQVBU-Y3b!S>cJZ!qw);fI2CEO@wQo@Jk?jb1spbnqd36=$D#c)>rz TvYl5CEg!pf>|+M=bSA$CZ1ZH& literal 0 HcmV?d00001 diff --git a/src/models/layers.py b/src/hirad/models/layers.py similarity index 100% rename from src/models/layers.py rename to src/hirad/models/layers.py diff --git a/src/models/preconditioning.py b/src/hirad/models/preconditioning.py similarity index 100% rename from src/models/preconditioning.py rename to src/hirad/models/preconditioning.py diff --git a/src/models/song_unet.py b/src/hirad/models/song_unet.py similarity index 100% rename from src/models/song_unet.py rename to src/hirad/models/song_unet.py diff --git a/src/models/unet.py b/src/hirad/models/unet.py similarity index 100% rename from src/models/unet.py rename to src/hirad/models/unet.py diff --git a/src/models/utils.py b/src/hirad/models/utils.py similarity index 100% rename from src/models/utils.py rename to src/hirad/models/utils.py diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hirad/utils/__init__.py b/src/hirad/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/capture.py b/src/hirad/utils/capture.py similarity index 100% rename from src/utils/capture.py rename to src/hirad/utils/capture.py diff --git a/src/utils/checkpoint.py b/src/hirad/utils/checkpoint.py similarity index 100% rename from src/utils/checkpoint.py rename to src/hirad/utils/checkpoint.py diff --git a/src/utils/console.py b/src/hirad/utils/console.py similarity index 100% rename from src/utils/console.py rename to src/hirad/utils/console.py diff --git a/src/utils/deterministic_sampler.py b/src/hirad/utils/deterministic_sampler.py similarity index 100% rename from src/utils/deterministic_sampler.py rename to src/hirad/utils/deterministic_sampler.py diff --git a/src/utils/function_utils.py b/src/hirad/utils/function_utils.py similarity index 100% rename from src/utils/function_utils.py rename to src/hirad/utils/function_utils.py diff --git a/src/utils/inference_utils.py b/src/hirad/utils/inference_utils.py similarity index 100% rename from src/utils/inference_utils.py rename to src/hirad/utils/inference_utils.py diff --git a/src/utils/model_utils.py b/src/hirad/utils/model_utils.py similarity index 100% rename from src/utils/model_utils.py rename to src/hirad/utils/model_utils.py diff --git a/src/utils/stochastic_sampler.py b/src/hirad/utils/stochastic_sampler.py similarity index 100% rename from src/utils/stochastic_sampler.py rename to src/hirad/utils/stochastic_sampler.py diff --git a/src/utils/train_helpers.py b/src/hirad/utils/train_helpers.py similarity index 100% rename from src/utils/train_helpers.py rename to src/hirad/utils/train_helpers.py diff --git a/src/hirad_gen.egg-info/PKG-INFO b/src/hirad_gen.egg-info/PKG-INFO new file mode 100644 index 0000000..9c4c18f --- /dev/null +++ b/src/hirad_gen.egg-info/PKG-INFO @@ -0,0 +1,38 @@ +Metadata-Version: 2.4 +Name: hirad-gen +Version: 0.1.0 +Summary: High resolution atmospheric downscaling using generative machine learning +Author-email: Petar Stamenkovic +License: BSD 3-Clause License + + Copyright (c) 2025, MeteoSwiss + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Requires-Python: >=3.12 +Description-Content-Type: text/markdown +License-File: LICENSE +Dynamic: license-file diff --git a/src/hirad_gen.egg-info/SOURCES.txt b/src/hirad_gen.egg-info/SOURCES.txt new file mode 100644 index 0000000..1645d32 --- /dev/null +++ b/src/hirad_gen.egg-info/SOURCES.txt @@ -0,0 +1,28 @@ +LICENSE +README.md +pyproject.toml +src/distributed/__init__.py +src/distributed/config.py +src/distributed/manager.py +src/hirad_gen.egg-info/PKG-INFO +src/hirad_gen.egg-info/SOURCES.txt +src/hirad_gen.egg-info/dependency_links.txt +src/hirad_gen.egg-info/top_level.txt +src/losses/__init__.py +src/losses/loss.py +src/models/__init__.py +src/models/layers.py +src/models/preconditioning.py +src/models/song_unet.py +src/models/unet.py +src/models/utils.py +src/utils/__init__.py +src/utils/capture.py +src/utils/checkpoint.py +src/utils/console.py +src/utils/deterministic_sampler.py +src/utils/function_utils.py +src/utils/inference_utils.py +src/utils/model_utils.py +src/utils/stochastic_sampler.py +src/utils/train_helpers.py \ No newline at end of file diff --git a/src/hirad_gen.egg-info/dependency_links.txt b/src/hirad_gen.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/hirad_gen.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/hirad_gen.egg-info/top_level.txt b/src/hirad_gen.egg-info/top_level.txt new file mode 100644 index 0000000..b778691 --- /dev/null +++ b/src/hirad_gen.egg-info/top_level.txt @@ -0,0 +1,7 @@ +distributed +evaluation +losses +metrics +models +training +utils From 05e3a08a75e41c5c6a5c63422c1516dab6cdb091 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic <56728083+PetarStam@users.noreply.github.com> Date: Mon, 14 Apr 2025 17:15:40 +0200 Subject: [PATCH 007/189] Delete src/hirad_gen.egg-info directory --- src/hirad_gen.egg-info/PKG-INFO | 38 --------------------- src/hirad_gen.egg-info/SOURCES.txt | 28 --------------- src/hirad_gen.egg-info/dependency_links.txt | 1 - src/hirad_gen.egg-info/top_level.txt | 7 ---- 4 files changed, 74 deletions(-) delete mode 100644 src/hirad_gen.egg-info/PKG-INFO delete mode 100644 src/hirad_gen.egg-info/SOURCES.txt delete mode 100644 src/hirad_gen.egg-info/dependency_links.txt delete mode 100644 src/hirad_gen.egg-info/top_level.txt diff --git a/src/hirad_gen.egg-info/PKG-INFO b/src/hirad_gen.egg-info/PKG-INFO deleted file mode 100644 index 9c4c18f..0000000 --- a/src/hirad_gen.egg-info/PKG-INFO +++ /dev/null @@ -1,38 +0,0 @@ -Metadata-Version: 2.4 -Name: hirad-gen -Version: 0.1.0 -Summary: High resolution atmospheric downscaling using generative machine learning -Author-email: Petar Stamenkovic -License: BSD 3-Clause License - - Copyright (c) 2025, MeteoSwiss - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - 3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -Requires-Python: >=3.12 -Description-Content-Type: text/markdown -License-File: LICENSE -Dynamic: license-file diff --git a/src/hirad_gen.egg-info/SOURCES.txt b/src/hirad_gen.egg-info/SOURCES.txt deleted file mode 100644 index 1645d32..0000000 --- a/src/hirad_gen.egg-info/SOURCES.txt +++ /dev/null @@ -1,28 +0,0 @@ -LICENSE -README.md -pyproject.toml -src/distributed/__init__.py -src/distributed/config.py -src/distributed/manager.py -src/hirad_gen.egg-info/PKG-INFO -src/hirad_gen.egg-info/SOURCES.txt -src/hirad_gen.egg-info/dependency_links.txt -src/hirad_gen.egg-info/top_level.txt -src/losses/__init__.py -src/losses/loss.py -src/models/__init__.py -src/models/layers.py -src/models/preconditioning.py -src/models/song_unet.py -src/models/unet.py -src/models/utils.py -src/utils/__init__.py -src/utils/capture.py -src/utils/checkpoint.py -src/utils/console.py -src/utils/deterministic_sampler.py -src/utils/function_utils.py -src/utils/inference_utils.py -src/utils/model_utils.py -src/utils/stochastic_sampler.py -src/utils/train_helpers.py \ No newline at end of file diff --git a/src/hirad_gen.egg-info/dependency_links.txt b/src/hirad_gen.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/src/hirad_gen.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/hirad_gen.egg-info/top_level.txt b/src/hirad_gen.egg-info/top_level.txt deleted file mode 100644 index b778691..0000000 --- a/src/hirad_gen.egg-info/top_level.txt +++ /dev/null @@ -1,7 +0,0 @@ -distributed -evaluation -losses -metrics -models -training -utils From dbf101ef0c50e1dff50ba140b85124a6d407f965 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 14 Apr 2025 17:17:57 +0200 Subject: [PATCH 008/189] add gitignore --- .gitignore | 171 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c514c5d --- /dev/null +++ b/.gitignore @@ -0,0 +1,171 @@ +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json \ No newline at end of file From 5de83bc5f7bb25f8bff91ffda69b2eedd97738fd Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 16 Apr 2025 11:57:45 +0200 Subject: [PATCH 009/189] add train script with missing parts --- .../{train_regression.yaml => training.yaml} | 0 src/hirad/distributed/manager.py | 6 +- src/hirad/losses/__init__.py | 1 + src/hirad/models/__init__.py | 1 + src/hirad/training/train.py | 462 ++++++++++++++++++ src/hirad/utils/capture.py | 2 +- 6 files changed, 468 insertions(+), 4 deletions(-) rename src/hirad/conf/{train_regression.yaml => training.yaml} (100%) diff --git a/src/hirad/conf/train_regression.yaml b/src/hirad/conf/training.yaml similarity index 100% rename from src/hirad/conf/train_regression.yaml rename to src/hirad/conf/training.yaml diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py index facb466..e80ce13 100644 --- a/src/hirad/distributed/manager.py +++ b/src/hirad/distributed/manager.py @@ -25,7 +25,7 @@ import torch import torch.distributed as dist -from src.distributed.config import ProcessGroupConfig, ProcessGroupNode +from hirad.distributed.config import ProcessGroupConfig, ProcessGroupNode warnings.simplefilter("default", DeprecationWarning) @@ -393,7 +393,7 @@ def initialize(): else: os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" initialization_method = os.getenv( - "PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD" + "DISTRIBUTED_INITIALIZATION_METHOD" ) if initialization_method is None: try: @@ -419,7 +419,7 @@ def initialize(): "Unknown initialization method " f"{initialization_method}. " "Supported values for " - "PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD are " + "DISTRIBUTED_INITIALIZATION_METHOD are " "ENV, SLURM and OPENMPI" ) diff --git a/src/hirad/losses/__init__.py b/src/hirad/losses/__init__.py index e69de29..185527b 100644 --- a/src/hirad/losses/__init__.py +++ b/src/hirad/losses/__init__.py @@ -0,0 +1 @@ +from .loss import ResLoss, RegressionLoss, RegressionLossCE \ No newline at end of file diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index 6b790ae..3b494c6 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,3 +1,4 @@ from .unet import UNet from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd +from .preconditioning import EDMPrecondSR from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index e69de29..a47910b 100644 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -0,0 +1,462 @@ +import os +import time + +import psutil +import hydra +from omegaconf import DictConfig, OmegaConf +import torch +from hydra.utils import to_absolute_path +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel + +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper +from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \ + set_patch_shape, compute_num_accumulation_rounds, \ + is_time_for_periodic_task, handle_and_clip_gradients +from hirad.models import UNet, EDMPrecondSR +from hirad.losses import ResLoss, RegressionLoss, RegressionLossCE + +@hydra.main(version_base=None, config_path="conf", config_name="training") +def main(cfg: DictConfig) -> None: + + # Initialize distributed environment for training + DistributedManager.initialize() + dist = DistributedManager() + + if dist.rank==0: + writer = SummaryWriter(log_dir='tensorboard') + logger = PythonLogger("main") # general logger + logger0 = RankZeroLoggingWrapper(logger, dist) # rank 0 logger + + OmegaConf.resolve(cfg) + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if hasattr(cfg, "validation"): + train_test_split = True + validation_dataset_cfg = OmegaConf.to_container(cfg.validation) + else: + train_test_split = False + validation_dataset_cfg = None + fp_optimizations = cfg.training.perf.fp_optimizations + songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level + fp16 = fp_optimizations == "fp16" + enable_amp = fp_optimizations.startswith("amp") + amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + logger.info(f"Saving the outputs in {os.getcwd()}") + checkpoint_dir = os.path.join( + cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" + ) + if cfg.training.hp.batch_size_per_gpu == "auto": + cfg.training.hp.batch_size_per_gpu = ( + cfg.training.hp.total_batch_size // dist.world_size + ) + + set_seed(dist.rank) + configure_cuda_for_consistent_precision() + + ### Write our own dataloader ### + ( + dataset, + dataset_iterator, + validation_dataset, + validation_dataset_iterator + ) = None, None, None, None + + dataset_channels = None #len(dataset.input_channels()) + img_in_channels = None #dataset_channels + img_shape = None #dataset.image_shape() + img_out_channels = None #len(dataset.output_channels()) + + prob_channels = None + + # Parse the patch shape + if ( + cfg.model.name == "patched_diffusion" + or cfg.model.name == "lt_aware_patched_diffusion" + ): + patch_shape_x = cfg.training.hp.patch_shape_x + patch_shape_y = cfg.training.hp.patch_shape_y + else: + patch_shape_x = None + patch_shape_y = None + patch_shape = (patch_shape_y, patch_shape_x) + img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if patch_shape != img_shape: + logger0.info("Patch-based training enabled") + else: + logger0.info("Patch-based training disabled") + # interpolate global channel if patch-based model is used + if img_shape[1] != patch_shape[1]: + img_in_channels += dataset_channels + + + # Instantiate the model and move to device. + if cfg.model.name not in ( + "regression", + "lt_aware_ce_regression", + "diffusion", + "patched_diffusion", + "lt_aware_patched_diffusion", + ): + raise ValueError("Invalid model") + model_args = { # default parameters for all networks + "img_out_channels": img_out_channels, + "img_resolution": list(img_shape), + "use_fp16": fp16, + } + standard_model_cfgs = { # default parameters for different network types + "regression": { + "img_channels": 4, + "N_grid_channels": 4, + "embedding_type": "zero", + "checkpoint_level": songunet_checkpoint_level, + }, + "lt_aware_ce_regression": { + "img_channels": 4, + "N_grid_channels": 4, + "embedding_type": "zero", + "lead_time_channels": 4, + "lead_time_steps": 9, + "prob_channels": prob_channels, + "checkpoint_level": songunet_checkpoint_level, + "model_type": "SongUNetPosLtEmbd", + }, + "diffusion": { + "img_channels": img_out_channels, + "gridtype": "sinusoidal", + "N_grid_channels": 4, + "checkpoint_level": songunet_checkpoint_level, + }, + "patched_diffusion": { + "img_channels": img_out_channels, + "gridtype": "learnable", + "N_grid_channels": 100, + "checkpoint_level": songunet_checkpoint_level, + }, + "lt_aware_patched_diffusion": { + "img_channels": img_out_channels, + "gridtype": "learnable", + "N_grid_channels": 100, + "lead_time_channels": 20, + "lead_time_steps": 9, + "checkpoint_level": songunet_checkpoint_level, + "model_type": "SongUNetPosLtEmbd", + }, + } + + + model_args.update(standard_model_cfgs[cfg.model.name]) + if cfg.model.name in ( + "diffusion", + "patched_diffusion", + "lt_aware_patched_diffusion", + ): + model_args["scale_cond_input"] = cfg.model.scale_cond_input + if hasattr(cfg.model, "model_args"): # override defaults from config file + model_args.update(OmegaConf.to_container(cfg.model.model_args)) + if cfg.model.name == "regression": + model = UNet( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + elif cfg.model.name == "lt_aware_ce_regression": + model = UNet( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + elif cfg.model.name == "lt_aware_patched_diffusion": + model = EDMPrecondSR( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + else: # diffusion or patched diffusion + model = EDMPrecondSR( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + + model.train().requires_grad_(True).to(dist.device) + + # Enable distributed data parallel if applicable + if dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + broadcast_buffers=True, + output_device=dist.device, + find_unused_parameters=dist.find_unused_parameters, + ) + + # Load the regression checkpoint if applicable + if hasattr(cfg.training.io, "regression_checkpoint_path"): + regression_checkpoint_path = to_absolute_path( + cfg.training.io.regression_checkpoint_path + ) + if not os.path.exists(regression_checkpoint_path): + raise FileNotFoundError( + f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" + ) + regression_net = torch.nn.Module() #Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device + regression_net.eval().requires_grad_(False).to(dist.device) + logger0.success("Loaded the pre-trained regression model") + + # Instantiate the loss function + patch_num = getattr(cfg.training.hp, "patch_num", 1) + if cfg.model.name in ( + "diffusion", + "patched_diffusion", + "lt_aware_patched_diffusion", + ): + loss_fn = ResLoss( + regression_net=regression_net, + img_shape_x=img_shape[1], + img_shape_y=img_shape[0], + patch_shape_x=patch_shape[1], + patch_shape_y=patch_shape[0], + patch_num=patch_num, + hr_mean_conditioning=cfg.model.hr_mean_conditioning, + ) + elif cfg.model.name == "regression": + loss_fn = RegressionLoss() + elif cfg.model.name == "lt_aware_ce_regression": + loss_fn = RegressionLossCE(prob_channels=prob_channels) + + # Instantiate the optimizer + optimizer = torch.optim.Adam( + params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8 + ) + + # Record the current time to measure the duration of subsequent operations. + start_time = time.time() + + # Compute the number of required gradient accumulation rounds + # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size + batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( + cfg.training.hp.total_batch_size, + cfg.training.hp.batch_size_per_gpu, + dist.world_size, + ) + batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") + + + ## Resume training from previous checkpoints if exists + if dist.world_size > 1: + torch.distributed.barrier() + try: + cur_nimg = 0 + # Fix loading and saving checkpoint + #load_checkpoint( + # path=checkpoint_dir, + # models=model, + # optimizer=optimizer, + # device=dist.device, + # ) + except: + cur_nimg = 0 + + ############################################################################ + # MAIN TRAINING LOOP # + ############################################################################ + + logger0.info(f"Training for {cfg.training.hp.training_duration} images...") + done = False + + # init variables to monitor running mean of average loss since last periodic + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + + while not done: + tick_start_nimg = cur_nimg + tick_start_time = time.time() + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + loss_accum = 0 + for _ in range(num_accumulation_rounds): + img_clean, img_lr, labels, *lead_time_label = next(dataset_iterator) # what are labels and lead_time_label + img_clean = img_clean.to(dist.device).to(torch.float32).contiguous() + img_lr = img_lr.to(dist.device).to(torch.float32).contiguous() + labels = labels.to(dist.device).contiguous() + loss_fn_kwargs = { + "net": model, + "img_clean": img_clean, + "img_lr": img_lr, + "labels": labels, + "augment_pipe": None, + } + if lead_time_label: + lead_time_label = lead_time_label[0].to(dist.device).contiguous() + loss_fn_kwargs.update({"lead_time_label": lead_time_label}) + else: + lead_time_label = None + with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp): + loss = loss_fn(**loss_fn_kwargs) + loss = loss.sum() / batch_size_per_gpu + loss_accum += loss / num_accumulation_rounds + loss.backward() + + + loss_sum = torch.tensor([loss_accum], device=dist.device) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM) + average_loss = (loss_sum / dist.world_size).cpu().item() + + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 + + if dist.rank == 0: + writer.add_scalar("training_loss", average_loss, cur_nimg) + writer.add_scalar( + "training_loss_running_mean", average_loss_running_mean, cur_nimg + ) + + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + # Update weights. + lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) + if cur_nimg >= lr_rampup: + g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) + current_lr = g["lr"] + if dist.rank == 0: + writer.add_scalar("learning_rate", current_lr, cur_nimg) + handle_and_clip_gradients( + model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold + ) + optimizer.step() + + cur_nimg += cfg.training.hp.total_batch_size + done = cur_nimg >= cfg.training.hp.training_duration + + # Validation + if validation_dataset_iterator is not None: + valid_loss_accum = 0 + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.validation_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + with torch.no_grad(): + for _ in range(cfg.training.io.validation_steps): + img_clean_valid, img_lr_valid, labels_valid = next( + validation_dataset_iterator + ) + + img_clean_valid = ( + img_clean_valid.to(dist.device) + .to(torch.float32) + .contiguous() + ) + img_lr_valid = ( + img_lr_valid.to(dist.device).to(torch.float32).contiguous() + ) + labels_valid = labels_valid.to(dist.device).contiguous() + loss_valid = loss_fn( + net=model, + img_clean=img_clean_valid, + img_lr=img_lr_valid, + labels=labels_valid, + augment_pipe=None, + ) + loss_valid = ( + (loss_valid.sum() / batch_size_per_gpu).cpu().item() + ) + valid_loss_accum += ( + loss_valid / cfg.training.io.validation_steps + ) + valid_loss_sum = torch.tensor( + [valid_loss_accum], device=dist.device + ) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_valid_loss = valid_loss_sum / dist.world_size + if dist.rank == 0: + writer.add_scalar( + "validation_loss", average_valid_loss, cur_nimg + ) + + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + logger0.info(" ".join(fields)) + torch.cuda.reset_peak_memory_stats() + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.save_checkpoint_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # figure out how to do save and load checkpoint + #save_checkpoint( + # path=checkpoint_dir, + # models=model, + # optimizer=optimizer, + # epoch=cur_nimg, + # ) + pass + + # Done. + logger0.info("Training Completed.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/utils/capture.py b/src/hirad/utils/capture.py index 50057f9..9c38d5a 100644 --- a/src/hirad/utils/capture.py +++ b/src/hirad/utils/capture.py @@ -24,7 +24,7 @@ import torch -from src.distributed import DistributedManager +from hirad.distributed import DistributedManager float16 = NewType("float16", torch.float16) bfloat16 = NewType("bfloat16", torch.bfloat16) From 134ed835f7cfde077dddb5d8eeee2313cf994611 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 16 Apr 2025 14:33:48 +0200 Subject: [PATCH 010/189] add abstract dataset class --- src/hirad/datasets/base.py | 85 ++++++++++++++++++++++++++ src/hirad/datasets/dataset.py | 111 ++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 src/hirad/datasets/base.py create mode 100644 src/hirad/datasets/dataset.py diff --git a/src/hirad/datasets/base.py b/src/hirad/datasets/base.py new file mode 100644 index 0000000..22b00d2 --- /dev/null +++ b/src/hirad/datasets/base.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Tuple + +import numpy as np +import torch + + +@dataclass +class ChannelMetadata: + """Metadata describing a data channel.""" + + name: str + level: str = "" + auxiliary: bool = False + + +class DownscalingDataset(torch.utils.data.Dataset, ABC): + """An abstract class that defines the interface for downscaling datasets.""" + + @abstractmethod + def longitude(self) -> np.ndarray: + """Get longitude values from the dataset.""" + pass + + @abstractmethod + def latitude(self) -> np.ndarray: + """Get latitude values from the dataset.""" + pass + + @abstractmethod + def input_channels(self) -> List[ChannelMetadata]: + """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def output_channels(self) -> List[ChannelMetadata]: + """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def time(self) -> List: + """Get time values from the dataset.""" + pass + + @abstractmethod + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + pass + + def normalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from physical units to normalized data.""" + return x + + def denormalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from normalized data to physical units.""" + return x + + def normalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from physical units to normalized data.""" + return x + + def denormalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from normalized data to physical units.""" + return x + + def info(self) -> dict: + """Get information about the dataset.""" + return {} diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py new file mode 100644 index 0000000..2a26630 --- /dev/null +++ b/src/hirad/datasets/dataset.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Tuple, Union +import copy +import torch + +from hirad.utils.function_utils import InfiniteSampler +from hirad.distributed import DistributedManager + +from . import base, cwb, hrrrmini, gefs_hrrr + + +# this maps all known dataset types to the corresponding init function +known_datasets = { + "cwb": cwb.get_zarr_dataset, + "hrrr_mini": hrrrmini.HRRRMiniDataset, + "gefs_hrrr": gefs_hrrr.HrrrForecastGEFSDataset, +} + + +def init_train_valid_datasets_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, + validation_dataset_cfg: Union[dict, None] = None, + train_test_split: bool = True, +) -> Tuple[ + base.DownscalingDataset, + Iterable, + Union[base.DownscalingDataset, None], + Union[Iterable, None], +]: + """ + A wrapper function for managing the train-test split for the CWB dataset. + + Parameters: + - dataset_cfg (dict): Configuration for the dataset. + - dataloader_cfg (dict, optional): Configuration for the dataloader. Defaults to None. + - batch_size (int): The number of samples in each batch of data. Defaults to 1. + - seed (int): The random seed for dataset shuffling. Defaults to 0. + - train_test_split (bool): A flag to determine whether to create a validation dataset. Defaults to True. + + Returns: + - Tuple[base.DownscalingDataset, Iterable, Optional[base.DownscalingDataset], Optional[Iterable]]: A tuple containing the training dataset and iterator, and optionally the validation dataset and iterator if train_test_split is True. + """ + + config = copy.deepcopy(dataset_cfg) + (dataset, dataset_iter) = init_dataset_from_config( + config, dataloader_cfg, batch_size=batch_size, seed=seed + ) + if train_test_split: + valid_dataset_cfg = copy.deepcopy(config) + if validation_dataset_cfg: + valid_dataset_cfg.update(validation_dataset_cfg) + (valid_dataset, valid_dataset_iter) = init_dataset_from_config( + valid_dataset_cfg, dataloader_cfg, batch_size=batch_size, seed=seed + ) + else: + valid_dataset = valid_dataset_iter = None + + return dataset, dataset_iter, valid_dataset, valid_dataset_iter + + +def init_dataset_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, +) -> Tuple[base.DownscalingDataset, Iterable]: + dataset_cfg = copy.deepcopy(dataset_cfg) + dataset_type = dataset_cfg.pop("type", "cwb") + if "train_test_split" in dataset_cfg: + # handled by init_train_valid_datasets_from_config + del dataset_cfg["train_test_split"] + dataset_init_func = known_datasets[dataset_type] + + dataset_obj = dataset_init_func(**dataset_cfg) + if dataloader_cfg is None: + dataloader_cfg = {} + + dist = DistributedManager() + dataset_sampler = InfiniteSampler( + dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed + ) + + dataset_iterator = iter( + torch.utils.data.DataLoader( + dataset=dataset_obj, + sampler=dataset_sampler, + batch_size=batch_size, + worker_init_fn=None, + **dataloader_cfg, + ) + ) + + return (dataset_obj, dataset_iterator) From de4e7c55374b45d637754f5bf8be4f6726e762c7 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 23 Apr 2025 16:09:28 +0200 Subject: [PATCH 011/189] adapt checkpoint saving and loading --- src/hirad/utils/checkpoint.py | 51 ++++++++++------------------------- 1 file changed, 14 insertions(+), 37 deletions(-) diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index 8ec70fa..5ce194a 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -23,11 +23,11 @@ from torch.cuda.amp import GradScaler from torch.optim.lr_scheduler import _LRScheduler -from src.distributed import DistributedManager -from src.utils.console import PythonLogger -from src.utils.capture import _StaticCapture +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger +from hirad.utils.capture import _StaticCapture -optimizer = NewType("optimizer", torch.optim) +optimizer = NewType("optimizer", torch.optim.Optimizer) scheduler = NewType("scheduler", _LRScheduler) scaler = NewType("scaler", GradScaler) @@ -39,7 +39,7 @@ def _get_checkpoint_filename( base_name: str = "checkpoint", index: Union[int, None] = None, saving: bool = False, - model_type: str = "mdlus", + model_type: str = "pt", ) -> str: """Gets the file name /path of checkpoint @@ -91,7 +91,7 @@ def _get_checkpoint_filename( ) # File extension for PhysicsNeMo models or PyTorch models - file_extension = ".mdlus" if model_type == "mdlus" else ".pt" + file_extension = "."+model_type # If epoch is provided load that file if index is not None: @@ -157,8 +157,6 @@ def _unique_model_names( model0 = model0.module # Base name of model is meta.name unless pytorch model base_name = model0.__class__.__name__ - if isinstance(model0, physicsnemo.models.Module): - base_name = model0.meta.name # If we have multiple models of the same name, introduce another index if base_name in model_dict: model_dict[base_name].append(model0) @@ -189,8 +187,8 @@ def save_checkpoint( """Training checkpoint saving utility This will save a training checkpoint in the provided path following the file naming - convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint - method in PhysicsNeMo core can then be used to read this file. + convention "checkpoint.{model parallel id}.{epoch/index}.pt". The load checkpoint + method can then be used to read this file. Parameters ---------- @@ -224,21 +222,13 @@ def save_checkpoint( models = [models] models = _unique_model_names(models) for name, model in models.items(): - # Get model type - model_type = ( - "mdlus" if isinstance(model, physicsnemo.models.Module) else "pt" - ) - # Get full file path / name file_name = _get_checkpoint_filename( - path, name, index=epoch, saving=True, model_type=model_type + path, name, index=epoch, saving=True, model_type="pt" ) # Save state dictionary - if isinstance(model, physicsnemo.models.Module): - model.save(file_name) - else: - torch.save(model.state_dict(), file_name) + torch.save(model.state_dict(), file_name) checkpoint_logging.success(f"Saved model state dictionary: {file_name}") # == Saving training checkpoint == @@ -251,12 +241,9 @@ def save_checkpoint( if scheduler: checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict() - # Scheduler state dict + # Scaler state dict if scaler: checkpoint_dict["scaler_state_dict"] = scaler.state_dict() - # Static capture is being used, save its grad scaler - if _StaticCapture._amp_scalers: - checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict() # Output file name output_filename = _get_checkpoint_filename( @@ -288,8 +275,7 @@ def load_checkpoint( ) -> int: """Checkpoint loading utility - This loader is designed to be used with the save checkpoint utility in PhysicsNeMo - Launch. Given a path, this method will try to find a checkpoint and load state + This loader is designed to be used with the save checkpoint utility. Given a path, this method will try to find a checkpoint and load state dictionaries into the provided training objects. Parameters @@ -331,9 +317,7 @@ def load_checkpoint( models = _unique_model_names(models) for name, model in models.items(): # Get model type - model_type = ( - "mdlus" if isinstance(model, physicsnemo.models.Module) else "pt" - ) + model_type = "pt" # Get full file path / name file_name = _get_checkpoint_filename( @@ -345,10 +329,7 @@ def load_checkpoint( ) continue # Load state dictionary - if isinstance(model, physicsnemo.models.Module): - model.load(file_name) - else: - model.load_state_dict(torch.load(file_name, map_location=device)) + model.load_state_dict(torch.load(file_name, map_location=device)) checkpoint_logging.success( f"Loaded model state dictionary {file_name} to device {device}" @@ -382,10 +363,6 @@ def load_checkpoint( scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) checkpoint_logging.success("Loaded grad scaler state dictionary") - if "static_capture_state_dict" in checkpoint_dict: - _StaticCapture.load_state_dict(checkpoint_dict["static_capture_state_dict"]) - checkpoint_logging.success("Loaded static capture state dictionary") - epoch = 0 if "epoch" in checkpoint_dict: epoch = checkpoint_dict["epoch"] From cff9da486d60d77764e2996632ba12d4642fc47f Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 16:51:17 +0200 Subject: [PATCH 012/189] adapt checkpoint loading and saving --- src/hirad/models/__init__.py | 5 +- src/hirad/training/train.py | 105 ++++++++++++++++++++++--------- src/hirad/utils/checkpoint.py | 115 +++++++++++----------------------- 3 files changed, 115 insertions(+), 110 deletions(-) diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index 3b494c6..f17e5ce 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,4 +1,5 @@ from .unet import UNet from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd -from .preconditioning import EDMPrecondSR -from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding \ No newline at end of file +from .preconditioning import EDMPrecondSR, EDMPrecond +from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding +from .meta import ModelMetaData \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index a47910b..d6fe563 100644 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -4,6 +4,7 @@ import psutil import hydra from omegaconf import DictConfig, OmegaConf +import json import torch from hydra.utils import to_absolute_path from torch.utils.tensorboard import SummaryWriter @@ -14,8 +15,10 @@ from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \ set_patch_shape, compute_num_accumulation_rounds, \ is_time_for_periodic_task, handle_and_clip_gradients +from hirad.utils.checkpoint import load_checkpoint, save_checkpoint from hirad.models import UNet, EDMPrecondSR from hirad.losses import ResLoss, RegressionLoss, RegressionLossCE +from hirad.datasets import init_train_valid_datasets_from_config @hydra.main(version_base=None, config_path="conf", config_name="training") def main(cfg: DictConfig) -> None: @@ -54,20 +57,38 @@ def main(cfg: DictConfig) -> None: set_seed(dist.rank) configure_cuda_for_consistent_precision() - ### Write our own dataloader ### + # Instantiate the dataset + data_loader_kwargs = { + "pin_memory": True, + "num_workers": cfg.training.perf.dataloader_workers, + "prefetch_factor": 2, + } ( - dataset, - dataset_iterator, - validation_dataset, - validation_dataset_iterator - ) = None, None, None, None + dataset, + dataset_iterator, + validation_dataset, + validation_dataset_iterator, + ) = init_train_valid_datasets_from_config( + dataset_cfg, + data_loader_kwargs, + batch_size=cfg.training.hp.batch_size_per_gpu, + seed=0, + validation_dataset_cfg=validation_dataset_cfg, + train_test_split=train_test_split, + ) - dataset_channels = None #len(dataset.input_channels()) - img_in_channels = None #dataset_channels - img_shape = None #dataset.image_shape() - img_out_channels = None #len(dataset.output_channels()) + # Parse image configuration & update model args + dataset_channels = len(dataset.input_channels()) + img_in_channels = dataset_channels + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + if cfg.model.hr_mean_conditioning: + img_in_channels += img_out_channels - prob_channels = None + if cfg.model.name == "lt_aware_ce_regression": + prob_channels = dataset.get_prob_channel_index() + else: + prob_channels = None # Parse the patch shape if ( @@ -181,6 +202,10 @@ def main(cfg: DictConfig) -> None: model.train().requires_grad_(True).to(dist.device) + if not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): + with open(os.path.join(checkpoint_dir, 'model_args.json'), 'w') as f: + json.dump(model_args, f) + # Enable distributed data parallel if applicable if dist.world_size > 1: model = DistributedDataParallel( @@ -196,11 +221,38 @@ def main(cfg: DictConfig) -> None: regression_checkpoint_path = to_absolute_path( cfg.training.io.regression_checkpoint_path ) - if not os.path.exists(regression_checkpoint_path): + if not os.path.isdir(regression_checkpoint_path): raise FileNotFoundError( f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" ) - regression_net = torch.nn.Module() #Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device + #regression_net = torch.nn.Module() #TODO Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device + #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name) + #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder + regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json') + if not os.path.isfile(regression_model_args_path): + raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") + + with open(regression_model_args_path, 'r') as f: + regression_model_args = json.load(f) + + if cfg.model.name == "lt_aware_patched_diffusion": + regression_net = UNet( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **regression_model_args, + ) + else: + regression_net = UNet( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **regression_model_args, + ) + + _ = load_checkpoint( + path=regression_checkpoint_path, + model=regression_net, + device=dist.device + ) regression_net.eval().requires_grad_(False).to(dist.device) logger0.success("Loaded the pre-trained regression model") @@ -248,14 +300,12 @@ def main(cfg: DictConfig) -> None: if dist.world_size > 1: torch.distributed.barrier() try: - cur_nimg = 0 - # Fix loading and saving checkpoint - #load_checkpoint( - # path=checkpoint_dir, - # models=model, - # optimizer=optimizer, - # device=dist.device, - # ) + cur_nimg = load_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + device=dist.device, + ) except: cur_nimg = 0 @@ -445,13 +495,12 @@ def main(cfg: DictConfig) -> None: dist.rank, rank_0_only=True, ): - # figure out how to do save and load checkpoint - #save_checkpoint( - # path=checkpoint_dir, - # models=model, - # optimizer=optimizer, - # epoch=cur_nimg, - # ) + save_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + epoch=cur_nimg, + ) pass # Done. diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index 5ce194a..03b423d 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -17,7 +17,7 @@ import glob import re from pathlib import Path -from typing import Any, Dict, List, NewType, Optional, Union +from typing import Any, Dict, List, NewType, Optional, Union, Tuple import torch from torch.cuda.amp import GradScaler @@ -25,7 +25,6 @@ from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger -from hirad.utils.capture import _StaticCapture optimizer = NewType("optimizer", torch.optim.Optimizer) scheduler = NewType("scheduler", _LRScheduler) @@ -133,51 +132,9 @@ def _get_checkpoint_filename( return checkpoint_filename -def _unique_model_names( - models: List[torch.nn.Module], -) -> Dict[str, torch.nn.Module]: - """Util to clean model names and index if repeat names, will also strip DDP wrappers - if they exist. - - Parameters - ---------- - model : List[torch.nn.Module] - List of models to generate names for - - Returns - ------- - Dict[str, torch.nn.Module] - Dictionary of model names and respective modules - """ - # Loop through provided models and set up base names - model_dict = {} - for model0 in models: - if hasattr(model0, "module"): - # Strip out DDP layer - model0 = model0.module - # Base name of model is meta.name unless pytorch model - base_name = model0.__class__.__name__ - # If we have multiple models of the same name, introduce another index - if base_name in model_dict: - model_dict[base_name].append(model0) - else: - model_dict[base_name] = [model0] - - # Set up unique model names if needed - output_dict = {} - for key, model in model_dict.items(): - if len(model) > 1: - for i, model0 in enumerate(model): - output_dict[key + str(i)] = model0 - else: - output_dict[key] = model[0] - - return output_dict - - def save_checkpoint( path: str, - models: Union[torch.nn.Module, List[torch.nn.Module], None] = None, + model: Union[torch.nn.Module, None] = None, optimizer: Union[optimizer, None] = None, scheduler: Union[scheduler, None] = None, scaler: Union[scaler, None] = None, @@ -217,19 +174,20 @@ def save_checkpoint( Path(path).mkdir(parents=True, exist_ok=True) # == Saving model checkpoint == - if models: - if not isinstance(models, list): - models = [models] - models = _unique_model_names(models) - for name, model in models.items(): - # Get full file path / name - file_name = _get_checkpoint_filename( - path, name, index=epoch, saving=True, model_type="pt" - ) + if model: + if hasattr(model, "module"): + # Strip out DDP layer + model = model.module + # Base name of model is meta.name unless pytorch model + name = model.__class__.__name__ + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, saving=True, model_type="pt" + ) - # Save state dictionary - torch.save(model.state_dict(), file_name) - checkpoint_logging.success(f"Saved model state dictionary: {file_name}") + # Save state dictionary + torch.save(model.state_dict(), file_name) + checkpoint_logging.success(f"Saved model state dictionary: {file_name}") # == Saving training checkpoint == checkpoint_dict = {} @@ -265,7 +223,7 @@ def save_checkpoint( def load_checkpoint( path: str, - models: Union[torch.nn.Module, List[torch.nn.Module], None] = None, + model: torch.nn.Module, optimizer: Union[optimizer, None] = None, scheduler: Union[scheduler, None] = None, scaler: Union[scaler, None] = None, @@ -311,29 +269,26 @@ def load_checkpoint( return 0 # == Loading model checkpoint == - if models: - if not isinstance(models, list): - models = [models] - models = _unique_model_names(models) - for name, model in models.items(): - # Get model type - model_type = "pt" - - # Get full file path / name - file_name = _get_checkpoint_filename( - path, name, index=epoch, model_type=model_type - ) - if not Path(file_name).exists(): - checkpoint_logging.error( - f"Could not find valid model file {file_name}, skipping load" - ) - continue - # Load state dictionary - model.load_state_dict(torch.load(file_name, map_location=device)) + if hasattr(model, "module"): + # Strip out DDP layer + model = model.module + # Base name of model is meta.name unless pytorch model + name = model.__class__.__name__ + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, + ) + if not Path(file_name).exists(): + checkpoint_logging.error( + f"Could not find valid model file {file_name}, skipping load" + ) + else: + # Load state dictionary + model.load_state_dict(torch.load(file_name, map_location=device)) - checkpoint_logging.success( - f"Loaded model state dictionary {file_name} to device {device}" - ) + checkpoint_logging.success( + f"Loaded model state dictionary {file_name} to device {device}" + ) # == Loading training checkpoint == checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") From 1029d4c7bb16f105e34181810a959797a1f32c8e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 16:52:27 +0200 Subject: [PATCH 013/189] fix model imports and dependency on module metadata --- src/hirad/models/meta.py | 50 +++++++++++++++++++++++++++++ src/hirad/models/preconditioning.py | 14 ++++---- src/hirad/models/song_unet.py | 6 ++-- src/hirad/models/unet.py | 6 ++-- 4 files changed, 63 insertions(+), 13 deletions(-) create mode 100644 src/hirad/models/meta.py diff --git a/src/hirad/models/meta.py b/src/hirad/models/meta.py new file mode 100644 index 0000000..aab8e45 --- /dev/null +++ b/src/hirad/models/meta.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class ModelMetaData: + """Data class for storing essential meta data needed for all Hirad Models""" + + # Model info + name: str = "HiradModule" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp: bool = False + amp_cpu: bool = None + amp_gpu: bool = None + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + onnx_gpu: bool = None + onnx_cpu: bool = None + onnx_runtime: bool = False + trt: bool = False + # Physics informed + var_dim: int = -1 + func_torch: bool = False + auto_grad: bool = False + + def __post_init__(self): + self.amp_cpu = self.amp if self.amp_cpu is None else self.amp_cpu + self.amp_gpu = self.amp if self.amp_gpu is None else self.amp_gpu + self.onnx_cpu = self.onnx if self.onnx_cpu is None else self.onnx_cpu + self.onnx_gpu = self.onnx if self.onnx_gpu is None else self.onnx_gpu diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py index 7b621e2..b0e924e 100644 --- a/src/hirad/models/preconditioning.py +++ b/src/hirad/models/preconditioning.py @@ -29,11 +29,11 @@ import torch import torch.nn as nn -from src.models import ( +from hirad.models import ( DhariwalUNet, # noqa: F401 for globals SongUNet, # noqa: F401 for globals ) -from physicsnemo.models.meta import ModelMetaData +from hirad.models import ModelMetaData network_module = importlib.import_module("physicsnemo.models.diffusion") @@ -105,7 +105,7 @@ def __init__( model_type: str = "SongUNet", **model_kwargs: dict, ): - super().__init__(meta=VPPrecondMetaData) + super().__init__() #meta=VPPrecondMetaData self.img_resolution = img_resolution self.img_channels = img_channels self.label_dim = label_dim @@ -282,7 +282,7 @@ def __init__( model_type: str = "SongUNet", **model_kwargs: dict, ): - super().__init__(meta=VEPrecondMetaData) + super().__init__() #meta=VEPrecondMetaData self.img_resolution = img_resolution self.img_channels = img_channels self.label_dim = label_dim @@ -414,7 +414,7 @@ def __init__( model_type="DhariwalUNet", **model_kwargs, ): - super().__init__(meta=iDDPMPrecondMetaData) + super().__init__() #meta=iDDPMPrecondMetaData self.img_resolution = img_resolution self.img_channels = img_channels self.label_dim = label_dim @@ -601,7 +601,7 @@ def __init__( img_out_channels=None, **model_kwargs, ): - super().__init__(meta=EDMPrecondMetaData) + super().__init__() #meta=EDMPrecondMetaData self.img_resolution = img_resolution if img_in_channels is not None: img_in_channels = img_in_channels @@ -767,7 +767,7 @@ def __init__( scale_cond_input=True, **model_kwargs, ): - super().__init__(meta=EDMPrecondSRMetaData) + super().__init__() #meta=EDMPrecondSRMetaData self.img_resolution = img_resolution self.img_channels = img_channels # TODO: this is not used, remove it self.img_in_channels = img_in_channels diff --git a/src/hirad/models/song_unet.py b/src/hirad/models/song_unet.py index 68adbda..5bfca8a 100644 --- a/src/hirad/models/song_unet.py +++ b/src/hirad/models/song_unet.py @@ -29,7 +29,7 @@ from torch.utils.checkpoint import checkpoint import torch.nn as nn -from src.models import ( +from hirad.models import ( Conv2d, FourierEmbedding, GroupNorm, @@ -37,7 +37,7 @@ PositionalEmbedding, UNetBlock, ) -from physicsnemo.models.meta import ModelMetaData +from hirad.models import ModelMetaData @dataclass @@ -175,7 +175,7 @@ def __init__( f"Invalid decoder_type: {decoder_type}. Must be one of {valid_decoder_types}." ) - super().__init__(meta=MetaData()) + super().__init__() #meta=MetaData() self.label_dropout = label_dropout self.embedding_type = embedding_type emb_channels = model_channels * channel_mult_emb diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index db8e4f8..1333bda 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn -from physicsnemo.models.meta import ModelMetaData +from hirad.models import ModelMetaData network_module = importlib.import_module("src.models") @@ -92,7 +92,7 @@ def __init__( model_type="SongUNetPosEmbd", **model_kwargs, ): - super().__init__(meta=MetaData) + super().__init__() #meta=MetaData self.img_channels = img_channels @@ -207,7 +207,7 @@ def __init__( model_type="SongUNet", **model_kwargs, ): - super().__init__(meta=MetaData("StormCastUNet")) + super().__init__() #meta=MetaData("StormCastUNet") if isinstance(img_resolution, int): self.img_shape_x = self.img_shape_y = img_resolution From a1daebec9efd89f08604af2323d92eb39b04047d Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 16:53:10 +0200 Subject: [PATCH 014/189] add generate utils --- src/hirad/utils/deterministic_sampler.py | 2 +- src/hirad/utils/generate_utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 src/hirad/utils/generate_utils.py diff --git a/src/hirad/utils/deterministic_sampler.py b/src/hirad/utils/deterministic_sampler.py index 4b2f32b..9fcea1d 100644 --- a/src/hirad/utils/deterministic_sampler.py +++ b/src/hirad/utils/deterministic_sampler.py @@ -19,7 +19,7 @@ import nvtx import torch -from src.models import EDMPrecond +from hirad.models import EDMPrecond # ruff: noqa: E731 diff --git a/src/hirad/utils/generate_utils.py b/src/hirad/utils/generate_utils.py new file mode 100644 index 0000000..29f7eb4 --- /dev/null +++ b/src/hirad/utils/generate_utils.py @@ -0,0 +1,24 @@ +import datetime +from hirad.datasets import init_dataset_from_config +from hirad.utils.function_utils import convert_datetime_to_cftime + + +def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False): + """ + Get a dataset and sampler for generation. + """ + (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1) + if has_lead_time: + plot_times = times + else: + plot_times = [ + convert_datetime_to_cftime( + datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") + ) + for time in times + ] + all_times = dataset.time() + time_indices = [all_times.index(t) for t in plot_times] + sampler = time_indices + + return dataset, sampler \ No newline at end of file From 81647669d5cc090d59eae5611871b0b645b11ec0 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 16:53:40 +0200 Subject: [PATCH 015/189] add sceleton for era5-cosmo dataset --- src/hirad/datasets/__init__.py | 3 +++ src/hirad/datasets/dataset.py | 6 ++---- src/hirad/datasets/era5_cosmo.py | 5 +++++ 3 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 src/hirad/datasets/__init__.py create mode 100644 src/hirad/datasets/era5_cosmo.py diff --git a/src/hirad/datasets/__init__.py b/src/hirad/datasets/__init__.py new file mode 100644 index 0000000..706284e --- /dev/null +++ b/src/hirad/datasets/__init__.py @@ -0,0 +1,3 @@ +from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config +from .era5_cosmo import ERA5_COSMO +from .base import DownscalingDataset \ No newline at end of file diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 2a26630..1928e4d 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -21,14 +21,12 @@ from hirad.utils.function_utils import InfiniteSampler from hirad.distributed import DistributedManager -from . import base, cwb, hrrrmini, gefs_hrrr +from hirad.datasets import ERA5_COSMO # this maps all known dataset types to the corresponding init function known_datasets = { - "cwb": cwb.get_zarr_dataset, - "hrrr_mini": hrrrmini.HRRRMiniDataset, - "gefs_hrrr": gefs_hrrr.HrrrForecastGEFSDataset, + "era5_cosmo": ERA5_COSMO, } diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py new file mode 100644 index 0000000..2ec94be --- /dev/null +++ b/src/hirad/datasets/era5_cosmo.py @@ -0,0 +1,5 @@ +from hirad.datasets.base import DownscalingDataset, ChannelMetadata + +class ERA5_COSMO(DownscalingDataset): + def __init__(self): + super().__init__() \ No newline at end of file From a9f70349f5b07c4c436da53fa239e1dbb5295e59 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 17:25:02 +0200 Subject: [PATCH 016/189] add in_channels to arg saving list --- src/hirad/training/train.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index d6fe563..0539aad 100644 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -180,6 +180,7 @@ def main(cfg: DictConfig) -> None: img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) + model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] elif cfg.model.name == "lt_aware_ce_regression": model = UNet( img_in_channels=img_in_channels @@ -187,6 +188,7 @@ def main(cfg: DictConfig) -> None: + model_args["lead_time_channels"], **model_args, ) + model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] elif cfg.model.name == "lt_aware_patched_diffusion": model = EDMPrecondSR( img_in_channels=img_in_channels @@ -194,15 +196,17 @@ def main(cfg: DictConfig) -> None: + model_args["lead_time_channels"], **model_args, ) + model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] else: # diffusion or patched diffusion model = EDMPrecondSR( img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) - + model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model.train().requires_grad_(True).to(dist.device) - if not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): + if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): with open(os.path.join(checkpoint_dir, 'model_args.json'), 'w') as f: json.dump(model_args, f) @@ -235,18 +239,7 @@ def main(cfg: DictConfig) -> None: with open(regression_model_args_path, 'r') as f: regression_model_args = json.load(f) - if cfg.model.name == "lt_aware_patched_diffusion": - regression_net = UNet( - img_in_channels=img_in_channels - + model_args["N_grid_channels"] - + model_args["lead_time_channels"], - **regression_model_args, - ) - else: - regression_net = UNet( - img_in_channels=img_in_channels + model_args["N_grid_channels"], - **regression_model_args, - ) + regression_net = UNet(**regression_model_args) _ = load_checkpoint( path=regression_checkpoint_path, From e0059c83ab8ad7f30d8c10ecc46b744811fe10c8 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 28 Apr 2025 13:47:39 +0200 Subject: [PATCH 017/189] add dataset era5_cosmo --- src/hirad/datasets/dataset.py | 9 +-- src/hirad/datasets/era5_cosmo.py | 95 +++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 1928e4d..6e402d9 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -21,7 +21,8 @@ from hirad.utils.function_utils import InfiniteSampler from hirad.distributed import DistributedManager -from hirad.datasets import ERA5_COSMO +from .era5_cosmo import ERA5_COSMO +from .base import DownscalingDataset # this maps all known dataset types to the corresponding init function @@ -38,9 +39,9 @@ def init_train_valid_datasets_from_config( validation_dataset_cfg: Union[dict, None] = None, train_test_split: bool = True, ) -> Tuple[ - base.DownscalingDataset, + DownscalingDataset, Iterable, - Union[base.DownscalingDataset, None], + Union[DownscalingDataset, None], Union[Iterable, None], ]: """ @@ -79,7 +80,7 @@ def init_dataset_from_config( dataloader_cfg: Union[dict, None] = None, batch_size: int = 1, seed: int = 0, -) -> Tuple[base.DownscalingDataset, Iterable]: +) -> Tuple[DownscalingDataset, Iterable]: dataset_cfg = copy.deepcopy(dataset_cfg) dataset_type = dataset_cfg.pop("type", "cwb") if "train_test_split" in dataset_cfg: diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 2ec94be..597bdfc 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -1,5 +1,94 @@ -from hirad.datasets.base import DownscalingDataset, ChannelMetadata +from .base import DownscalingDataset, ChannelMetadata +import os +import numpy as np +import torch +from typing import List, Tuple +import yaml class ERA5_COSMO(DownscalingDataset): - def __init__(self): - super().__init__() \ No newline at end of file + def __init__(self, dataset_path: str): + super().__init__() + + #TODO switch hanbdling paths to Path rather than pure strings + self._dataset_path = dataset_path + self._era5_path = os.path.join(dataset_path, 'era-interpolated') + self._cosmo_path = os.path.join(dataset_path, 'cosmo') + self._info_path = os.path.join(dataset_path, 'info') + + # load file list (each file is one date-time state) + self._file_list = os.listdir(self._cosmo_path) + + # Load cosmo info and channel names + with open(os.path.join(self._info_path,'cosmo.yaml'), 'r') as file: + self._cosmo_info = yaml.safe_load(file) + self._cosmo_channels = [ChannelMetadata(name) for name in self._cosmo_info['select']] + + # Load era5 info and channel names + with open(os.path.join(self._info_path,'era.yaml'), 'r') as file: + self._era_info = yaml.safe_load(file) + self._era_channels = [ChannelMetadata(name) if len(name.split('_'))==1 + else ChannelMetadata(name.split('_')[0],name.split('_')[1]) + for name in self._era_info['select']] + + # Load stats for normalizing channels of input and output + + cosmo_stats = torch.load(os.path.join(self._info_path,'cosmo-stats'), weights_only=False) + print(cosmo_stats) + + + def __len__(self): + return len(self._file_list) + + + def longitude(self) -> np.ndarray: + """Get longitude values from the dataset.""" + lon_lat = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lon_lat[:,0] + + + def latitude(self) -> np.ndarray: + """Get latitude values from the dataset.""" + lon_lat = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lon_lat[:,1] + + + def input_channels(self) -> List[ChannelMetadata]: + """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" + return self._era_channels + + + def output_channels(self) -> List[ChannelMetadata]: + """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" + return self._cosmo_channels + + + def time(self) -> List: + """Get time values from the dataset.""" + #TODO Choose the time format and convert to that, currently it's a string from a filename + return [file.split('.')[0] for file in self._file_list] + + + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + #TODO load from info, I hardcode it for now + return 390,582 + + + def normalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from physical units to normalized data.""" + return (x - self.input_mean) / self.input_std + + + def denormalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from normalized data to physical units.""" + return x * self.input_std + self.input_mean + + + def normalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from physical units to normalized data.""" + return (x - self.output_mean) / self.output_std + + + def denormalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from normalized data to physical units.""" + return x * self.output_std + self.output_mean \ No newline at end of file From da4cb6ccc9abc901128ea1db2d401387ff0374e1 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 28 Apr 2025 13:48:06 +0200 Subject: [PATCH 018/189] fix imports --- src/hirad/distributed/manager.py | 2 +- src/hirad/inference/generate.py | 159 ++++++++++++++++++++++++++++ src/hirad/models/layers.py | 2 +- src/hirad/models/preconditioning.py | 4 +- src/hirad/models/song_unet.py | 4 +- src/hirad/models/unet.py | 2 +- src/hirad/utils/checkpoint.py | 2 +- src/hirad/utils/generate_utils.py | 2 +- src/hirad/utils/inference_utils.py | 2 +- 9 files changed, 169 insertions(+), 10 deletions(-) create mode 100644 src/hirad/inference/generate.py diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py index e80ce13..647d054 100644 --- a/src/hirad/distributed/manager.py +++ b/src/hirad/distributed/manager.py @@ -25,7 +25,7 @@ import torch import torch.distributed as dist -from hirad.distributed.config import ProcessGroupConfig, ProcessGroupNode +from .config import ProcessGroupConfig, ProcessGroupNode warnings.simplefilter("default", DeprecationWarning) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py new file mode 100644 index 0000000..6e5273a --- /dev/null +++ b/src/hirad/inference/generate.py @@ -0,0 +1,159 @@ +import hydra +import os +import json +from omegaconf import OmegaConf, DictConfig +import torch +import torch._dynamo +import nvtx +import numpy as np +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from einops import rearrange +from torch.distributed import gather + + +from hydra.utils import to_absolute_path +from hirad.models import EDMPrecond, UNet +from hirad.utils.stochastic_sampler import stochastic_sampler +from hirad.utils.deterministic_sampler import deterministic_sampler +from hirad.utils.inference_utils import ( + get_time_from_range, + regression_step, + diffusion_step, +) +from hirad.utils.checkpoint import load_checkpoint + + +from hirad.utils.generate_utils import ( + get_dataset_and_sampler +) + +from hirad.utils.train_helpers import set_patch_shape + + +@hydra.main(version_base="1.2", config_path="conf", config_name="config_generate") +def main(cfg: DictConfig) -> None: + """Generate random dowscaled atmospheric states using the techniques described in the paper + "Elucidating the Design Space of Diffusion-Based Generative Models". + """ + + # Initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + + # Initialize logger + logger = PythonLogger("generate") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging("generate.log") + + # Handle the batch size + seeds = list(np.arange(cfg.generation.num_ensembles)) + num_batches = ( + (len(seeds) - 1) // (cfg.generation.seed_batch_size * dist.world_size) + 1 + ) * dist.world_size + all_batches = torch.as_tensor(seeds).tensor_split(num_batches) + rank_batches = all_batches[dist.rank :: dist.world_size] + + # Synchronize + if dist.world_size > 1: + torch.distributed.barrier() + + # Parse the inference input times + if cfg.generation.times_range and cfg.generation.times: + raise ValueError("Either times_range or times must be provided, but not both") + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range) #TODO check what time formats we are using and adapt + else: + times = cfg.generation.times + + # Create dataset object + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if "has_lead_time" in cfg.generation: + has_lead_time = cfg.generation["has_lead_time"] + else: + has_lead_time = False + dataset, sampler = get_dataset_and_sampler( + dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time + ) + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + + # Parse the patch shape + if hasattr(cfg.generation, "patch_shape_x"): # TODO better config handling + patch_shape_x = cfg.generation.patch_shape_x + else: + patch_shape_x = None + if hasattr(cfg.generation, "patch_shape_y"): + patch_shape_y = cfg.generation.patch_shape_y + else: + patch_shape_y = None + patch_shape = (patch_shape_y, patch_shape_x) + img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if patch_shape != img_shape: + logger0.info("Patch-based training enabled") + else: + logger0.info("Patch-based training disabled") + + # Parse the inference mode + if cfg.generation.inference_mode == "regression": + load_net_reg, load_net_res = True, False + elif cfg.generation.inference_mode == "diffusion": + load_net_reg, load_net_res = False, True + elif cfg.generation.inference_mode == "all": + load_net_reg, load_net_res = True, True + else: + raise ValueError(f"Invalid inference mode {cfg.generation.inference_mode}") + + # Load diffusion network, move to device, change precision + if load_net_res: + res_ckpt_path = cfg.generation.io.res_ckpt_path + logger0.info(f'Loading residual network from "{res_ckpt_path}"...') + + diffusion_model_args_path = os.path.join(res_ckpt_path, 'model_args.json') + if not os.path.isfile(diffusion_model_args_path): + raise FileNotFoundError(f"Missing config file at '{diffusion_model_args_path}'.") + with open(diffusion_model_args_path, 'r') as f: + diffusion_model_args = json.load(f) + + net_res = EDMPrecond(**diffusion_model_args) + + _ = load_checkpoint( + path=res_ckpt_path, + model=net_res, + device=dist.device + ) + + net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) + if cfg.generation.perf.force_fp16: + net_res.use_fp16 = True + else: + net_res = None + + # load regression network, move to device, change precision + if load_net_reg: + reg_ckpt_path = cfg.generation.io.reg_ckpt_path + logger0.info(f'Loading network from "{reg_ckpt_path}"...') + + + regression_model_args_path = os.path.join(reg_ckpt_path, 'model_args.json') + if not os.path.isfile(regression_model_args_path): + raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") + with open(regression_model_args_path, 'r') as f: + regression_model_args = json.load(f) + + net_reg = EDMPrecond(**regression_model_args) + + _ = load_checkpoint( + path=reg_ckpt_path, + model=net_reg, + device=dist.device + ) + + net_reg = net_reg.eval().to(device).to(memory_format=torch.channels_last) + if cfg.generation.perf.force_fp16: + net_reg.use_fp16 = True + else: + net_reg = None \ No newline at end of file diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index d5a1ab2..ddb23b6 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -26,7 +26,7 @@ from einops import rearrange from torch.nn.functional import silu -from src.utils.model_utils import weight_init +from hirad.utils.model_utils import weight_init class Linear(torch.nn.Module): diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py index b0e924e..9c10004 100644 --- a/src/hirad/models/preconditioning.py +++ b/src/hirad/models/preconditioning.py @@ -29,11 +29,11 @@ import torch import torch.nn as nn -from hirad.models import ( +from .song_unet import ( DhariwalUNet, # noqa: F401 for globals SongUNet, # noqa: F401 for globals ) -from hirad.models import ModelMetaData +from .meta import ModelMetaData network_module = importlib.import_module("physicsnemo.models.diffusion") diff --git a/src/hirad/models/song_unet.py b/src/hirad/models/song_unet.py index 5bfca8a..6267dfc 100644 --- a/src/hirad/models/song_unet.py +++ b/src/hirad/models/song_unet.py @@ -29,7 +29,7 @@ from torch.utils.checkpoint import checkpoint import torch.nn as nn -from hirad.models import ( +from .layers import ( Conv2d, FourierEmbedding, GroupNorm, @@ -37,7 +37,7 @@ PositionalEmbedding, UNetBlock, ) -from hirad.models import ModelMetaData +from .meta import ModelMetaData @dataclass diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index 1333bda..d81a734 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn -from hirad.models import ModelMetaData +from .meta import ModelMetaData network_module = importlib.import_module("src.models") diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index 03b423d..e0f8d58 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -24,7 +24,7 @@ from torch.optim.lr_scheduler import _LRScheduler from hirad.distributed import DistributedManager -from hirad.utils.console import PythonLogger +from .console import PythonLogger optimizer = NewType("optimizer", torch.optim.Optimizer) scheduler = NewType("scheduler", _LRScheduler) diff --git a/src/hirad/utils/generate_utils.py b/src/hirad/utils/generate_utils.py index 29f7eb4..b99852f 100644 --- a/src/hirad/utils/generate_utils.py +++ b/src/hirad/utils/generate_utils.py @@ -1,6 +1,6 @@ import datetime from hirad.datasets import init_dataset_from_config -from hirad.utils.function_utils import convert_datetime_to_cftime +from .function_utils import convert_datetime_to_cftime def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False): diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 842bdd3..b158ec0 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -21,7 +21,7 @@ import torch import tqdm -from src.utils.function_utils import StackedRandomGenerator, time_range +from .function_utils import StackedRandomGenerator, time_range ############################################################################ # CorrDiff Generation Utilities # From c284d0aa948bca579290fc9abe1ebb75212e3ad4 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 29 Apr 2025 14:51:43 +0200 Subject: [PATCH 019/189] add getitem to dataset --- src/hirad/datasets/era5_cosmo.py | 40 +++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 597bdfc..4d9187d 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -33,8 +33,28 @@ def __init__(self, dataset_path: str): # Load stats for normalizing channels of input and output cosmo_stats = torch.load(os.path.join(self._info_path,'cosmo-stats'), weights_only=False) - print(cosmo_stats) - + self.output_mean = cosmo_stats['mean'] + self.output_std = cosmo_stats['stdev'] + + era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False) + #TODO Switch from cosmo to era stats once era-interpolated has all channels + self.input_mean = cosmo_stats['mean'] + self.input_std = cosmo_stats['stdev'] + + + def __getitem__(self, idx): + # get era5 data point + era5_data = torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ + .squeeze()\ + .reshape(-1,*self.image_shape()) + era5_data = self.normalize_input(era5_data) + # get cosmo data point + cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ + .squeeze()\ + .reshape(-1,*self.image_shape()) + cosmo_data = self.normalize_output(cosmo_data) + # return samples + return cosmo_data, era5_data, 0 def __len__(self): return len(self._file_list) @@ -70,25 +90,29 @@ def time(self) -> List: def image_shape(self) -> Tuple[int, int]: """Get the (height, width) of the data (same for input and output).""" - #TODO load from info, I hardcode it for now - return 390,582 + #TODO load from info, I hardcode it for now (cosmo from anemoi-datasets minus trim-edge=20) + return 350,542 def normalize_input(self, x: np.ndarray) -> np.ndarray: """Convert input from physical units to normalized data.""" - return (x - self.input_mean) / self.input_std + return (x - self.input_mean.reshape((self.input_mean.shape[0],1,1))) \ + / self.input_std.reshape((self.input_std.shape[0],1,1)) def denormalize_input(self, x: np.ndarray) -> np.ndarray: """Convert input from normalized data to physical units.""" - return x * self.input_std + self.input_mean + return x * self.input_std.reshape((self.input_std.shape[0],1,1)) \ + + self.input_mean.reshape((self.input_mean.shape[0],1,1)) def normalize_output(self, x: np.ndarray) -> np.ndarray: """Convert output from physical units to normalized data.""" - return (x - self.output_mean) / self.output_std + return (x - self.output_mean.reshape((self.output_mean.shape[0],1,1))) \ + / self.output_std.reshape((self.output_std.shape[0],1,1)) def denormalize_output(self, x: np.ndarray) -> np.ndarray: """Convert output from normalized data to physical units.""" - return x * self.output_std + self.output_mean \ No newline at end of file + return x * self.output_std.reshape((self.output_std.shape[0],1,1)) \ + + self.output_mean.reshape((self.output_mean.shape[0],1,1)) \ No newline at end of file From e7d5b1b6324a8affba89bc216074395ca7b8af25 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 29 Apr 2025 15:11:52 +0200 Subject: [PATCH 020/189] add grid flip to start at top left corner --- src/hirad/datasets/era5_cosmo.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 4d9187d..89a0581 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -44,14 +44,19 @@ def __init__(self, dataset_path: str): def __getitem__(self, idx): # get era5 data point - era5_data = torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ - .squeeze()\ - .reshape(-1,*self.image_shape()) + # squeeze the ensemble dimesnsion + # reshape to image_shape + # flip so that it starts in top-left corner (by default it is bottom left) + era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ + .squeeze() \ + .reshape(-1,*self.image_shape()), + 1) era5_data = self.normalize_input(era5_data) # get cosmo data point - cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ - .squeeze()\ - .reshape(-1,*self.image_shape()) + cosmo_data = np.flip(torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ + .squeeze() \ + .reshape(-1,*self.image_shape()), + 1) cosmo_data = self.normalize_output(cosmo_data) # return samples return cosmo_data, era5_data, 0 From d093e662e23f079f891accbd656b258fa7e5b8ac Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 29 Apr 2025 15:26:51 +0200 Subject: [PATCH 021/189] small fix --- src/hirad/datasets/dataset.py | 2 +- src/hirad/datasets/era5_cosmo.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 6e402d9..6cc6165 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -82,7 +82,7 @@ def init_dataset_from_config( seed: int = 0, ) -> Tuple[DownscalingDataset, Iterable]: dataset_cfg = copy.deepcopy(dataset_cfg) - dataset_type = dataset_cfg.pop("type", "cwb") + dataset_type = dataset_cfg.pop("type", "era5_cosmo") if "train_test_split" in dataset_cfg: # handled by init_train_valid_datasets_from_config del dataset_cfg["train_test_split"] diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 89a0581..e7de456 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -13,7 +13,7 @@ def __init__(self, dataset_path: str): self._dataset_path = dataset_path self._era5_path = os.path.join(dataset_path, 'era-interpolated') self._cosmo_path = os.path.join(dataset_path, 'cosmo') - self._info_path = os.path.join(dataset_path, 'info') + self._info_path = os.path.join(dataset_path, 'old/info') # load file list (each file is one date-time state) self._file_list = os.listdir(self._cosmo_path) @@ -37,9 +37,8 @@ def __init__(self, dataset_path: str): self.output_std = cosmo_stats['stdev'] era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False) - #TODO Switch from cosmo to era stats once era-interpolated has all channels - self.input_mean = cosmo_stats['mean'] - self.input_std = cosmo_stats['stdev'] + self.input_mean = era_stats['mean'] + self.input_std = era_stats['stdev'] def __getitem__(self, idx): @@ -59,7 +58,7 @@ def __getitem__(self, idx): 1) cosmo_data = self.normalize_output(cosmo_data) # return samples - return cosmo_data, era5_data, 0 + return torch.tensor(cosmo_data), torch.tensor(era5_data), 0 def __len__(self): return len(self._file_list) From 7e695f0305f6190aa4484c020926cdafb62d64e1 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 7 May 2025 15:19:32 +0200 Subject: [PATCH 022/189] update everything for training --- src/hirad/conf/dataset/era_cosmo.yaml | 2 + src/hirad/conf/model/era_cosmo_diffusion.yaml | 5 + .../conf/model/era_cosmo_regression.yaml | 2 + src/hirad/conf/training.yaml | 0 .../conf/training/era_cosmo_diffusion.yaml | 41 +++ .../conf/training/era_cosmo_regression.yaml | 38 +++ .../conf/training_era_cosmo_diffusion.yaml | 19 ++ .../conf/training_era_cosmo_regression.yaml | 19 ++ src/hirad/datasets/era5_cosmo.py | 13 +- src/hirad/distributed/config.py | 2 +- src/hirad/distributed/manager.py | 34 ++- src/hirad/models/__init__.py | 7 +- src/hirad/models/dhariwal_unet.py | 259 ++++++++++++++++++ src/hirad/models/preconditioning.py | 6 +- src/hirad/models/unet.py | 2 +- src/hirad/testrun.sh | 41 +++ src/hirad/training/train.py | 37 +-- 17 files changed, 485 insertions(+), 42 deletions(-) create mode 100644 src/hirad/conf/dataset/era_cosmo.yaml create mode 100644 src/hirad/conf/model/era_cosmo_diffusion.yaml create mode 100644 src/hirad/conf/model/era_cosmo_regression.yaml delete mode 100644 src/hirad/conf/training.yaml create mode 100644 src/hirad/conf/training/era_cosmo_diffusion.yaml create mode 100644 src/hirad/conf/training/era_cosmo_regression.yaml create mode 100644 src/hirad/conf/training_era_cosmo_diffusion.yaml create mode 100644 src/hirad/conf/training_era_cosmo_regression.yaml create mode 100644 src/hirad/models/dhariwal_unet.py create mode 100644 src/hirad/testrun.sh diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml new file mode 100644 index 0000000..854b775 --- /dev/null +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -0,0 +1,2 @@ +type: era5_cosmo +dataset_path: /store_new/mch/msopr/hirad-gen/basic-torch \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_diffusion.yaml b/src/hirad/conf/model/era_cosmo_diffusion.yaml new file mode 100644 index 0000000..06aa2a4 --- /dev/null +++ b/src/hirad/conf/model/era_cosmo_diffusion.yaml @@ -0,0 +1,5 @@ +name: diffusion + # Name of the preconditioner +hr_mean_conditioning: True + # High-res mean (regression's output) as additional condition +scale_cond_input: False \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_regression.yaml b/src/hirad/conf/model/era_cosmo_regression.yaml new file mode 100644 index 0000000..487eb4b --- /dev/null +++ b/src/hirad/conf/model/era_cosmo_regression.yaml @@ -0,0 +1,2 @@ +name: regression +hr_mean_conditioning: False \ No newline at end of file diff --git a/src/hirad/conf/training.yaml b/src/hirad/conf/training.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml new file mode 100644 index 0000000..b61603a --- /dev/null +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -0,0 +1,41 @@ +# Hyperparameters +hp: + training_duration: 128 + # Training duration based on the number of processed samples + total_batch_size: 16 + # Total batch size + batch_size_per_gpu: "auto" + # Batch size per GPU + lr: 0.0002 + # Learning rate + grad_clip_threshold: null + # no gradient clipping for defualt non-patch-based training + lr_decay: 1 + # LR decay rate + lr_rampup: 0 + # Rampup for learning rate, in number of samples + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 4 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# I/O +io: + regression_checkpoint_path: /scratch/mch/pstamenk/output/regression/checkpoints_regression + # Where to load the regression checkpoint + print_progress_freq: 32 + # How often to print progress + save_checkpoint_freq: 5000 + # How often to save the checkpoints, measured in number of processed samples + validation_freq: 5000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 10 + # how many loss evaluations are used to compute the validation loss per checkpoint + # how many loss evaluations are used to compute the validation loss per checkpoint + checkpoint_dir: /scratch/mch/pstamenk/output/diffusion \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml new file mode 100644 index 0000000..7c443f0 --- /dev/null +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -0,0 +1,38 @@ +# Hyperparameters +hp: + training_duration: 16 + # Training duration based on the number of processed samples + total_batch_size: 16 + # Total batch size + batch_size_per_gpu: "auto" + # Batch size per GPU + lr: 0.0002 + # Learning rate + grad_clip_threshold: null + # no gradient clipping for defualt non-patch-based training + lr_decay: 1 + # LR decay rate + lr_rampup: 0 + # Rampup for learning rate, in number of samples + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 4 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# I/O +io: + print_progress_freq: 32 + # How often to print progress + save_checkpoint_freq: 5000 + # How often to save the checkpoints, measured in number of processed samples + validation_freq: 5000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 10 + # how many loss evaluations are used to compute the validation loss per checkpoint + checkpoint_dir: /scratch/mch/pstamenk/output/regression \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml new file mode 100644 index 0000000..7ee7dba --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -0,0 +1,19 @@ +hydra: + job: + chdir: true + name: diffusion + run: + dir: /scratch/mch/pstamenk/output/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_diffusion + + # Training + - training/era_cosmo_diffusion \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml new file mode 100644 index 0000000..d857d12 --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -0,0 +1,19 @@ +hydra: + job: + chdir: true + name: regression + run: + dir: /scratch/mch/pstamenk/output/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_regression + + # Training + - training/era_cosmo_regression \ No newline at end of file diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index e7de456..8b0d60f 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -4,6 +4,7 @@ import torch from typing import List, Tuple import yaml +import torch.nn.functional as F class ERA5_COSMO(DownscalingDataset): def __init__(self, dataset_path: str): @@ -42,23 +43,27 @@ def __init__(self, dataset_path: str): def __getitem__(self, idx): + """Get cosmo and era5 interpolated to cosmo grid""" # get era5 data point # squeeze the ensemble dimesnsion # reshape to image_shape # flip so that it starts in top-left corner (by default it is bottom left) + orig_shape = [350,542] #TODO currently padding to be divisible by 16 era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ .squeeze() \ - .reshape(-1,*self.image_shape()), + .reshape(-1,*orig_shape), 1) era5_data = self.normalize_input(era5_data) # get cosmo data point cosmo_data = np.flip(torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ .squeeze() \ - .reshape(-1,*self.image_shape()), + .reshape(-1,*orig_shape), 1) cosmo_data = self.normalize_output(cosmo_data) # return samples - return torch.tensor(cosmo_data), torch.tensor(era5_data), 0 + return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ + F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ + 0 def __len__(self): return len(self._file_list) @@ -95,7 +100,7 @@ def time(self) -> List: def image_shape(self) -> Tuple[int, int]: """Get the (height, width) of the data (same for input and output).""" #TODO load from info, I hardcode it for now (cosmo from anemoi-datasets minus trim-edge=20) - return 350,542 + return 352,544 #TODO 350,542 is orig size, UNet requires dimenions divisible by 16, for now, I just add zeros to orig images def normalize_input(self, x: np.ndarray) -> np.ndarray: diff --git a/src/hirad/distributed/config.py b/src/hirad/distributed/config.py index c5414b4..2808d92 100644 --- a/src/hirad/distributed/config.py +++ b/src/hirad/distributed/config.py @@ -84,7 +84,7 @@ class ProcessGroupConfig: Examples -------- - >>> from physicsnemo.distributed import ProcessGroupNode, ProcessGroupConfig + >>> from hirad.distributed import ProcessGroupNode, ProcessGroupConfig >>> >>> # Create world group that contains all processes that are part of this job >>> world = ProcessGroupNode("world") diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py index 647d054..eca46c6 100644 --- a/src/hirad/distributed/manager.py +++ b/src/hirad/distributed/manager.py @@ -348,7 +348,7 @@ def initialize_slurm(port): rank = int(os.environ.get("SLURM_PROCID")) world_size = int(os.environ.get("SLURM_NPROCS")) local_rank = int(os.environ.get("SLURM_LOCALID")) - addr = os.environ.get("SLURM_LAUNCH_NODE_IPADDR") + addr = os.environ.get("MASTER_ADDR") DistributedManager.setup( rank=rank, @@ -388,6 +388,7 @@ def initialize(): port = os.getenv("MASTER_PORT", "12355") # https://pytorch.org/docs/master/notes/cuda.html#id5 # was changed in version 2.2 + #TODO why is setting this important? if torch.__version__ < (2, 2): os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" else: @@ -542,22 +543,25 @@ def setup( f"cuda:{manager.local_rank}" if torch.cuda.is_available() else "cpu" ) + #TODO device_id makes the init hang, couldn't figure out why if manager._distributed: # Setup distributed process group - try: - dist.init_process_group( - backend, - rank=manager.rank, - world_size=manager.world_size, - device_id=manager.device, - ) - except TypeError: - # device_id only introduced in PyTorch 2.3 - dist.init_process_group( - backend, - rank=manager.rank, - world_size=manager.world_size, - ) + # try: + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + ) + # rank=manager.rank, + # world_size=manager.world_size, + # device_id=manager.device, + # except TypeError: + # # device_id only introduced in PyTorch 2.3 + # dist.init_process_group( + # backend, + # rank=manager.rank, + # world_size=manager.world_size, + # ) if torch.cuda.is_available(): # Set device for this process and empty cache to optimize memory usage diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index f17e5ce..3ab4a6f 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,5 +1,6 @@ -from .unet import UNet +from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding +from .meta import ModelMetaData from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd +from .dhariwal_unet import DhariwalUNet +from .unet import UNet from .preconditioning import EDMPrecondSR, EDMPrecond -from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding -from .meta import ModelMetaData \ No newline at end of file diff --git a/src/hirad/models/dhariwal_unet.py b/src/hirad/models/dhariwal_unet.py new file mode 100644 index 0000000..3880cd0 --- /dev/null +++ b/src/hirad/models/dhariwal_unet.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architectures used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +from torch.nn.functional import silu +import torch.nn as nn + +from .layers import ( + Conv2d, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from .meta import ModelMetaData + + +@dataclass +class MetaData(ModelMetaData): + name: str = "DhariwalUNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class DhariwalUNet(nn.Module): + """ + Reimplementation of the ADM architecture, a U-Net variant, with optional + self-attention. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters + ----------- + img_resolution : int + The resolution of the input/output image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 192. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,3,4]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 3. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [32, 16, 8]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.10. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + + Reference + ---------- + Reference: Dhariwal, P. and Nichol, A., 2021. Diffusion models beat gans on image + synthesis. Advances in neural information processing systems, 34, pp.8780-8794. + + Note + ----- + Equivalent to the original implementation by Dhariwal and Nichol, available at + https://github.com/openai/guided-diffusion + + Example + -------- + >>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + """ + + def __init__( + self, + img_resolution: int, + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 192, + channel_mult: List[int] = [1, 2, 3, 4], + channel_mult_emb: int = 4, + num_blocks: int = 3, + attn_resolutions: List[int] = [32, 16, 8], + dropout: float = 0.10, + label_dropout: float = 0.0, + ): + super().__init__(meta=MetaData()) + self.label_dropout = label_dropout + emb_channels = model_channels * channel_mult_emb + init = dict( + init_mode="kaiming_uniform", + init_weight=np.sqrt(1 / 3), + init_bias=np.sqrt(1 / 3), + ) + init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0) + block_kwargs = dict( + emb_channels=emb_channels, + channels_per_head=64, + dropout=dropout, + init=init, + init_zero=init_zero, + ) + + # Mapping. + self.map_noise = PositionalEmbedding(num_channels=model_channels) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=model_channels, + bias=False, + **init_zero, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=model_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + self.map_label = ( + Linear( + in_features=label_dim, + out_features=emb_channels, + bias=False, + init_mode="kaiming_normal", + init_weight=np.sqrt(label_dim), + ) + if label_dim + else None + ) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + for level, mult in enumerate(channel_mult): + res = img_resolution >> level + if level == 0: + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_conv"] = Conv2d( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + skips = [block.out_channels for block in self.enc.values()] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = img_resolution >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + self.out_norm = GroupNorm(num_channels=cout) + self.out_conv = Conv2d( + in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + ) + + def forward(self, x, noise_labels, class_labels, augment_labels=None): + # Mapping. + emb = self.map_noise(noise_labels) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = self.map_layer1(emb) + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label(tmp) + emb = silu(emb) + + # Encoder. + skips = [] + for block in self.enc.values(): + x = block(x, emb) if isinstance(block, UNetBlock) else block(x) + skips.append(x) + + # Decoder. + for block in self.dec.values(): + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + x = block(x, emb) + x = self.out_conv(silu(self.out_norm(x))) + return x diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py index 9c10004..c66b6b6 100644 --- a/src/hirad/models/preconditioning.py +++ b/src/hirad/models/preconditioning.py @@ -30,12 +30,14 @@ import torch.nn as nn from .song_unet import ( - DhariwalUNet, # noqa: F401 for globals SongUNet, # noqa: F401 for globals ) +from .dhariwal_unet import ( + DhariwalUNet, # noqa: F401 for globals +) from .meta import ModelMetaData -network_module = importlib.import_module("physicsnemo.models.diffusion") +network_module = importlib.import_module("hirad.models") @dataclass diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index d81a734..10079ec 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -22,7 +22,7 @@ from .meta import ModelMetaData -network_module = importlib.import_module("src.models") +network_module = importlib.import_module("hirad.models") @dataclass diff --git a/src/hirad/testrun.sh b/src/hirad/testrun.sh new file mode 100644 index 0000000..ee4a977 --- /dev/null +++ b/src/hirad/testrun.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=16G +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/scratch/mch/pstamenk/logs/regression_test.log +#SBATCH --error=/scratch/mch/pstamenk/logs/regression_test.err + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=ENV + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# Use SLURM_NTASKS (number of processes to be launched by torchrun) +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute threads per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS +echo "Node: $(hostname)" +echo "Physical cores: $PHYSICAL_CORES" +echo "Local processes: $LOCAL_PROCS" +echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# activate conda env +CONDA_ENV=train +source /users/pstamenk/.bashrc +mamba activate $CONDA_ENV + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +torchrun --nproc-per-node=$LOCAL_PROCS src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 0539aad..88b2118 100644 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -9,6 +9,7 @@ from hydra.utils import to_absolute_path from torch.utils.tensorboard import SummaryWriter from torch.nn.parallel import DistributedDataParallel +from torchinfo import summary from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper @@ -20,9 +21,10 @@ from hirad.losses import ResLoss, RegressionLoss, RegressionLossCE from hirad.datasets import init_train_valid_datasets_from_config -@hydra.main(version_base=None, config_path="conf", config_name="training") +from matplotlib import pyplot as plt + +@hydra.main(version_base=None, config_path="../conf", config_name="training") def main(cfg: DictConfig) -> None: - # Initialize distributed environment for training DistributedManager.initialize() dist = DistributedManager() @@ -45,10 +47,12 @@ def main(cfg: DictConfig) -> None: fp16 = fp_optimizations == "fp16" enable_amp = fp_optimizations.startswith("amp") amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 - logger.info(f"Saving the outputs in {os.getcwd()}") + logger0.info(f"Saving the outputs in {os.getcwd()}") checkpoint_dir = os.path.join( cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" ) + if dist.rank==0 and not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) # added creating checkpoint dir if cfg.training.hp.batch_size_per_gpu == "auto": cfg.training.hp.batch_size_per_gpu = ( cfg.training.hp.total_batch_size // dist.world_size @@ -85,12 +89,14 @@ def main(cfg: DictConfig) -> None: if cfg.model.hr_mean_conditioning: img_in_channels += img_out_channels + if cfg.model.name == "lt_aware_ce_regression": - prob_channels = dataset.get_prob_channel_index() + prob_channels = dataset.get_prob_channel_index() #TODO figure out what prob_channel are and update dataloader else: prob_channels = None # Parse the patch shape + #TODO figure out patched diffusion and how to use it if ( cfg.model.name == "patched_diffusion" or cfg.model.name == "lt_aware_patched_diffusion" @@ -109,9 +115,8 @@ def main(cfg: DictConfig) -> None: # interpolate global channel if patch-based model is used if img_shape[1] != patch_shape[1]: img_in_channels += dataset_channels - - # Instantiate the model and move to device. + # Instantiate the model and move to device. if cfg.model.name not in ( "regression", "lt_aware_ce_regression", @@ -180,7 +185,7 @@ def main(cfg: DictConfig) -> None: img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) - model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] elif cfg.model.name == "lt_aware_ce_regression": model = UNet( img_in_channels=img_in_channels @@ -188,7 +193,7 @@ def main(cfg: DictConfig) -> None: + model_args["lead_time_channels"], **model_args, ) - model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] elif cfg.model.name == "lt_aware_patched_diffusion": model = EDMPrecondSR( img_in_channels=img_in_channels @@ -196,18 +201,21 @@ def main(cfg: DictConfig) -> None: + model_args["lead_time_channels"], **model_args, ) - model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] else: # diffusion or patched diffusion model = EDMPrecondSR( img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) - model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] model.train().requires_grad_(True).to(dist.device) + # TODO write summry from rank=0 possibly + # summary(model, input_size=[(4,img_out_channels,*img_shape),(4,img_in_channels,*img_shape),(4,1),(4,1)]) + if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): - with open(os.path.join(checkpoint_dir, 'model_args.json'), 'w') as f: + with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: json.dump(model_args, f) # Enable distributed data parallel if applicable @@ -220,7 +228,7 @@ def main(cfg: DictConfig) -> None: find_unused_parameters=dist.find_unused_parameters, ) - # Load the regression checkpoint if applicable + # Load the regression checkpoint if applicable #TODO test when training correction if hasattr(cfg.training.io, "regression_checkpoint_path"): regression_checkpoint_path = to_absolute_path( cfg.training.io.regression_checkpoint_path @@ -286,8 +294,7 @@ def main(cfg: DictConfig) -> None: dist.world_size, ) batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu - logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") - + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation {"rounds" if num_accumulation_rounds>1 else "round"}.") ## Resume training from previous checkpoints if exists if dist.world_size > 1: @@ -313,7 +320,6 @@ def main(cfg: DictConfig) -> None: average_loss_running_mean = 0 n_average_loss_running_mean = 1 - while not done: tick_start_nimg = cur_nimg tick_start_time = time.time() @@ -494,7 +500,6 @@ def main(cfg: DictConfig) -> None: optimizer=optimizer, epoch=cur_nimg, ) - pass # Done. logger0.info("Training Completed.") From 7510cf1ad7f40c0794d881d06277de793bbdde94 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 7 May 2025 15:20:07 +0200 Subject: [PATCH 023/189] small fix --- src/hirad/utils/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index e0f8d58..a346b16 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -294,7 +294,7 @@ def load_checkpoint( checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") if not Path(checkpoint_filename).is_file(): checkpoint_logging.warning( - "Could not find valid checkpoint file, skipping load" + f"Could not find valid checkpoint file {checkpoint_filename} skipping load" ) return 0 From 75db04fb5b550787fac0cea394d34e090127ae77 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 7 May 2025 15:27:18 +0200 Subject: [PATCH 024/189] remove tracked .pyc files --- .../models/__pycache__/__init__.cpython-312.pyc | Bin 185 -> 0 bytes .../models/__pycache__/dummy.cpython-312.pyc | Bin 451 -> 0 bytes .../models/__pycache__/unet.cpython-312.pyc | Bin 8966 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/hirad/models/__pycache__/__init__.cpython-312.pyc delete mode 100644 src/hirad/models/__pycache__/dummy.cpython-312.pyc delete mode 100644 src/hirad/models/__pycache__/unet.cpython-312.pyc diff --git a/src/hirad/models/__pycache__/__init__.cpython-312.pyc b/src/hirad/models/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 70d5748263686e7ee01a00c9c9ffd61e4261990b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 185 zcmX@j%ge<81eP(8=gQRm475F^(AnLQ) zQDjD9>o~pA)a3YQB1#<>S}gTyk;VCqxJWL~XOnrH3tdDayH#g=oRqaZdZNZE447`7==k|mj}EK82PsyZ3&%#yR( zo!Q)(l}KR)vxSQ;+q#mP6r@`haex4oDiu*4g2+$KQ(yWbg(|!mH~@j92#mZ?pxh>Y zNzS?RvrBqy2Tp*#_zF6E&;2>~@7#0GcZUBJk4G3t8zaA+{guElzd*)K0VlI@7c$F? z%*bq!$>OiC$d-IrA8VKSvwm8~mG~@A%lySaDVPnijE_;6Ledr3_?5Yvtj%S^iqFR! zWn}(MMh+;!n|{0erqehZp|wG%4Jmw~t0BY9xjnKm+A<6+BVW)mzQCrVAEAKhG;4)r z$&_+M$uNM4PwBE!oKj3_5=t*c43f)~u*k@{1LChUOJ?Jrfmw!CW?7lZ`p&=}W&H}L z@Urh&sFD3ySQBuZ!-XivJ6xD@0m=ol5z2*}x+vwsPF;+05vMLrxv0Y>C>L|MHp<13 zTS%natdP4Fi#s!|m=>o=CFO|aFQ}#!%~fPcoF!6u&ai?~sVwHomD6@^wp_77rcQEm zVt(G@X7VEkEKb+7c`K5yXgQHqSz)PS>L@{;PA9F9D57*xv?8JiyH+V;J|>Fasz^n< zhJ;{ZmftXm)h3FPrs<|+s=8)~A_-#^H>2xCf?7$UNMcChND@fekf279b|5#H4WP^? zgToa=A;xgoFkyV{{P2tFsqx97qlz{R2RLjW;BZA#%uIQKpkhQ*i@+!Autvv;O zAA=ri?I;CuU2EHW7mr`%R=M7K*UrUbHS<2Vlk~U)Z_*_mN-=G`2;^^?v?$h;iL^ED>V%E33oSzeCFp*PuVKnBgv25Zbbn+?6fq@yGV;-vrJ7$zb>Ix_^P zFMI>ex~vc(uM=THCuCC1=Q9!7;G{%g?@ZW6$__OW<<-)xNEAaaR`9rlF+tT#T7kk( z&nZGvrJff|C{eX?#T?cvCj1C$Nt#tMo~Ya$oITvCErNvUN@)h>2NH665hm%jYxGn# z(JkE4xzU%7{FxG9Mqez#nRcsLfUb%8^2mWP0d;jy(95_bQqkLv!}Ch6Vk&eKDw?d2 z;)1Hp3RDdR$rO&9962DAiIP(d==l7gFtZ@YN?xiIp-w6q&ORCHY)KMJsx~I%i@IdC z7->q?)KaA+7?pBaC#E6`nywm(P*g4`FtaAP`|KTbWJ>d&BU76H17*@#4fLy8zCYb8 z9u4Ye%Rf9{R&qF1xU8Bi5gFEgL8Vpz&mA?#BoZtPwhSe0S;o`$NOUtrbD^w^fgsXi zZzj;B3uty+v0Kk#ljQ!%IfZ+{f8d`rsS-7af^r5Z- zR>w5xwZi#t9SE)%RgP0iULlH>Q?}kQS~~@MA!+KMFqIh;4naHpqH10m6sCcl(B~9w zP3~7Iq)g|iCtz)r-EUYVpZlX zu&*XqK|9@IZ#td_%T#?;Ms#0t#>D4`rYdWC@^9j7Ad$8zJI_LU*DGQmk0aBhZL6}f z56%0g7Od=d%Dk3UKeWI%kftgJTw`1wIB(=?061Ki-7*LqzJ)dy0uCQVn#?{g_lTRT zhTWN8j{2BtROYKJjLltw232@X-tku>-)H{;SLFF;3rspNeS@_)cm^bT3?JcM-SAt9 z&DP(FHW^_n)@)BnH}puiTRd1Kg`h_O7ukfz9rw+Op&ZwGTM;mS8XAF|g69qX2v!?& zQdtq_o0$cx4QrYQO=3awz*fjL%U0MkoL*J}x5X|+i(&<5VTC66APH~@;0~X6ycLV5 z_dOlAI72Ds$zw2Sl17q2GK^#&k`W}MNcIDfc$2WdANfKCzgWuLf!Ao-3@i$5B0D04%E2&p`M2U zCf0RT`+n|zbZ>q1g}bpAYLoY4y({K_?;2j)HFRCNw`(6QsW&rgyE53jw=V3%-@dhO zp~n4vyxy}1eiH7_j>mv&`>?C;9qoEywd-Il`m>a<&hTB4`p$H{yZ`Rqv3uQP^`7*D zAe60#n6{k{!jP+RZ$+rjlXG1hSOtU1Pp|75T&-oegn!v%I#-fm3eI)IhHe7I_2qps zd&}pkG+Sj0%`%;)jf-FKuBj`9Vcd&<87{(??_Kb&YeTZ^ms$LG(=jub;gT1;i`-~O zo~&|48s_d@S!H+R#?KrNbrEeT3nto;`6f;c$mfXt|Dl}F=rIam$0loB}q4D8Fk zrupIR*^=Cu7`+VxDF>^8e*q}jS}xRt|CM)F|5q%u`tn*9s-hbnnI?N+UVtq9vReA+ zcaSU&Vb=|x#mnfvSOIi@zy!u@+fbBdz%Tx!$8(d=p`a}VQx|3w;gbmXE#Ohl>Zb5X z7`$x5R3teASdX_snLj6)me16!Aa&68k6J-+&e5#e#_JGF$l~bOc8ie%xk$kDqu&ZJ z5x^rl%KJyjL1<#j^B)1z?w=Ir4ZKIL^%JP|gZb(NJT3WBj9Ng4TY!c}KbVL`W-Z!& zH`@1MGPNAO7G4=y$*(+ledK;}s5ZV9i(i>sn*2^`x$j!vYHU}HtwXEOwa`1M@Akdj zx0>w#AUS+5Iego9d;Ipd?(AMoK3n7015A5rIsT9F?-g#p{NvYt_}Y)p{qWpR@;`n0 zrH6hd-f?B>Z>N6Q-h~p^gJ5{qV@z95O}iypLdLR0I@x%i@*cMrjFgqSuk75p_?0gW1T+P4U!NikSjx8NqN&PhT z$SwYZf$@6-((E`kfO$S-6{gdBgAPOsq2?79_TH zT*)uxzjJ2!+_iJ7iAQVxdRs@$F}cl#6nz~0zcB~|{G%^0x9$Z1^<|~w&^Y3r(}b>F zPzbng!YsIjnm|ulxPaO)V5Ct$CBalQ16)G%@)QHSV(QpC7!&1e!I@aJ1!`iEEqP2h zqk*$|mg=!kR1I)wsg7q((^1ZOGs8Zo;i7zcMmc-VO*1&9m=&VgE;#mchWBJOkKh%$ z1GHS83oyz=zo5zpZSV?s8)JL6C?V}5N8t=0ZeobGBq+fPL{0-S@X!S3GDL|EFhZn~ zE+y%OXewF3=AA&mDd=UAsVi?Ry|GeRZF~G~=yCey-e8Zw-9)2Luo*!EZUseA)^iY( z;pvW2cZ0l!)n~Cf?)WD;rC1cj8!Y)clp3flFWBx0y0J8_fbNiiDXTH==@p7_7(oygZ0j?dgo&wryl1c8`Eqb-}%4KCHN=o2nb5a!GABgMs@=&)0c>V zgzX3jjeP{s=vwht;1Ej_V6<@KBnz~ko}?WGx!4p1@fvH@KnHsL)L< zzY42-6;}BwtkM)qbMBiRG9VbuB?yJBkQ*+@H9PBQb|~<83M|@|I6>5j6S$TL;`le6 zR}x@{lVceZWzYmM0uRNzu*8cIcv1jedk#5`5xDVyMhjNvJ{32>xI|;5KS{*E-09BT z#C@Wv8ggfT8NRYqBXR(tR@E=_5FH3R(E$v&(*)6h;53qvq4=Xd-X_MwjE80P5Hg;)ch zT_-hz>@LsnZ$jJu2|bK|2ZCt3yMeafePi{JXDHCvksJPjM@SS0+emk zx0uysjzZZ)V=On!(1L zuktOWa1M}d%&yfJPG^w-WYJX61b3AKxEgFN7a{H@Aj|&|8W5-kz^4jJZ-Ek>s%V0Q zAzy@PLKy?>j#uJcC-7Ki!D*pEV(M-+dEVm?MIWP1kb5QPg@I;wJB{8Del7wifw-3n z*wg{Wc2K1-u=(=~^re;HqZuYfra?x&Ebbo_)Vy$-RFsS`X@6ufu*&(+;t4+`!H1JA zvuohoG2vL=nHokA1w_vR_Yy2b`XI5QOL04{;JfVGk?XsE(0`-<&hGaI-W#|dn|Q$U zfW`xe%Y&G6;+7PbC;kXr4st#lmy2(0hp&fR?E|BkSG_)Z&qVc&GZ3acgjU|3XVtSBlpjXfy_K|^Fi zRQuz9M$;@G{NG-lTQJm|p(!QZ`GbfI{Y-#X+KSuKoG7~ee52!I1*Yg<0Z`0n`wr+Q z*fYV9ofUc(0_5;_0?!c*SbV08IUwtPmSs2gFl_tJnccr&Vn1h6|HF)|wGUkF`JjE^ zUi-lH-qrR4iwQVBU-Y3b!S>cJZ!qw);fI2CEO@wQo@Jk?jb1spbnqd36=$D#c)>rz TvYl5CEg!pf>|+M=bSA$CZ1ZH& From 4b5ecb4f7cfa0048c3d253ed50b9ae6537848245 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 9 May 2025 14:21:47 +0200 Subject: [PATCH 025/189] add small loggign changes --- src/hirad/training/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) mode change 100644 => 100755 src/hirad/training/train.py diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py old mode 100644 new mode 100755 index 88b2118..9ce619d --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -80,6 +80,7 @@ def main(cfg: DictConfig) -> None: validation_dataset_cfg=validation_dataset_cfg, train_test_split=train_test_split, ) + logger0.info(f"Training on dataset with size {len(dataset)}") # Parse image configuration & update model args dataset_channels = len(dataset.input_channels()) @@ -295,7 +296,7 @@ def main(cfg: DictConfig) -> None: ) batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu logger0.info(f"Using {num_accumulation_rounds} gradient accumulation {"rounds" if num_accumulation_rounds>1 else "round"}.") - + logger0.info(f"Batch size per gpu: {batch_size_per_gpu}") ## Resume training from previous checkpoints if exists if dist.world_size > 1: torch.distributed.barrier() From 3eb4d0a7b15af18b42ae647a2f08a1c5b45b5b8e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 9 May 2025 14:22:55 +0200 Subject: [PATCH 026/189] adapt sbatch script to slurm config --- src/hirad/testrun.sh | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/hirad/testrun.sh b/src/hirad/testrun.sh index ee4a977..ac631c8 100644 --- a/src/hirad/testrun.sh +++ b/src/hirad/testrun.sh @@ -5,20 +5,33 @@ ### HARDWARE ### #SBATCH --partition=debug #SBATCH --nodes=1 -#SBATCH --gres=gpu:4 -#SBATCH --ntasks-per-node=4 -#SBATCH --cpus-per-task=16 -#SBATCH --mem=16G +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 #SBATCH --time=00:30:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/scratch/mch/pstamenk/logs/regression_test.log -#SBATCH --error=/scratch/mch/pstamenk/logs/regression_test.err +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a01 # Choose method to initialize dist in pythorch -export DISTRIBUTED_INITIALIZATION_METHOD=ENV +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" # Get number of physical cores using Python PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") @@ -27,15 +40,12 @@ LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} # Compute threads per process OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) export OMP_NUM_THREADS=$OMP_THREADS -echo "Node: $(hostname)" echo "Physical cores: $PHYSICAL_CORES" echo "Local processes: $LOCAL_PROCS" echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" -# activate conda env -CONDA_ENV=train -source /users/pstamenk/.bashrc -mamba activate $CONDA_ENV - # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -torchrun --nproc-per-node=$LOCAL_PROCS src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml \ No newline at end of file +srun bash -c " + . ./train_env/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml +" \ No newline at end of file From 9aaea9d2b2f7ddb9825613beb6303c32e10ed784 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 9 May 2025 14:23:40 +0200 Subject: [PATCH 027/189] adapt era5cosmo loader to trim_edge 19 --- src/hirad/datasets/era5_cosmo.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 8b0d60f..f8835d1 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -14,7 +14,7 @@ def __init__(self, dataset_path: str): self._dataset_path = dataset_path self._era5_path = os.path.join(dataset_path, 'era-interpolated') self._cosmo_path = os.path.join(dataset_path, 'cosmo') - self._info_path = os.path.join(dataset_path, 'old/info') + self._info_path = os.path.join(dataset_path, 'info') # load file list (each file is one date-time state) self._file_list = os.listdir(self._cosmo_path) @@ -48,7 +48,8 @@ def __getitem__(self, idx): # squeeze the ensemble dimesnsion # reshape to image_shape # flip so that it starts in top-left corner (by default it is bottom left) - orig_shape = [350,542] #TODO currently padding to be divisible by 16 + # orig_shape = [350,542] #TODO currently padding to be divisible by 16 + orig_shape = self.image_shape() era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ .squeeze() \ .reshape(-1,*orig_shape), @@ -61,9 +62,12 @@ def __getitem__(self, idx): 1) cosmo_data = self.normalize_output(cosmo_data) # return samples - return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ - F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ + return torch.tensor(cosmo_data),\ + torch.tensor(era5_data),\ 0 + # return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ + # F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ + # 0 def __len__(self): return len(self._file_list) From dca7ff446056ce2596248559e3db2b853f576cd4 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 12 May 2025 17:32:56 +0200 Subject: [PATCH 028/189] add inference --- src/hirad/conf/generate_era_cosmo.yaml | 20 ++ src/hirad/conf/generation/era_cosmo.yaml | 37 ++++ src/hirad/conf/sampler/deterministic.yaml | 4 + src/hirad/conf/sampler/stochastic.yaml | 3 + src/hirad/inference/generate.py | 253 +++++++++++++++++++++- src/hirad/utils/generate_utils.py | 18 +- 6 files changed, 319 insertions(+), 16 deletions(-) create mode 100644 src/hirad/conf/generate_era_cosmo.yaml create mode 100644 src/hirad/conf/generation/era_cosmo.yaml create mode 100644 src/hirad/conf/sampler/deterministic.yaml create mode 100644 src/hirad/conf/sampler/stochastic.yaml diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml new file mode 100644 index 0000000..03650e2 --- /dev/null +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -0,0 +1,20 @@ +hydra: + job: + chdir: true + name: generation + run: + dir: ./outputs/${hydra:job.name} + +# Get defaults +defaults: + + # Dataset + - dataset/era_cosmo + + # Sampler + - sampler/stochastic + #- sampler/deterministic + + # Generation + - generation/era_cosmo + #- generation/patched_based \ No newline at end of file diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml new file mode 100644 index 0000000..2e37a63 --- /dev/null +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -0,0 +1,37 @@ +num_ensembles: 64 + # Number of ensembles to generate per input +seed_batch_size: 1 + # Size of the batched inference +inference_mode: regression + # Choose between "all" (regression + diffusion), "regression" or "diffusion" + # Patch size. Patch-based sampling will be utilized if these dimensions differ from + # img_shape_x and img_shape_y +overlap_pixels: 0 + # Number of overlapping pixels between adjacent patches +boundary_pixels: 0 + # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary + # artifact. +hr_mean_conditioning: False +sample_res: full + # Sampling resolution +times_range: null +times: + - 20160101-0000 + +perf: + force_fp16: false + # Whether to force fp16 precision for the model. If false, it'll use the precision + # specified upon training. + use_torch_compile: false + # whether to use torch.compile on the diffusion model + # this will make the first time stamp generation very slow due to compilation overheads + # but will significantly speed up subsequent inference runs + num_writer_workers: 1 + # number of workers to use for writing file + # To support multiple workers a threadsafe version of the netCDF library must be used + +io: + res_ckpt_path: diffusion_checkpoint + # Checkpoint filename for the diffusion model + reg_ckpt_path: regression_checkpoint + # Checkpoint filename for the mean predictor model \ No newline at end of file diff --git a/src/hirad/conf/sampler/deterministic.yaml b/src/hirad/conf/sampler/deterministic.yaml new file mode 100644 index 0000000..35bc0f6 --- /dev/null +++ b/src/hirad/conf/sampler/deterministic.yaml @@ -0,0 +1,4 @@ +type: deterministic +num_steps: 9 + # Number of denoising steps +solver: euler \ No newline at end of file diff --git a/src/hirad/conf/sampler/stochastic.yaml b/src/hirad/conf/sampler/stochastic.yaml new file mode 100644 index 0000000..5e8fa88 --- /dev/null +++ b/src/hirad/conf/sampler/stochastic.yaml @@ -0,0 +1,3 @@ +type: stochastic +boundary_pix: 2 +overlap_pix: 4 \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 6e5273a..adb882e 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -10,6 +10,9 @@ from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from concurrent.futures import ThreadPoolExecutor from functools import partial + +from matplotlib import pyplot as plt +import cartopy.crs as ccrs from einops import rearrange from torch.distributed import gather @@ -57,7 +60,7 @@ def main(cfg: DictConfig) -> None: all_batches = torch.as_tensor(seeds).tensor_split(num_batches) rank_batches = all_batches[dist.rank :: dist.world_size] - # Synchronize + # Synchronize if dist.world_size > 1: torch.distributed.barrier() @@ -65,7 +68,7 @@ def main(cfg: DictConfig) -> None: if cfg.generation.times_range and cfg.generation.times: raise ValueError("Either times_range or times must be provided, but not both") if cfg.generation.times_range: - times = get_time_from_range(cfg.generation.times_range) #TODO check what time formats we are using and adapt + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") #TODO check what time formats we are using and adapt else: times = cfg.generation.times @@ -110,7 +113,7 @@ def main(cfg: DictConfig) -> None: # Load diffusion network, move to device, change precision if load_net_res: res_ckpt_path = cfg.generation.io.res_ckpt_path - logger0.info(f'Loading residual network from "{res_ckpt_path}"...') + logger0.info(f'Loading correction network from "{res_ckpt_path}"...') diffusion_model_args_path = os.path.join(res_ckpt_path, 'model_args.json') if not os.path.isfile(diffusion_model_args_path): @@ -135,7 +138,7 @@ def main(cfg: DictConfig) -> None: # load regression network, move to device, change precision if load_net_reg: reg_ckpt_path = cfg.generation.io.reg_ckpt_path - logger0.info(f'Loading network from "{reg_ckpt_path}"...') + logger0.info(f'Loading regression network from "{reg_ckpt_path}"...') regression_model_args_path = os.path.join(reg_ckpt_path, 'model_args.json') @@ -144,7 +147,7 @@ def main(cfg: DictConfig) -> None: with open(regression_model_args_path, 'r') as f: regression_model_args = json.load(f) - net_reg = EDMPrecond(**regression_model_args) + net_reg = UNet(**regression_model_args) _ = load_checkpoint( path=reg_ckpt_path, @@ -156,4 +159,242 @@ def main(cfg: DictConfig) -> None: if cfg.generation.perf.force_fp16: net_reg.use_fp16 = True else: - net_reg = None \ No newline at end of file + net_reg = None + + # Reset since we are using a different mode. + if cfg.generation.perf.use_torch_compile: + torch._dynamo.reset() + # Only compile residual network + # Overhead of compiling regression network outweights any benefits + if net_res: + net_res = torch.compile(net_res, mode="reduce-overhead") + + # Partially instantiate the sampler based on the configs + if cfg.sampler.type == "deterministic": + if cfg.generation.hr_mean_conditioning: + raise NotImplementedError( + "High-res mean conditioning is not yet implemented for the deterministic sampler" + ) + sampler_fn = partial( + deterministic_sampler, + num_steps=cfg.sampler.num_steps, + # num_ensembles=cfg.generation.num_ensembles, + solver=cfg.sampler.solver, + ) + elif cfg.sampler.type == "stochastic": + sampler_fn = partial( + stochastic_sampler, + img_shape=img_shape[1], + patch_shape=patch_shape[1], + boundary_pix=cfg.sampler.boundary_pix, + overlap_pix=cfg.sampler.overlap_pix, + ) + else: + raise ValueError(f"Unknown sampling method {cfg.sampling.type}") + + + # Main generation definition + def generate_fn(image_lr, lead_time_label): + img_shape_y, img_shape_x = img_shape + with nvtx.annotate("generate_fn", color="green"): + if cfg.generation.sample_res == "full": + image_lr_patch = image_lr + else: + torch.cuda.nvtx.range_push("rearrange") + image_lr_patch = rearrange( + image_lr, + "b c (h1 h) (w1 w) -> (b h1 w1) c h w", + h1=img_shape_y // patch_shape[0], + w1=img_shape_x // patch_shape[1], + ) + torch.cuda.nvtx.range_pop() + image_lr_patch = image_lr_patch.to(memory_format=torch.channels_last) + + if net_reg: + with nvtx.annotate("regression_model", color="yellow"): + image_reg = regression_step( + net=net_reg, + img_lr=image_lr_patch, + latents_shape=( + cfg.generation.seed_batch_size, + img_out_channels, + img_shape[0], + img_shape[1], + ), + lead_time_label=lead_time_label, + ) + if net_res: + if cfg.generation.hr_mean_conditioning: + mean_hr = image_reg[0:1] + else: + mean_hr = None + with nvtx.annotate("diffusion model", color="purple"): + image_res = diffusion_step( + net=net_res, + sampler_fn=sampler_fn, + seed_batch_size=cfg.generation.seed_batch_size, + img_shape=img_shape, + img_out_channels=img_out_channels, + rank_batches=rank_batches, + img_lr=image_lr_patch.expand( + cfg.generation.seed_batch_size, -1, -1, -1 + ).to(memory_format=torch.channels_last), + rank=dist.rank, + device=device, + hr_mean=mean_hr, + lead_time_label=lead_time_label, + ) + if cfg.generation.inference_mode == "regression": + image_out = image_reg + elif cfg.generation.inference_mode == "diffusion": + image_out = image_res + else: + image_out = image_reg + image_res + + if cfg.generation.sample_res != "full": + image_out = rearrange( + image_out, + "(b h1 w1) c h w -> b c (h1 h) (w1 w)", + h1=img_shape_y // patch_shape[0], + w1=img_shape_x // patch_shape[1], + ) + # Gather tensors on rank 0 + if dist.world_size > 1: + if dist.rank == 0: + gathered_tensors = [ + torch.zeros_like( + image_out, dtype=image_out.dtype, device=image_out.device + ) + for _ in range(dist.world_size) + ] + else: + gathered_tensors = None + + torch.distributed.barrier() + gather( + image_out, + gather_list=gathered_tensors if dist.rank == 0 else None, + dst=0, + ) + + if dist.rank == 0: + return torch.cat(gathered_tensors) + else: + return None + else: + return image_out + + # generate images + output_path = getattr(cfg.generation.io, "output_path", "./outputs") + logger0.info(f"Generating images, saving results to {output_path}...") + batch_size = 1 + warmup_steps = min(len(times) - 1, 2) + # Generates model predictions from the input data using the specified + # `generate_fn`, and save the predictions to the provided NetCDF file. It iterates + # through the dataset using a data loader, computes predictions, and saves them along + # with associated metadata. + + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + + data_loader = torch.utils.data.DataLoader( + dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True + ) + time_index = -1 + if dist.rank == 0: + writer_executor = ThreadPoolExecutor( + max_workers=cfg.generation.perf.num_writer_workers + ) + writer_threads = [] + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + times = dataset.time() + for image_tar, image_lr, index, *lead_time_label in iter(data_loader): + time_index += 1 + if dist.rank == 0: + logger0.info(f"starting index: {time_index}") + + if time_index == warmup_steps: + start.record() + + # continue + if lead_time_label: + lead_time_label = lead_time_label[0].to(dist.device).contiguous() + else: + lead_time_label = None + image_lr = ( + image_lr.to(device=device) + .to(torch.float32) + .to(memory_format=torch.channels_last) + ) + image_tar = image_tar.to(device=device).to(torch.float32) + image_out = generate_fn(image_lr,lead_time_label) + if dist.rank == 0: + batch_size = image_out.shape[0] + # write out data in a seperate thread so we don't hold up inferencing + writer_threads.append( + writer_executor.submit( + save_images, + output_path, + dataset, + image_out.cpu(), + image_tar.cpu(), + image_lr.cpu(), + ) + ) + end.record() + end.synchronize() + elapsed_time = start.elapsed_time(end) / 1000.0 # Convert ms to s + timed_steps = time_index + 1 - warmup_steps + if dist.rank == 0: + average_time_per_batch_element = elapsed_time / timed_steps / batch_size + logger.info( + f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s" + ) + logger.info( + f"Average time per batch element = {average_time_per_batch_element} s" + ) + + # make sure all the workers are done writing + if dist.rank == 0: + for thread in list(writer_threads): + thread.result() + writer_threads.remove(thread) + writer_executor.shutdown() + + if dist.rank == 0: + f.close() + logger0.info("Generation Completed.") + +def save_images(output_path, dataset, image_pred, image_hr, image_lr): + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + image_pred = np.flip(dataset.denormalize_output(image_pred.numpy()),1).reshape(len(output_channels),-1) + image_hr = np.flip(dataset.denormalize_output(image_hr.numpy()),1).reshape(len(output_channels),-1) + image_lr = np.flip(dataset.denormalize_input(image_lr.numpy()),1).reshape(len(input_channels),-1) + for idx, channel in enumerate(output_channels): + input_channel_idx = input_channels.index(channel) + _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{channel.name}-lr.jpg')) + _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{channel.name}-hr.jpg')) + _plot_projection(longitudes,latitudes,image_pred[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred.jpg')) + +def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): + + """Plot observed or interpolated data in a scatter plot.""" + # TODO: Refactor this somehow, it's not really generalizing well across variables. + fig = plt.figure() + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) + ax.coastlines() + ax.gridlines(draw_labels=True) + plt.colorbar(p, label="K", orientation="horizontal") + plt.savefig(filename) + plt.close('all') + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/src/hirad/utils/generate_utils.py b/src/hirad/utils/generate_utils.py index b99852f..43f83b6 100644 --- a/src/hirad/utils/generate_utils.py +++ b/src/hirad/utils/generate_utils.py @@ -8,17 +8,15 @@ def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False): Get a dataset and sampler for generation. """ (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1) - if has_lead_time: - plot_times = times - else: - plot_times = [ - convert_datetime_to_cftime( - datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") - ) - for time in times - ] + # if has_lead_time: + # plot_times = times + # else: + # plot_times = [ + # datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") + # for time in times + # ] all_times = dataset.time() - time_indices = [all_times.index(t) for t in plot_times] + time_indices = [all_times.index(t) for t in times] sampler = time_indices return dataset, sampler \ No newline at end of file From 8ba0c5a21e21a09ff1b248221f01378e4bc30754 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 12 May 2025 18:38:19 +0200 Subject: [PATCH 029/189] Plot absolute error onto a projection for a given date --- src/hirad/eval/metrics.py | 21 ++++++++++++++++++++ src/hirad/eval/plotting.py | 16 ++++++++++++++++ src/hirad/eval/run_scoring.py | 36 +++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 src/hirad/eval/metrics.py create mode 100644 src/hirad/eval/plotting.py create mode 100644 src/hirad/eval/run_scoring.py diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py new file mode 100644 index 0000000..c8fda43 --- /dev/null +++ b/src/hirad/eval/metrics.py @@ -0,0 +1,21 @@ +import numpy as np +import torch + + +# set up MAE calculation to be run for each channel for a given date/time (for target COSMO, prediction, and ERA interpolated) + +# input will be a 2D tensor of values with the COSMO lat/lon. + +# Extracted from physicsnemo/examples/weather/regen/paper_figures/score_inference.py + +def absolute_error(pred, target): + return torch.abs(pred-target) + +def compute_mae(pred, target): + # Exclude any target NaNs (not expected, but precautionary) + mask = ~np.isnan(target) + pred = pred[:, mask] + target = target[mask] + + return torch.mean(absolute_error(pred, target)) + diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py new file mode 100644 index 0000000..141b5c9 --- /dev/null +++ b/src/hirad/eval/plotting.py @@ -0,0 +1,16 @@ +import logging + +import cartopy.crs as ccrs +import matplotlib.pyplot as plt +import numpy as np + +def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str): + fig = plt.figure() + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + logging.info(f'plotting values to {filename}') + p = ax.scatter(x=longitudes, y=latitudes, c=values) + ax.coastlines() + ax.gridlines(draw_labels=True) + plt.colorbar(p, label="absolute error", orientation="horizontal") + plt.savefig(filename) + plt.close('all') diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py new file mode 100644 index 0000000..57ca69f --- /dev/null +++ b/src/hirad/eval/run_scoring.py @@ -0,0 +1,36 @@ +import os +import sys + +import metrics +import plotting +import torch +import yaml + + +def main(): + if len(sys.argv) < 4: + raise ValueError('Expected call run_scoring.py [input data directory] [predictions directory] [date]') + + input_directory = sys.argv[1] + predictions_directory = sys.argv[2] + date = sys.argv[3] + + target = torch.load(os.path.join(input_directory, 'cosmo', date), weights_only=False) + baseline = torch.load(os.path.join(input_directory, 'era-interpolated', date), weights_only=False) + #prediction_file = torch.load(os.path.join(predictions_directory, date), weights_only=False) + prediction = torch.load(os.path.join(input_directory, 'cosmo', '20160101-0000'), weights_only=False) + lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) + + # Reshape grides to be the same as prediction + #target = target.squeeze().reshape(-1,*prediction.shape), + target = torch.from_numpy(target) + prediction = torch.from_numpy(prediction) + #prediction = prediction.squeeze().reshape(-1,*prediction.shape) + latitudes = lat_lon[:,0] #.squeeze().reshape(-1,*prediction.shape) + longitudes = lat_lon[:,1] #squeeze().reshape(-1,*prediction.shape) + + errors = metrics.absolute_error(prediction[0,:,:], target[0,:,:]) + plotting.plot_error_projection(errors, latitudes, longitudes, os.path.join('plots/errors/', date)) + +if __name__ == "__main__": + main() \ No newline at end of file From c6e632a1bdc52d9f932baf15beab93abc187d94f Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 12 May 2025 18:39:35 +0200 Subject: [PATCH 030/189] Adjust how reshaping is done before error calcs --- src/hirad/eval/run_scoring.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index 57ca69f..ce2c343 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -17,18 +17,25 @@ def main(): target = torch.load(os.path.join(input_directory, 'cosmo', date), weights_only=False) baseline = torch.load(os.path.join(input_directory, 'era-interpolated', date), weights_only=False) - #prediction_file = torch.load(os.path.join(predictions_directory, date), weights_only=False) - prediction = torch.load(os.path.join(input_directory, 'cosmo', '20160101-0000'), weights_only=False) + prediction = torch.load(os.path.join(predictions_directory, date), weights_only=False) lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) - # Reshape grides to be the same as prediction - #target = target.squeeze().reshape(-1,*prediction.shape), + with open(os.path.join(input_directory), 'info', 'cosmo.yaml') as cosmo_file: + cosmo_config = yaml.safe_load(cosmo_file) + channels = cosmo_config['select'] + + # Reshape predictions, if necessary + # target is shape [channels, ensembles, points] + # prediction is shape [channels, ensembles, x, y] + prediction = prediction.reshape(*target.shape) + + latitudes = lat_lon[:,0] + longitudes = lat_lon[:,1] + + # convert to torch target = torch.from_numpy(target) prediction = torch.from_numpy(prediction) - #prediction = prediction.squeeze().reshape(-1,*prediction.shape) - latitudes = lat_lon[:,0] #.squeeze().reshape(-1,*prediction.shape) - longitudes = lat_lon[:,1] #squeeze().reshape(-1,*prediction.shape) - + errors = metrics.absolute_error(prediction[0,:,:], target[0,:,:]) plotting.plot_error_projection(errors, latitudes, longitudes, os.path.join('plots/errors/', date)) From d21ec35ab055d313d7128163a941b909a3b44745 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 12 May 2025 18:58:47 +0200 Subject: [PATCH 031/189] plot for all channels, and against baseline --- src/hirad/eval/run_scoring.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index ce2c343..0cc3c1d 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -20,9 +20,13 @@ def main(): prediction = torch.load(os.path.join(predictions_directory, date), weights_only=False) lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) - with open(os.path.join(input_directory), 'info', 'cosmo.yaml') as cosmo_file: + with open(os.path.join(input_directory, 'info', 'cosmo.yaml')) as cosmo_file: cosmo_config = yaml.safe_load(cosmo_file) - channels = cosmo_config['select'] + target_channels = cosmo_config['select'] + + with open(os.path.join(input_directory, 'info', 'era.yaml')) as era_file: + era_config = yaml.safe_load(era_file) + input_channels = era_config['select'] # Reshape predictions, if necessary # target is shape [channels, ensembles, points] @@ -34,10 +38,18 @@ def main(): # convert to torch target = torch.from_numpy(target) + baseline = torch.from_numpy(baseline) prediction = torch.from_numpy(prediction) - errors = metrics.absolute_error(prediction[0,:,:], target[0,:,:]) - plotting.plot_error_projection(errors, latitudes, longitudes, os.path.join('plots/errors/', date)) + # plot baseline error + for t_c in range(len(target_channels)): + b_c = input_channels.index(target_channels[t_c]) + if b_c > -1: + baseline_errors = metrics.absolute_error(baseline[b_c,:,:], target[t_c,:,:]) + plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) + prediction_errors = metrics.absolute_error(prediction[t_c,:,:], target[t_c,:,:]) + plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + if __name__ == "__main__": main() \ No newline at end of file From 5b32566bf13bc19ae4de8c0f957ac0f44c6ce892 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 13 May 2025 09:14:41 +0200 Subject: [PATCH 032/189] Add MAE output --- src/hirad/eval/metrics.py | 6 ++++-- src/hirad/eval/run_scoring.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index c8fda43..133cc21 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -8,7 +8,7 @@ # Extracted from physicsnemo/examples/weather/regen/paper_figures/score_inference.py -def absolute_error(pred, target): +def absolute_error(pred, target) -> tuple[float, np.ndarray]: return torch.abs(pred-target) def compute_mae(pred, target): @@ -17,5 +17,7 @@ def compute_mae(pred, target): pred = pred[:, mask] target = target[mask] - return torch.mean(absolute_error(pred, target)) + ae = absolute_error(pred, target) + + return torch.mean(absolute_error(pred, target)), ae diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index 0cc3c1d..fee984e 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -41,15 +41,15 @@ def main(): baseline = torch.from_numpy(baseline) prediction = torch.from_numpy(prediction) - # plot baseline error + # plot errors for t_c in range(len(target_channels)): b_c = input_channels.index(target_channels[t_c]) if b_c > -1: - baseline_errors = metrics.absolute_error(baseline[b_c,:,:], target[t_c,:,:]) + baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - prediction_errors = metrics.absolute_error(prediction[t_c,:,:], target[t_c,:,:]) + prediction_mae, prediction_errors = metrics.compute_mae(prediction[t_c,:,:], target[t_c,:,:]) plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) - + print(f'baseline MAE={baseline_mae}, prediction MAE={prediction_mae}') if __name__ == "__main__": main() \ No newline at end of file From 60c8affbc46d240940d84255c9ab3b32e82cfc8b Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 13 May 2025 09:23:15 +0200 Subject: [PATCH 033/189] Fix indexing error --- src/hirad/eval/metrics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index 133cc21..e6e1afb 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -13,8 +13,9 @@ def absolute_error(pred, target) -> tuple[float, np.ndarray]: def compute_mae(pred, target): # Exclude any target NaNs (not expected, but precautionary) + # TODO: Fix the deprecated warning (index with dtype torch.bool instead of torch.uint8) mask = ~np.isnan(target) - pred = pred[:, mask] + pred = pred[mask] target = target[mask] ae = absolute_error(pred, target) From 5cc42e9c0de8b04f2f865c90a12bf928443e9ad5 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 14 May 2025 10:08:43 +0200 Subject: [PATCH 034/189] Try adding spectral graph --- .gitignore | 17 ++++++++++++++++- src/hirad/eval/plotting.py | 7 +++++++ src/hirad/eval/run_scoring.py | 19 +++++++++++++++---- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index c514c5d..dee6b07 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,19 @@ poetry.toml .ruff_cache/ # LSP config files -pyrightconfig.json \ No newline at end of file +pyrightconfig.json + +# output files +*.out +*.torch +plots/* +*.npz + +# conda +.conda/* + +# temp +temp.* + +# local script +interpolate.sh \ No newline at end of file diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 141b5c9..262109f 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -14,3 +14,10 @@ def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np. plt.colorbar(p, label="absolute error", orientation="horizontal") plt.savefig(filename) plt.close('all') + +def plot_power_spectrum(x, filename): + fig = plt.figure() + plt.psd(x) + logging.info(f'plotting values to {filename}') + plt.savefig(filename) + plt.close('all') \ No newline at end of file diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index fee984e..df37a9a 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -31,7 +31,8 @@ def main(): # Reshape predictions, if necessary # target is shape [channels, ensembles, points] # prediction is shape [channels, ensembles, x, y] - prediction = prediction.reshape(*target.shape) + prediction_1d = prediction.reshape(*target.shape) + prediction_2d = prediction.reshape(prediction.shape[0],352,544) latitudes = lat_lon[:,0] longitudes = lat_lon[:,1] @@ -39,7 +40,7 @@ def main(): # convert to torch target = torch.from_numpy(target) baseline = torch.from_numpy(baseline) - prediction = torch.from_numpy(prediction) + prediction_1d = torch.from_numpy(prediction_1d) # plot errors for t_c in range(len(target_channels)): @@ -47,9 +48,19 @@ def main(): if b_c > -1: baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - prediction_mae, prediction_errors = metrics.compute_mae(prediction[t_c,:,:], target[t_c,:,:]) - plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) + prediction_mae, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) + plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) + plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) print(f'baseline MAE={baseline_mae}, prediction MAE={prediction_mae}') + # Plot power spectra + freq, power = metrics.compute_power_spectrum(prediction, 1) + plotting.plot_power_spectrum(prediction, 'plots/errors/powerspec-prediction') + plotting.plot_power_spectrum(prediction, 'plots/errors/powerspec-prediction') + + + if __name__ == "__main__": main() \ No newline at end of file From 5b77dbebd984a35e3797ed40e97c6cdb1e03669c Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 15 May 2025 15:38:46 +0200 Subject: [PATCH 035/189] fix inference for diffusion --- pyproject.toml | 10 ++- src/hirad/conf/dataset/era_cosmo.yaml | 2 +- src/hirad/conf/generate_era_cosmo.yaml | 6 +- src/hirad/conf/generation/era_cosmo.yaml | 7 ++- .../conf/training/era_cosmo_diffusion.yaml | 10 +-- .../conf/training/era_cosmo_regression.yaml | 11 ++-- src/hirad/datasets/era5_cosmo.py | 8 +-- src/hirad/generate.sh | 51 +++++++++++++++ src/hirad/inference/generate.py | 62 ++++++++++++------- src/hirad/models/layers.py | 3 +- src/hirad/train.sh | 51 +++++++++++++++ src/hirad/training/train.py | 2 +- src/hirad/utils/inference_utils.py | 13 ++-- src/hirad/utils/stochastic_sampler.py | 60 ++++++++++-------- 14 files changed, 218 insertions(+), 78 deletions(-) create mode 100644 src/hirad/generate.sh create mode 100644 src/hirad/train.sh diff --git a/pyproject.toml b/pyproject.toml index b2fa56c..1477899 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,15 @@ requires-python = ">=3.12" license = {file = "LICENSE"} dependencies = [ - "torch>=2.6.0" + "cartopy>=0.24.1", + "cftime>=1.6.4", + "hydra-core>=1.3.2", + "matplotlib>=3.10.1", + "omegaconf>=2.3.0", + "tensorboard>=2.19.0", + "termcolor>=3.1.0", + "torchinfo>=1.8.0", + "treelib>=1.7.1" ] [tool.setuptools] diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml index 854b775..63d7361 100644 --- a/src/hirad/conf/dataset/era_cosmo.yaml +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -1,2 +1,2 @@ type: era5_cosmo -dataset_path: /store_new/mch/msopr/hirad-gen/basic-torch \ No newline at end of file +dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_overfit \ No newline at end of file diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml index 03650e2..5d7649d 100644 --- a/src/hirad/conf/generate_era_cosmo.yaml +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -1,13 +1,13 @@ hydra: job: chdir: true - name: generation + name: generation_full run: - dir: ./outputs/${hydra:job.name} + dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} # Get defaults defaults: - + - _self_ # Dataset - dataset/era_cosmo diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index 2e37a63..5179520 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -31,7 +31,8 @@ perf: # To support multiple workers a threadsafe version of the netCDF library must be used io: - res_ckpt_path: diffusion_checkpoint + res_ckpt_path: null # Checkpoint filename for the diffusion model - reg_ckpt_path: regression_checkpoint - # Checkpoint filename for the mean predictor model \ No newline at end of file + reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression + # Checkpoint filename for the mean predictor model + output_path: ./images \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index b61603a..b06ec61 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -1,8 +1,8 @@ # Hyperparameters hp: - training_duration: 128 + training_duration: 16 # Training duration based on the number of processed samples - total_batch_size: 16 + total_batch_size: 4 # Total batch size batch_size_per_gpu: "auto" # Batch size per GPU @@ -20,14 +20,14 @@ perf: fp_optimizations: amp-bf16 # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 4 + dataloader_workers: 8 # DataLoader worker processes songunet_checkpoint_level: 0 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint # I/O io: - regression_checkpoint_path: /scratch/mch/pstamenk/output/regression/checkpoints_regression + regression_checkpoint_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression # Where to load the regression checkpoint print_progress_freq: 32 # How often to print progress @@ -38,4 +38,4 @@ io: validation_steps: 10 # how many loss evaluations are used to compute the validation loss per checkpoint # how many loss evaluations are used to compute the validation loss per checkpoint - checkpoint_dir: /scratch/mch/pstamenk/output/diffusion \ No newline at end of file + checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index 7c443f0..76bdc4e 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,12 +1,13 @@ # Hyperparameters hp: - training_duration: 16 + training_duration: 8 # Training duration based on the number of processed samples - total_batch_size: 16 + total_batch_size: 4 # Total batch size batch_size_per_gpu: "auto" # Batch size per GPU - lr: 0.0002 + lr: 0.001 + #0.0002 # Learning rate grad_clip_threshold: null # no gradient clipping for defualt non-patch-based training @@ -27,7 +28,7 @@ perf: # I/O io: - print_progress_freq: 32 + print_progress_freq: 128 # How often to print progress save_checkpoint_freq: 5000 # How often to save the checkpoints, measured in number of processed samples @@ -35,4 +36,4 @@ io: # how often to record the validation loss, measured in number of processed samples validation_steps: 10 # how many loss evaluations are used to compute the validation loss per checkpoint - checkpoint_dir: /scratch/mch/pstamenk/output/regression \ No newline at end of file + checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index f8835d1..674dbf0 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -75,14 +75,14 @@ def __len__(self): def longitude(self) -> np.ndarray: """Get longitude values from the dataset.""" - lon_lat = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) - return lon_lat[:,0] + lat_lon = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lat_lon[:,1] def latitude(self) -> np.ndarray: """Get latitude values from the dataset.""" - lon_lat = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) - return lon_lat[:,1] + lat_lon = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lat_lon[:,0] def input_channels(self) -> List[ChannelMetadata]: diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh new file mode 100644 index 0000000..87c8979 --- /dev/null +++ b/src/hirad/generate.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a122 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# Use SLURM_NTASKS (number of processes to be launched by torchrun) +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute threads per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS +echo "Physical cores: $PHYSICAL_CORES" +echo "Local processes: $LOCAL_PROCS" +echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun bash -c " + . ./train_env/bin/activate + python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index adb882e..5558a20 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -11,14 +11,14 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial -from matplotlib import pyplot as plt import cartopy.crs as ccrs +from matplotlib import pyplot as plt from einops import rearrange from torch.distributed import gather from hydra.utils import to_absolute_path -from hirad.models import EDMPrecond, UNet +from hirad.models import EDMPrecondSR, UNet from hirad.utils.stochastic_sampler import stochastic_sampler from hirad.utils.deterministic_sampler import deterministic_sampler from hirad.utils.inference_utils import ( @@ -36,12 +36,12 @@ from hirad.utils.train_helpers import set_patch_shape -@hydra.main(version_base="1.2", config_path="conf", config_name="config_generate") +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig) -> None: """Generate random dowscaled atmospheric states using the techniques described in the paper "Elucidating the Design Space of Diffusion-Based Generative Models". """ - + torch.backends.cudnn.enabled = False # Initialize distributed manager DistributedManager.initialize() dist = DistributedManager() @@ -50,7 +50,7 @@ def main(cfg: DictConfig) -> None: # Initialize logger logger = PythonLogger("generate") # General python logger logger0 = RankZeroLoggingWrapper(logger, dist) - logger.file_logging("generate.log") + # logger.file_logging("generate.log") # Handle the batch size seeds = list(np.arange(cfg.generation.num_ensembles)) @@ -121,7 +121,7 @@ def main(cfg: DictConfig) -> None: with open(diffusion_model_args_path, 'r') as f: diffusion_model_args = json.load(f) - net_res = EDMPrecond(**diffusion_model_args) + net_res = EDMPrecondSR(**diffusion_model_args) _ = load_checkpoint( path=res_ckpt_path, @@ -129,7 +129,8 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) + #TODO fix to use channels_last which is optimal for H100 + net_res = net_res.eval().to(device)#.to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_res.use_fp16 = True else: @@ -155,7 +156,7 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - net_reg = net_reg.eval().to(device).to(memory_format=torch.channels_last) + net_reg = net_reg.eval().to(device)#.to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_reg.use_fp16 = True else: @@ -184,8 +185,9 @@ def main(cfg: DictConfig) -> None: elif cfg.sampler.type == "stochastic": sampler_fn = partial( stochastic_sampler, - img_shape=img_shape[1], - patch_shape=patch_shape[1], + img_shape=img_shape, + patch_shape_x=patch_shape[0], + patch_shape_y=patch_shape[1], boundary_pix=cfg.sampler.boundary_pix, overlap_pix=cfg.sampler.overlap_pix, ) @@ -194,7 +196,7 @@ def main(cfg: DictConfig) -> None: # Main generation definition - def generate_fn(image_lr, lead_time_label): + def generate_fn(image_lr, labels, lead_time_label): img_shape_y, img_shape_x = img_shape with nvtx.annotate("generate_fn", color="green"): if cfg.generation.sample_res == "full": @@ -208,13 +210,14 @@ def generate_fn(image_lr, lead_time_label): w1=img_shape_x // patch_shape[1], ) torch.cuda.nvtx.range_pop() - image_lr_patch = image_lr_patch.to(memory_format=torch.channels_last) + image_lr_patch = image_lr_patch #.to(memory_format=torch.channels_last) if net_reg: with nvtx.annotate("regression_model", color="yellow"): image_reg = regression_step( net=net_reg, img_lr=image_lr_patch, + labels=labels, latents_shape=( cfg.generation.seed_batch_size, img_out_channels, @@ -238,7 +241,7 @@ def generate_fn(image_lr, lead_time_label): rank_batches=rank_batches, img_lr=image_lr_patch.expand( cfg.generation.seed_batch_size, -1, -1, -1 - ).to(memory_format=torch.channels_last), + ), #.to(memory_format=torch.channels_last), rank=dist.rank, device=device, hr_mean=mean_hr, @@ -282,7 +285,10 @@ def generate_fn(image_lr, lead_time_label): else: return None else: - return image_out + #TODO do this for multi-gpu setting above too + if cfg.generation.inference_mode != "regression": + return image_out, image_reg + return image_out, None # generate images output_path = getattr(cfg.generation.io, "output_path", "./outputs") @@ -311,7 +317,7 @@ def generate_fn(image_lr, lead_time_label): end = torch.cuda.Event(enable_timing=True) times = dataset.time() - for image_tar, image_lr, index, *lead_time_label in iter(data_loader): + for image_tar, image_lr, labels, *lead_time_label in iter(data_loader): time_index += 1 if dist.rank == 0: logger0.info(f"starting index: {time_index}") @@ -327,10 +333,11 @@ def generate_fn(image_lr, lead_time_label): image_lr = ( image_lr.to(device=device) .to(torch.float32) - .to(memory_format=torch.channels_last) + #.to(memory_format=torch.channels_last) ) image_tar = image_tar.to(device=device).to(torch.float32) - image_out = generate_fn(image_lr,lead_time_label) + labels = labels.to(device).to(torch.float32).contiguous() + image_out, image_reg = generate_fn(image_lr,labels,lead_time_label) if dist.rank == 0: batch_size = image_out.shape[0] # write out data in a seperate thread so we don't hold up inferencing @@ -342,6 +349,7 @@ def generate_fn(image_lr, lead_time_label): image_out.cpu(), image_tar.cpu(), image_lr.cpu(), + image_reg.cpu(), ) ) end.record() @@ -368,19 +376,29 @@ def generate_fn(image_lr, lead_time_label): f.close() logger0.info("Generation Completed.") -def save_images(output_path, dataset, image_pred, image_hr, image_lr): +def save_images(output_path, dataset, image_pred, image_hr, image_lr, mean_pred): longitudes = dataset.longitude() latitudes = dataset.latitude() input_channels = dataset.input_channels() output_channels = dataset.output_channels() - image_pred = np.flip(dataset.denormalize_output(image_pred.numpy()),1).reshape(len(output_channels),-1) - image_hr = np.flip(dataset.denormalize_output(image_hr.numpy()),1).reshape(len(output_channels),-1) - image_lr = np.flip(dataset.denormalize_input(image_lr.numpy()),1).reshape(len(input_channels),-1) + image_pred = image_pred.numpy() + image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) + image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) + image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[32,::].squeeze()),1).reshape(len(output_channels),-1) + image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) + image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) + if mean_pred is not None: + mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) + os.makedirs(output_path, exist_ok=True) for idx, channel in enumerate(output_channels): input_channel_idx = input_channels.index(channel) _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{channel.name}-lr.jpg')) _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{channel.name}-hr.jpg')) - _plot_projection(longitudes,latitudes,image_pred[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred.jpg')) + _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred.jpg')) + _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred-0.jpg')) + _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred-mid.jpg')) + if mean_pred is not None: + _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{channel.name}-mean-pred.jpg')) def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index ddb23b6..8612da7 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -221,6 +221,8 @@ def forward(self, x): padding=f_pad, ) if w is not None: + #TODO during inference, model breaks here for some reason + # current fix is to disable torch.backends.cudnn.enabled = False x = torch.nn.functional.conv2d(x, w, padding=w_pad) if b is not None: x = x.add_(b.reshape(1, -1, 1, 1)) @@ -473,7 +475,6 @@ def forward(self, x, emb): torch.cuda.nvtx.range_push("UNetBlock") orig = x x = self.conv0(silu(self.norm0(x))) - params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) if self.adaptive_scale: scale, shift = params.chunk(chunks=2, dim=1) diff --git a/src/hirad/train.sh b/src/hirad/train.sh new file mode 100644 index 0000000..a31cec0 --- /dev/null +++ b/src/hirad/train.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a122 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# Use SLURM_NTASKS (number of processes to be launched by torchrun) +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute threads per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS +echo "Physical cores: $PHYSICAL_CORES" +echo "Local processes: $LOCAL_PROCS" +echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun bash -c " + . ./train_env/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml +" \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 9ce619d..37e6110 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -213,7 +213,7 @@ def main(cfg: DictConfig) -> None: model.train().requires_grad_(True).to(dist.device) # TODO write summry from rank=0 possibly - # summary(model, input_size=[(4,img_out_channels,*img_shape),(4,img_in_channels,*img_shape),(4,1),(4,1)]) + # summary(model, input_size=[(1,img_out_channels,*img_shape),(1,img_in_channels,*img_shape),(1,1)]) if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index b158ec0..4831bdd 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -31,6 +31,7 @@ def regression_step( net: torch.nn.Module, img_lr: torch.Tensor, + labels: torch.Tensor, latents_shape: torch.Size, lead_time_label: torch.Tensor = None, ) -> torch.Tensor: @@ -50,15 +51,15 @@ def regression_step( torch.Tensor: Predicted output at the next time step. """ # Create a tensor of zeros with the given shape and move it to the appropriate device - x_hat = torch.zeros(latents_shape, dtype=torch.float64, device=net.device) - t_hat = torch.tensor(1.0, dtype=torch.float64, device=net.device) + x_hat = torch.zeros(latents_shape, dtype=img_lr.dtype, device=img_lr.device) + t_hat = torch.tensor(1.0, dtype=img_lr.dtype, device=img_lr.device).reshape((1,1,1,1)) # Perform regression on a single batch element with torch.inference_mode(): if lead_time_label is not None: - x = net(x_hat[0:1], img_lr, t_hat, lead_time_label=lead_time_label) + x = net(x_hat, img_lr, t_hat, labels, lead_time_label=lead_time_label) else: - x = net(x_hat[0:1], img_lr, t_hat) + x = net(x_hat, img_lr, t_hat, labels) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: @@ -100,7 +101,7 @@ def diffusion_step( # TODO generalize the module and add defaults torch.Tensor: Generated images concatenated across batches. """ - img_lr = img_lr.to(memory_format=torch.channels_last) + img_lr = img_lr #.to(memory_format=torch.channels_last) # Handling of the high-res mean additional_args = {} @@ -128,7 +129,7 @@ def diffusion_step( # TODO generalize the module and add defaults img_shape[1], ], device=device, - ).to(memory_format=torch.channels_last) + )#.to(memory_format=torch.channels_last) with torch.inference_mode(): images = sampler_fn( diff --git a/src/hirad/utils/stochastic_sampler.py b/src/hirad/utils/stochastic_sampler.py index ddcf9cc..ac5c13b 100644 --- a/src/hirad/utils/stochastic_sampler.py +++ b/src/hirad/utils/stochastic_sampler.py @@ -292,8 +292,9 @@ def stochastic_sampler( img_lr: Tensor, class_labels: Optional[Tensor] = None, randn_like: Callable[[Tensor], Tensor] = torch.randn_like, - img_shape: int = 448, - patch_shape: int = 448, + img_shape: tuple[int,int] = (448,448), + patch_shape_x: int = 448, + patch_shape_y: int = 448, overlap_pix: int = 4, boundary_pix: int = 2, mean_hr: Optional[Tensor] = None, @@ -360,12 +361,13 @@ def stochastic_sampler( "Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution." sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) - if isinstance(img_shape, tuple): - img_shape_y, img_shape_x = img_shape - else: - img_shape_x = img_shape_y = img_shape - if patch_shape > img_shape_x or patch_shape > img_shape_y: - patch_shape = min(img_shape_x, img_shape_y) + # if isinstance(img_shape, tuple): + # img_shape_y, img_shape_x = img_shape + # else: + # img_shape_x = img_shape_y = img_shape + img_shape_x, img_shape_y = img_shape + patch_shape_x = min(img_shape_x, patch_shape_x) + patch_shape_y = min(img_shape_y, patch_shape_y) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) @@ -394,16 +396,16 @@ def stochastic_sampler( global_index = None # input and position padding + patching - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: input_interp = torch.nn.functional.interpolate( - img_lr, (patch_shape, patch_shape), mode="bilinear" + img_lr, (patch_shape_x, patch_shape_y), mode="bilinear" ) x_lr = image_batching( x_lr, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -413,8 +415,8 @@ def stochastic_sampler( grid.float(), img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -433,13 +435,13 @@ def stochastic_sampler( # Euler step. Perform patching operation on score tensor if patch-based generation is used # denoised = net(x_hat, t_hat, class_labels,lead_time_label=lead_time_label).to(torch.float64) #x_lr - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: x_hat_batch = image_batching( x_hat, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -461,6 +463,12 @@ def stochastic_sampler( global_index=global_index, ).to(torch.float64) else: + # print("Sizes") + # print(x_hat_batch.shape) + # print(x_lr.shape) + # print(t_hat) + # print(class_labels) + # print(global_index) denoised = net( x_hat_batch, x_lr, @@ -468,14 +476,14 @@ def stochastic_sampler( class_labels, global_index=global_index, ).to(torch.float64) - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: denoised = image_fuse( denoised, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -485,13 +493,13 @@ def stochastic_sampler( # Apply 2nd order correction. if i < num_steps - 1: - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: x_next_batch = image_batching( x_next, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -517,13 +525,13 @@ def stochastic_sampler( class_labels, global_index=global_index, ).to(torch.float64) - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: denoised = image_fuse( denoised, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, From 83716f45821e6f15f04422b140bc90f493a3be79 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 15 May 2025 15:55:32 +0200 Subject: [PATCH 036/189] clean up --- src/hirad/conf/generation/era_cosmo.yaml | 11 ++--- src/hirad/testrun.sh | 51 ------------------------ 2 files changed, 6 insertions(+), 56 deletions(-) delete mode 100644 src/hirad/testrun.sh diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index 5179520..a0c5a40 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -2,7 +2,7 @@ num_ensembles: 64 # Number of ensembles to generate per input seed_batch_size: 1 # Size of the batched inference -inference_mode: regression +inference_mode: all # Choose between "all" (regression + diffusion), "regression" or "diffusion" # Patch size. Patch-based sampling will be utilized if these dimensions differ from # img_shape_x and img_shape_y @@ -11,7 +11,7 @@ overlap_pixels: 0 boundary_pixels: 0 # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary # artifact. -hr_mean_conditioning: False +hr_mean_conditioning: True sample_res: full # Sampling resolution times_range: null @@ -19,10 +19,10 @@ times: - 20160101-0000 perf: - force_fp16: false + force_fp16: False # Whether to force fp16 precision for the model. If false, it'll use the precision # specified upon training. - use_torch_compile: false + use_torch_compile: False # whether to use torch.compile on the diffusion model # this will make the first time stamp generation very slow due to compilation overheads # but will significantly speed up subsequent inference runs @@ -31,8 +31,9 @@ perf: # To support multiple workers a threadsafe version of the netCDF library must be used io: - res_ckpt_path: null + res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_test/checkpoints_diffusion # Checkpoint filename for the diffusion model reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression + # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression # Checkpoint filename for the mean predictor model output_path: ./images \ No newline at end of file diff --git a/src/hirad/testrun.sh b/src/hirad/testrun.sh deleted file mode 100644 index ac631c8..0000000 --- a/src/hirad/testrun.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="testrun" - -### HARDWARE ### -#SBATCH --partition=debug -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 -#SBATCH --cpus-per-task=72 -#SBATCH --time=00:30:00 -#SBATCH --no-requeue -#SBATCH --exclusive - -### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.err - -### ENVIRONMENT #### -#SBATCH --uenv=pytorch/v2.6.0:/user-environment -#SBATCH --view=default -#SBATCH -A a-a01 - -# Choose method to initialize dist in pythorch -export DISTRIBUTED_INITIALIZATION_METHOD=SLURM - -MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" -echo "Master node : $MASTER_ADDR" -# Get IP for hostname. -MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" -echo "Master address : $MASTER_ADDR" -export MASTER_ADDR -export MASTER_PORT=29500 -echo "Master port: $MASTER_PORT" - -# Get number of physical cores using Python -PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -# Use SLURM_NTASKS (number of processes to be launched by torchrun) -LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute threads per process -OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) -export OMP_NUM_THREADS=$OMP_THREADS -echo "Physical cores: $PHYSICAL_CORES" -echo "Local processes: $LOCAL_PROCS" -echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" - -# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun bash -c " - . ./train_env/bin/activate - python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml -" \ No newline at end of file From 69f10ddc0f3964eebb116233489d0d55f6b34077 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 15 May 2025 16:45:29 +0200 Subject: [PATCH 037/189] Add power spectrum plots --- src/hirad/eval/metrics.py | 39 +++++++- src/hirad/eval/plotting.py | 11 ++- src/hirad/eval/run_scoring.py | 165 +++++++++++++++++++++++++++------- 3 files changed, 178 insertions(+), 37 deletions(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index e6e1afb..1170ea1 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -1,6 +1,10 @@ +import logging + import numpy as np import torch +from scipy.signal import periodogram + # set up MAE calculation to be run for each channel for a given date/time (for target COSMO, prediction, and ERA interpolated) @@ -9,7 +13,7 @@ # Extracted from physicsnemo/examples/weather/regen/paper_figures/score_inference.py def absolute_error(pred, target) -> tuple[float, np.ndarray]: - return torch.abs(pred-target) + return np.abs(pred-target) def compute_mae(pred, target): # Exclude any target NaNs (not expected, but precautionary) @@ -20,5 +24,34 @@ def compute_mae(pred, target): ae = absolute_error(pred, target) - return torch.mean(absolute_error(pred, target)), ae - + # TODO, consider adding axis=-1 to choose what axis to average + return np.mean(absolute_error(pred, target)), ae + +def average_power_spectrum(data: np.ndarray, d=2.0): # d=2km by default + """ + Compute the average power spectrum of a data array. + + This function calculates the power spectrum for each row of the input data and + then averages them to obtain the overall power spectrum, repeating until + dimensionality is reduced to 1D. + The power spectrum represents the distribution of signal power as a function of frequency. + + Parameters: + data (numpy.ndarray): Input data array. + d (float): Sampling interval (time between data points). + + Returns: + tuple: A tuple containing the frequency values and the average power spectrum. + - freqs (numpy.ndarray): Frequency values corresponding to the power spectrum. + - power_spectra (numpy.ndarray): Average power spectrum of the input data. + """ + # Compute the power spectrum along the highest dimension for each row + freqs, power_spectra = periodogram(data, fs=1 / d, axis=-1) + logging.info(f'freqs.shape={freqs.shape}, power_spectra.shape={power_spectra.shape}') + + # Average along the first dimension + while power_spectra.ndim > 1: + power_spectra = power_spectra.mean(axis=0) + logging.info(f'power spectra shape={power_spectra.shape}') + + return freqs, power_spectra diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 262109f..1ca11c2 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -15,9 +15,16 @@ def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np. plt.savefig(filename) plt.close('all') -def plot_power_spectrum(x, filename): +def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): fig = plt.figure() - plt.psd(x) + for k in freqs.keys(): + plt.loglog(freqs[k], spec[k], label=k) + plt.title(channel_name) + plt.legend() + plt.xlabel("Frequency (1/km)") + plt.ylabel("Power Spectrum") + plt.ylim(bottom=1e-1) + #plt.psd(x) logging.info(f'plotting values to {filename}') plt.savefig(filename) plt.close('all') \ No newline at end of file diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index df37a9a..da98a04 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -2,23 +2,22 @@ import sys import metrics +import numpy as np import plotting import torch import yaml +X = 352 # length of grid from N-S +Y = 544 # length of grid from E-W def main(): - if len(sys.argv) < 4: - raise ValueError('Expected call run_scoring.py [input data directory] [predictions directory] [date]') + # TODO: Better arg parsing. + if len(sys.argv) < 3: + raise ValueError('Expected call run_scoring.py [input data directory] [predictions directory] [output plot directory]') input_directory = sys.argv[1] predictions_directory = sys.argv[2] - date = sys.argv[3] - - target = torch.load(os.path.join(input_directory, 'cosmo', date), weights_only=False) - baseline = torch.load(os.path.join(input_directory, 'era-interpolated', date), weights_only=False) - prediction = torch.load(os.path.join(predictions_directory, date), weights_only=False) - lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) + output_directory = sys.argv[3] with open(os.path.join(input_directory, 'info', 'cosmo.yaml')) as cosmo_file: cosmo_config = yaml.safe_load(cosmo_file) @@ -28,37 +27,139 @@ def main(): era_config = yaml.safe_load(era_file) input_channels = era_config['select'] - # Reshape predictions, if necessary - # target is shape [channels, ensembles, points] - # prediction is shape [channels, ensembles, x, y] - prediction_1d = prediction.reshape(*target.shape) - prediction_2d = prediction.reshape(prediction.shape[0],352,544) - + lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) latitudes = lat_lon[:,0] longitudes = lat_lon[:,1] - - # convert to torch - target = torch.from_numpy(target) - baseline = torch.from_numpy(baseline) - prediction_1d = torch.from_numpy(prediction_1d) - # plot errors + # Iterate over all files in the ground truth directory + files = os.listdir(os.path.join(input_directory, 'cosmo')) + files = sorted(files) + + + # Plot power spectra + # TODO: Handle ensembles + prediction_tensor = np.ndarray([len(files), len(target_channels), X, Y]) + baseline_tensor = np.ndarray([len(files), len(input_channels), X, Y]) + target_tensor = np.ndarray([len(files), len(target_channels), X, Y]) + + for i in range(len(files)): + datetime = files[i] + target = torch.load(os.path.join(input_directory, 'cosmo', datetime), weights_only=False) + baseline = torch.load(os.path.join(input_directory, 'era-interpolated', datetime), weights_only=False) + prediction = torch.load(os.path.join(predictions_directory, datetime), weights_only=False) + + # TODO: Handle ensembles + prediction_1d = prediction.reshape(prediction.shape[0], X*Y) + prediction_2d = prediction.reshape(prediction.shape[0], X, Y) + + baseline_1d = baseline.reshape(baseline.shape[0], X*Y) + baseline_2d = baseline.reshape(baseline.shape[0], X, Y) + + target_1d = target.reshape(target.shape[0], X*Y) + target_2d = target.reshape(target.shape[0], X, Y) + + baseline_tensor[i, :] = baseline_2d + prediction_tensor[i, :] = prediction_2d + target_tensor[i,:] = target_2d + + + # Calc spectra for t_c in range(len(target_channels)): b_c = input_channels.index(target_channels[t_c]) + freqs = {} + power = {} if b_c > -1: - baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) - plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) - prediction_mae, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) - plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) - plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) - plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) - print(f'baseline MAE={baseline_mae}, prediction MAE={prediction_mae}') + b_freq, b_power = metrics.average_power_spectrum(baseline_tensor[:,b_c,:,:].squeeze(), 2.0) + freqs['baseline'] = b_freq + power['baseline'] = b_power + #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) + t_freq, t_power = metrics.average_power_spectrum(target_tensor[:,t_c,:,:].squeeze(), 2.0) + freqs['target'] = t_freq + power['target'] = t_power + #p_freq, p_power = metrics.average_power_spectrum(prediction_tensor[:,t_c,:,:].squeeze(), 2.0) + #freqs['prediction'] = p_freq + #power['prediction'] = p_power + plotting.plot_power_spectra(freqs, power, target_channels[t_c], os.path.join(output_directory, 'spectra', target_channels[t_c] + '-alldates')) + + # store MAE as tensor of date:channel:ensembles:points + # TODO: Handle ensembles + baseline_absolute_error = np.ndarray([len(files),len(target_channels),1,X*Y]) + prediction_absolute_error = np.ndarray([len(files),len(target_channels),1,X*Y]) + + for i in range(len(files)): + datetime = files[i] + target = torch.load(os.path.join(input_directory, 'cosmo', datetime), weights_only=False) + baseline = torch.load(os.path.join(input_directory, 'era-interpolated', datetime), weights_only=False) + prediction = torch.load(os.path.join(predictions_directory, datetime), weights_only=False) + + + prediction_1d = prediction.reshape(prediction.shape[0], 1, X*Y) + prediction_2d = prediction.reshape(prediction.shape[0], 1, X, Y) + + # Get MAE + for t_c in range(len(target_channels)): + b_c = input_channels.index(target_channels[t_c]) + if b_c > -1: + _, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) + baseline_absolute_error[i, t_c, :, :] = baseline_errors + #plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) + #plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) + _, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) + prediction_absolute_error[i, t_c, :, :] = prediction_errors + #plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + #plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) + #plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) + + + print(f'baseline_absolute_error.shape={baseline_absolute_error.shape}, prediction_absolute_error.shape={prediction_absolute_error.shape}') + # Average errors over ensembles + baseline_mae = np.mean(baseline_absolute_error, axis=2) + prediction_mae = np.mean(prediction_absolute_error, axis=2) + + # Average errors over time + baseline_mae = np.mean(baseline_mae, axis=0) + prediction_mae = np.mean(prediction_mae, axis = 0) + + print(f'baseline mean error = {np.mean(baseline_mae, axis=-1)}') + print(f'prediction mean error = {np.mean(prediction_mae, axis=-1)}') + + # Plot the mean error onto the grid. + for t_c in range(len(target_channels)): + plotting.plot_error_projection(baseline_mae[t_c,:], latitudes, longitudes, os.path.join(output_directory, 'baseline-error' + target_channels[t_c] + '-' + 'average_over_time')) + plotting.plot_error_projection(prediction_mae[t_c,:], latitudes, longitudes, os.path.join(output_directory, 'prediction-error' + target_channels[t_c] + '-' + 'average_over_time')) + + + + + #for i in range(4): + # dates = ['20160101-0000', '20160115-0000', '20160201-0000', '20160215-0000'] + # pred = torch.load(os.path.join(predictions_directory, dates[i]), weights_only=False) + # base = torch.load(os.path.join(input_directory, 'era-interpolated', dates[i]), weights_only=False) + # pred_2d = pred.reshape(pred.shape[0],352,544) + # base_2d = base.reshape(baseline.shape[0],352,544) + # base_2d = np.transpose(base_2d, (0,-1,-2)) + # preds_tensor[i,:] = pred_2d + # baseline_tensor[i,:] = base_2d + #for t_c in range(len(target_channels)): + # freq, power = metrics.average_power_spectrum(baseline_tensor[:,t_c,:,:].squeeze(), 2) + # b_c = input_channels.index(target_channels[t_c]) + ## if b_c > -1: + # plotting.plot_power_spectrum(freq, power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + date)) + + + # plot errors + #for t_c in range(len(target_channels)): + # b_c = input_channels.index(target_channels[t_c]) + # if b_c > -1: + # baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) + # plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) + # #plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) + # prediction_mae, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) + # plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + # #plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) + #plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) + - # Plot power spectra - freq, power = metrics.compute_power_spectrum(prediction, 1) - plotting.plot_power_spectrum(prediction, 'plots/errors/powerspec-prediction') - plotting.plot_power_spectrum(prediction, 'plots/errors/powerspec-prediction') From b4c97c51baf4d543f6365d193afeab68b01dd317 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 15 May 2025 16:46:17 +0200 Subject: [PATCH 038/189] clean up a bit --- src/hirad/eval/run_scoring.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index da98a04..fd7c2ab 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -129,39 +129,5 @@ def main(): plotting.plot_error_projection(prediction_mae[t_c,:], latitudes, longitudes, os.path.join(output_directory, 'prediction-error' + target_channels[t_c] + '-' + 'average_over_time')) - - - #for i in range(4): - # dates = ['20160101-0000', '20160115-0000', '20160201-0000', '20160215-0000'] - # pred = torch.load(os.path.join(predictions_directory, dates[i]), weights_only=False) - # base = torch.load(os.path.join(input_directory, 'era-interpolated', dates[i]), weights_only=False) - # pred_2d = pred.reshape(pred.shape[0],352,544) - # base_2d = base.reshape(baseline.shape[0],352,544) - # base_2d = np.transpose(base_2d, (0,-1,-2)) - # preds_tensor[i,:] = pred_2d - # baseline_tensor[i,:] = base_2d - #for t_c in range(len(target_channels)): - # freq, power = metrics.average_power_spectrum(baseline_tensor[:,t_c,:,:].squeeze(), 2) - # b_c = input_channels.index(target_channels[t_c]) - ## if b_c > -1: - # plotting.plot_power_spectrum(freq, power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + date)) - - - # plot errors - #for t_c in range(len(target_channels)): - # b_c = input_channels.index(target_channels[t_c]) - # if b_c > -1: - # baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) - # plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - # #plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) - # prediction_mae, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) - # plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) - # #plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) - #plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) - - - - - if __name__ == "__main__": main() \ No newline at end of file From 90c7e28719a83d8534c102240faa756b0883f01e Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 15 May 2025 16:47:28 +0200 Subject: [PATCH 039/189] clean up a bit --- src/hirad/eval/run_scoring.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index fd7c2ab..4f2fcd8 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -76,7 +76,8 @@ def main(): t_freq, t_power = metrics.average_power_spectrum(target_tensor[:,t_c,:,:].squeeze(), 2.0) freqs['target'] = t_freq power['target'] = t_power - #p_freq, p_power = metrics.average_power_spectrum(prediction_tensor[:,t_c,:,:].squeeze(), 2.0) + p_freq, p_power = metrics.average_power_spectrum(prediction_tensor[:,t_c,:,:].squeeze(), 2.0) + # TODO: Uncomment when we have predictions #freqs['prediction'] = p_freq #power['prediction'] = p_power plotting.plot_power_spectra(freqs, power, target_channels[t_c], os.path.join(output_directory, 'spectra', target_channels[t_c] + '-alldates')) @@ -102,14 +103,9 @@ def main(): if b_c > -1: _, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) baseline_absolute_error[i, t_c, :, :] = baseline_errors - #plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - #plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) _, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) prediction_absolute_error[i, t_c, :, :] = prediction_errors - #plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) - #plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) - #plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) - + print(f'baseline_absolute_error.shape={baseline_absolute_error.shape}, prediction_absolute_error.shape={prediction_absolute_error.shape}') # Average errors over ensembles From a9056027b6ea337c6417a2997e6a9419740e33b8 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 15 May 2025 17:49:32 +0200 Subject: [PATCH 040/189] add readme for training --- README.md | 110 ++++++++++++++++++ .../conf/training/era_cosmo_diffusion.yaml | 2 +- src/hirad/train_diffusion.sh | 45 +++++++ src/hirad/{train.sh => train_regression.sh} | 16 +-- 4 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 src/hirad/train_diffusion.sh rename src/hirad/{train.sh => train_regression.sh} (73%) diff --git a/README.md b/README.md index e69de29..3e66062 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,110 @@ +# HiRAD-Gen + +HiRAD-Gen is short for high-resolution atmospheric downscaling using generative models. This repository contains the code and configuration required to train and use the model. + +## Installation (Alps) + +To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these steps: + +1. **Start the PyTorch user environment**: + ```bash + uenv start pytorch/v2.6.0:v1 --view=default + ``` + +2. **Create a Python virtual environment** (replace `{env_name}` with your desired environment name): + ```bash + python -m venv ./{env_name} + ``` + +3. **Activate the virtual environment**: + ```bash + source ./{env_name}/bin/activate + ``` + +4. **Install project dependencies**: + ```bash + pip install -e . + ``` + +This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure. + +## Run regression model training (Alps) + +1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. +Inside this script set the following: +```bash +### OUTPUT ### +#SBATCH --output=your_path_to_output_log +#SBATCH --error=your_path_to_output_error +``` +```bash +#SBATCH -A your_compute_group +``` +```bash +srun bash -c " + . ./{your_env_name}/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml +" +``` + +2. Setup the following config files in `src/hirad/conf`: + +- In `training_era_cosmo_regression.yaml` set: +``` +hydra: + run: + dir: your_path_to_save_training_output +``` +- In `training/era_cosmo_regression.yaml` set: +``` +hp: + training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) +``` +- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. + +3. Submit the job with: +```bash +sbatch src/hirad/train_regression.sh +``` + +## Run diffusion model training (Alps) +Before training diffusion model, checkpoint for regression model has to exist. + +1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. +Inside this script set the following: +```bash +### OUTPUT ### +#SBATCH --output=your_path_to_output_log +#SBATCH --error=your_path_to_output_error +``` +```bash +#SBATCH -A your_compute_group +``` +```bash +srun bash -c " + . ./{your_env_name}/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml +" +``` + +2. Setup the following config files in `src/hirad/conf`: + +- In `training_era_cosmo_diffusion.yaml` set: +``` +hydra: + run: + dir: your_path_to_save_training_output +``` +- In `training/era_cosmo_regression.yaml` set: +``` +hp: + training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) +io: + regression_checkpoint_path: path_to_directory_containing_regression_training_model_checkpoints +``` +- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. + +3. Submit the job with: +```bash +sbatch src/hirad/train_diffusion.sh +``` \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index b06ec61..f8d19e6 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -29,7 +29,7 @@ perf: io: regression_checkpoint_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression # Where to load the regression checkpoint - print_progress_freq: 32 + print_progress_freq: 128 # How often to print progress save_checkpoint_freq: 5000 # How often to save the checkpoints, measured in number of processed samples diff --git a/src/hirad/train_diffusion.sh b/src/hirad/train_diffusion.sh new file mode 100644 index 0000000..cf2f88f --- /dev/null +++ b/src/hirad/train_diffusion.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a122 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +# Get master node. +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +export MASTER_ADDR +export MASTER_PORT=29500 + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute cores per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun bash -c " + . ./train_env/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml +" \ No newline at end of file diff --git a/src/hirad/train.sh b/src/hirad/train_regression.sh similarity index 73% rename from src/hirad/train.sh rename to src/hirad/train_regression.sh index a31cec0..c065477 100644 --- a/src/hirad/train.sh +++ b/src/hirad/train_regression.sh @@ -13,8 +13,8 @@ #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.err +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression.err ### ENVIRONMENT #### #SBATCH --uenv=pytorch/v2.6.0:/user-environment @@ -24,28 +24,22 @@ # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM +# Get master node. MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" -echo "Master node : $MASTER_ADDR" # Get IP for hostname. MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" -echo "Master address : $MASTER_ADDR" export MASTER_ADDR export MASTER_PORT=29500 -echo "Master port: $MASTER_PORT" # Get number of physical cores using Python PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -# Use SLURM_NTASKS (number of processes to be launched by torchrun) LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute threads per process +# Compute cores per process OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) export OMP_NUM_THREADS=$OMP_THREADS -echo "Physical cores: $PHYSICAL_CORES" -echo "Local processes: $LOCAL_PROCS" -echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml srun bash -c " . ./train_env/bin/activate - python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml + python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml " \ No newline at end of file From 573dc2387af0851ee16de51c6bbda4e16519d57d Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 16 May 2025 13:16:20 +0200 Subject: [PATCH 041/189] update readme for inference --- README.md | 65 +++++++++++++++++-- .../conf/training_era_cosmo_diffusion.yaml | 4 +- .../conf/training_era_cosmo_regression.yaml | 2 +- 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 3e66062..b0dbd2e 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,9 @@ To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure. -## Run regression model training (Alps) +## Training + +### Run regression model training (Alps) 1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. Inside this script set the following: @@ -47,7 +49,7 @@ srun bash -c " " ``` -2. Setup the following config files in `src/hirad/conf`: +2. Set up the following config files in `src/hirad/conf`: - In `training_era_cosmo_regression.yaml` set: ``` @@ -67,7 +69,7 @@ hp: sbatch src/hirad/train_regression.sh ``` -## Run diffusion model training (Alps) +### Run diffusion model training (Alps) Before training diffusion model, checkpoint for regression model has to exist. 1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. @@ -87,7 +89,7 @@ srun bash -c " " ``` -2. Setup the following config files in `src/hirad/conf`: +2. Set up the following config files in `src/hirad/conf`: - In `training_era_cosmo_diffusion.yaml` set: ``` @@ -107,4 +109,59 @@ io: 3. Submit the job with: ```bash sbatch src/hirad/train_diffusion.sh +``` + +## Inference + +### Running inference on Alps + +1. Script for running the inference is in `src/hirad/generate.sh`. +Inside this script set the following: +```bash +### OUTPUT ### +#SBATCH --output=your_path_to_output_log +#SBATCH --error=your_path_to_output_error +``` +```bash +#SBATCH -A your_compute_group +``` +```bash +srun bash -c " + . ./{your_env_name}/bin/activate + python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml +" +``` + +2. Set up the following config files in `src/hirad/conf`: + +- In `generate_era_cosmo.yaml` set: +``` +hydra: + run: + dir: your_path_to_save_inference_output +``` +- In `generation/era_cosmo.yaml`: +Choose the inference mode: +``` +inference_mode: all/regression/diffusion +``` +by default `all` does both regression and diffusion. Depending on mode, regression and/or diffusion model pretrained weights should be provided: +``` +io: + res_ckpt_path: path_to_directory_containing_diffusion_training_model_checkpoints + reg_ckpt_path: path_to_directory_containing_regression_training_model_checkpoints +``` +Finally, from the dataset, subset of time steps can be chosen to do inference for. + +One way is to list steps under `times:` in format `%Y%m%d-%H%M` for era5_cosmo dataset. + +The other way is to specify `times_range:` with three items: first time step (`%Y%m%d-%H%M`), last time step (`%Y%m%d-%H%M`), hour shift (int). Hour shift specifies distance in hours between closest time steps for specific dataset (6 for era_cosmo). + +By default, inference is done for one time step `20160101-0000` + +- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. + +3. Submit the job with: +```bash +sbatch src/hirad/generate.sh ``` \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 7ee7dba..2c8d37f 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -1,9 +1,9 @@ hydra: job: chdir: true - name: diffusion + name: diffusion_test run: - dir: /scratch/mch/pstamenk/output/${hydra:job.name} + dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} # Get defaults defaults: diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml index d857d12..dc498ce 100644 --- a/src/hirad/conf/training_era_cosmo_regression.yaml +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -3,7 +3,7 @@ hydra: chdir: true name: regression run: - dir: /scratch/mch/pstamenk/output/${hydra:job.name} + dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} # Get defaults defaults: From f4d856b0e562e78fa002b5f9780d9c7d878a14d9 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 16 May 2025 18:04:37 +0200 Subject: [PATCH 042/189] small fix for inference on multiple time steps --- .../conf/training_era_cosmo_diffusion.yaml | 2 +- src/hirad/inference/generate.py | 23 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 2c8d37f..4271e44 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -1,7 +1,7 @@ hydra: job: chdir: true - name: diffusion_test + name: diffusion run: dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 5558a20..7cb9685 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -345,11 +345,12 @@ def generate_fn(image_lr, labels, lead_time_label): writer_executor.submit( save_images, output_path, + times[sampler[time_index]], dataset, image_out.cpu(), image_tar.cpu(), image_lr.cpu(), - image_reg.cpu(), + image_reg.cpu() if image_reg is not None else None, ) ) end.record() @@ -376,15 +377,16 @@ def generate_fn(image_lr, labels, lead_time_label): f.close() logger0.info("Generation Completed.") -def save_images(output_path, dataset, image_pred, image_hr, image_lr, mean_pred): +def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): longitudes = dataset.longitude() latitudes = dataset.latitude() input_channels = dataset.input_channels() output_channels = dataset.output_channels() image_pred = image_pred.numpy() image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) - image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) - image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[32,::].squeeze()),1).reshape(len(output_channels),-1) + if image_pred.shape[0]>1: + image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) + image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[32,::].squeeze()),1).reshape(len(output_channels),-1) image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) if mean_pred is not None: @@ -392,13 +394,14 @@ def save_images(output_path, dataset, image_pred, image_hr, image_lr, mean_pred) os.makedirs(output_path, exist_ok=True) for idx, channel in enumerate(output_channels): input_channel_idx = input_channels.index(channel) - _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{channel.name}-lr.jpg')) - _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{channel.name}-hr.jpg')) - _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred.jpg')) - _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred-0.jpg')) - _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred-mid.jpg')) + _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-lr.jpg')) + _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) + _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) + if image_pred.shape[0]>1: + _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) + _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) if mean_pred is not None: - _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{channel.name}-mean-pred.jpg')) + _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-mean-pred.jpg')) def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): From dcc2a067c1e7f0a7dc1a2da7b58ee5568d4dd6a6 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 21 May 2025 13:00:23 +0200 Subject: [PATCH 043/189] enable validation during training --- src/hirad/datasets/dataset.py | 8 ++++---- src/hirad/inference/generate.py | 7 ++++--- src/hirad/training/train.py | 5 +---- src/hirad/utils/inference_utils.py | 6 +++--- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 6cc6165..380f797 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -36,7 +36,6 @@ def init_train_valid_datasets_from_config( dataloader_cfg: Union[dict, None] = None, batch_size: int = 1, seed: int = 0, - validation_dataset_cfg: Union[dict, None] = None, train_test_split: bool = True, ) -> Tuple[ DownscalingDataset, @@ -59,13 +58,14 @@ def init_train_valid_datasets_from_config( """ config = copy.deepcopy(dataset_cfg) + del config['validation_path'] (dataset, dataset_iter) = init_dataset_from_config( config, dataloader_cfg, batch_size=batch_size, seed=seed ) if train_test_split: - valid_dataset_cfg = copy.deepcopy(config) - if validation_dataset_cfg: - valid_dataset_cfg.update(validation_dataset_cfg) + valid_dataset_cfg = copy.deepcopy(dataset_cfg) + valid_dataset_cfg["dataset_path"] = valid_dataset_cfg["validation_path"] + del valid_dataset_cfg['validation_path'] (valid_dataset, valid_dataset_iter) = init_dataset_from_config( valid_dataset_cfg, dataloader_cfg, batch_size=batch_size, seed=seed ) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 7cb9685..ce8ed7b 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -50,7 +50,6 @@ def main(cfg: DictConfig) -> None: # Initialize logger logger = PythonLogger("generate") # General python logger logger0 = RankZeroLoggingWrapper(logger, dist) - # logger.file_logging("generate.log") # Handle the batch size seeds = list(np.arange(cfg.generation.num_ensembles)) @@ -252,7 +251,7 @@ def generate_fn(image_lr, labels, lead_time_label): elif cfg.generation.inference_mode == "diffusion": image_out = image_res else: - image_out = image_reg + image_res + image_out = image_reg[0:1,::] + image_res if cfg.generation.sample_res != "full": image_out = rearrange( @@ -385,8 +384,9 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, image_pred = image_pred.numpy() image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) if image_pred.shape[0]>1: + image_pred_mean = np.flip(dataset.denormalize_output(image_pred.mean(axis=0)),1).reshape(len(output_channels),-1) image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) - image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[32,::].squeeze()),1).reshape(len(output_channels),-1) + image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[image_pred.shape[0]//2,::].squeeze()),1).reshape(len(output_channels),-1) image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) if mean_pred is not None: @@ -398,6 +398,7 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) if image_pred.shape[0]>1: + _plot_projection(longitudes,latitudes,image_pred_mean[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mean.jpg')) _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) if mean_pred is not None: diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 37e6110..664a6a5 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -36,12 +36,10 @@ def main(cfg: DictConfig) -> None: OmegaConf.resolve(cfg) dataset_cfg = OmegaConf.to_container(cfg.dataset) - if hasattr(cfg, "validation"): + if hasattr(cfg.dataset, "validation_path"): train_test_split = True - validation_dataset_cfg = OmegaConf.to_container(cfg.validation) else: train_test_split = False - validation_dataset_cfg = None fp_optimizations = cfg.training.perf.fp_optimizations songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level fp16 = fp_optimizations == "fp16" @@ -77,7 +75,6 @@ def main(cfg: DictConfig) -> None: data_loader_kwargs, batch_size=cfg.training.hp.batch_size_per_gpu, seed=0, - validation_dataset_cfg=validation_dataset_cfg, train_test_split=train_test_split, ) logger0.info(f"Training on dataset with size {len(dataset)}") diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 4831bdd..ace05ba 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -52,14 +52,14 @@ def regression_step( """ # Create a tensor of zeros with the given shape and move it to the appropriate device x_hat = torch.zeros(latents_shape, dtype=img_lr.dtype, device=img_lr.device) - t_hat = torch.tensor(1.0, dtype=img_lr.dtype, device=img_lr.device).reshape((1,1,1,1)) + t_hat = torch.tensor(1.0, dtype=img_lr.dtype, device=img_lr.device)#.reshape((1,1,1,1)) # Perform regression on a single batch element with torch.inference_mode(): if lead_time_label is not None: - x = net(x_hat, img_lr, t_hat, labels, lead_time_label=lead_time_label) + x = net(x_hat[0:1], img_lr, t_hat, labels, lead_time_label=lead_time_label) else: - x = net(x_hat, img_lr, t_hat, labels) + x = net(x_hat[0:1], img_lr, t_hat, labels) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: From cadccd502fb913e18db7d7ea4eba2787f8fca5c5 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 22 May 2025 15:57:25 +0200 Subject: [PATCH 044/189] change generate eval to new functions --- src/hirad/inference/generate.py | 107 +++++++++++++++++++++----------- 1 file changed, 70 insertions(+), 37 deletions(-) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index ce8ed7b..fbfd8cf 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -35,6 +35,7 @@ from hirad.utils.train_helpers import set_patch_shape +from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig) -> None: @@ -346,10 +347,10 @@ def generate_fn(image_lr, labels, lead_time_label): output_path, times[sampler[time_index]], dataset, - image_out.cpu(), - image_tar.cpu(), - image_lr.cpu(), - image_reg.cpu() if image_reg is not None else None, + image_out.cpu().numpy(), + image_tar.cpu().numpy(), + image_lr.cpu().numpy(), + image_reg.cpu().numpy() if image_reg is not None else None, ) ) end.record() @@ -381,41 +382,73 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, latitudes = dataset.latitude() input_channels = dataset.input_channels() output_channels = dataset.output_channels() - image_pred = image_pred.numpy() - image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) - if image_pred.shape[0]>1: - image_pred_mean = np.flip(dataset.denormalize_output(image_pred.mean(axis=0)),1).reshape(len(output_channels),-1) - image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) - image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[image_pred.shape[0]//2,::].squeeze()),1).reshape(len(output_channels),-1) - image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) - image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) - if mean_pred is not None: - mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) - os.makedirs(output_path, exist_ok=True) + + target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) + baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) + + freqs = {} + power = {} for idx, channel in enumerate(output_channels): input_channel_idx = input_channels.index(channel) - _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-lr.jpg')) - _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) - _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) - if image_pred.shape[0]>1: - _plot_projection(longitudes,latitudes,image_pred_mean[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mean.jpg')) - _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) - _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) - if mean_pred is not None: - _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-mean-pred.jpg')) - -def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): - - """Plot observed or interpolated data in a scatter plot.""" - # TODO: Refactor this somehow, it's not really generalizing well across variables. - fig = plt.figure() - fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) - p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) - ax.coastlines() - ax.gridlines(draw_labels=True) - plt.colorbar(p, label="K", orientation="horizontal") - plt.savefig(filename) - plt.close('all') + _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) + _, prediction_errors = compute_mae(prediction[idx,:,:], target[idx,:,:]) + + plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-baseline-error.jpg')) + plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-prediction-error.jpg')) + + b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0) + freqs['baseline'] = b_freq + power['baseline'] = b_power + #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) + t_freq, t_power = average_power_spectrum(target[idx,:,:].squeeze(), 2.0) + freqs['target'] = t_freq + power['target'] = t_power + p_freq, p_power = average_power_spectrum(prediction[idx,:,:].squeeze(), 2.0) + freqs['prediction'] = p_freq + power['prediction'] = p_power + plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) + +# def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): +# longitudes = dataset.longitude() +# latitudes = dataset.latitude() +# input_channels = dataset.input_channels() +# output_channels = dataset.output_channels() +# image_pred = image_pred.numpy() +# image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) +# if image_pred.shape[0]>1: +# image_pred_mean = np.flip(dataset.denormalize_output(image_pred.mean(axis=0)),1).reshape(len(output_channels),-1) +# image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) +# image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[image_pred.shape[0]//2,::].squeeze()),1).reshape(len(output_channels),-1) +# image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) +# image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) +# if mean_pred is not None: +# mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) +# os.makedirs(output_path, exist_ok=True) +# for idx, channel in enumerate(output_channels): +# input_channel_idx = input_channels.index(channel) +# _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-lr.jpg')) +# _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) +# _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) +# if image_pred.shape[0]>1: +# _plot_projection(longitudes,latitudes,image_pred_mean[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mean.jpg')) +# _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) +# _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) +# if mean_pred is not None: +# _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-mean-pred.jpg')) + +# def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): + +# """Plot observed or interpolated data in a scatter plot.""" +# # TODO: Refactor this somehow, it's not really generalizing well across variables. +# fig = plt.figure() +# fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) +# p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) +# ax.coastlines() +# ax.gridlines(draw_labels=True) +# plt.colorbar(p, label="K", orientation="horizontal") +# plt.savefig(filename) +# plt.close('all') if __name__ == "__main__": main() From 92b08b7fad3c2fdf09b07ea2a82a69ba9175f62e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 23 May 2025 14:28:30 +0200 Subject: [PATCH 045/189] fix average training loss tracking --- src/hirad/training/train.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 664a6a5..559d800 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -366,18 +366,6 @@ def main(cfg: DictConfig) -> None: "training_loss_running_mean", average_loss_running_mean, cur_nimg ) - ptt = is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ) - if ptt: - # reset running mean of average loss - average_loss_running_mean = 0 - n_average_loss_running_mean = 1 # Update weights. lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate @@ -481,6 +469,19 @@ def main(cfg: DictConfig) -> None: logger0.info(" ".join(fields)) torch.cuda.reset_peak_memory_stats() + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + # Save checkpoints if dist.world_size > 1: torch.distributed.barrier() From 97970fcc634f7ca1934cda32ee58b8c27d79dc02 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 23 May 2025 14:33:58 +0200 Subject: [PATCH 046/189] fix validation bug --- src/hirad/datasets/dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 380f797..7ba8833 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -58,7 +58,8 @@ def init_train_valid_datasets_from_config( """ config = copy.deepcopy(dataset_cfg) - del config['validation_path'] + if 'validation_path': + del config['validation_path'] (dataset, dataset_iter) = init_dataset_from_config( config, dataloader_cfg, batch_size=batch_size, seed=seed ) @@ -83,6 +84,8 @@ def init_dataset_from_config( ) -> Tuple[DownscalingDataset, Iterable]: dataset_cfg = copy.deepcopy(dataset_cfg) dataset_type = dataset_cfg.pop("type", "era5_cosmo") + if "validation_path" in dataset_cfg: + del dataset_cfg['validation_path'] if "train_test_split" in dataset_cfg: # handled by init_train_valid_datasets_from_config del dataset_cfg["train_test_split"] From 937e7c97e129ee6c0d696ef318300372b17f29a3 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 26 May 2025 18:20:09 +0200 Subject: [PATCH 047/189] update to latest corrdiff version --- src/hirad/conf/generation/era_cosmo.yaml | 19 +- src/hirad/conf/model/era_cosmo_diffusion.yaml | 12 +- .../conf/model/era_cosmo_regression.yaml | 10 +- src/hirad/conf/model_size/mini.yaml | 26 + src/hirad/conf/model_size/normal.yaml | 26 + src/hirad/conf/sampler/stochastic.yaml | 4 +- .../conf/training_era_cosmo_regression.yaml | 2 + src/hirad/datasets/era5_cosmo.py | 3 +- src/hirad/inference/generate.py | 132 ++- src/hirad/losses/__init__.py | 2 +- src/hirad/losses/loss.py | 630 ++++++----- src/hirad/models/__init__.py | 12 +- src/hirad/models/layers.py | 354 +++++-- src/hirad/models/preconditioning.py | 352 +++++-- src/hirad/models/song_unet.py | 996 ++++++++++++------ src/hirad/models/unet.py | 177 +++- src/hirad/training/train.py | 741 ++++++++----- src/hirad/utils/deterministic_sampler.py | 162 ++- src/hirad/utils/function_utils.py | 37 +- src/hirad/utils/inference_utils.py | 145 ++- src/hirad/utils/patching.py | 767 ++++++++++++++ src/hirad/utils/stochastic_sampler.py | 524 +++------ src/hirad/utils/train_helpers.py | 14 +- 23 files changed, 3584 insertions(+), 1563 deletions(-) create mode 100644 src/hirad/conf/model_size/mini.yaml create mode 100644 src/hirad/conf/model_size/normal.yaml create mode 100644 src/hirad/utils/patching.py diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index a0c5a40..be4219d 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -1,22 +1,26 @@ -num_ensembles: 64 +num_ensembles: 8 # Number of ensembles to generate per input -seed_batch_size: 1 +seed_batch_size: 4 # Size of the batched inference inference_mode: all # Choose between "all" (regression + diffusion), "regression" or "diffusion" # Patch size. Patch-based sampling will be utilized if these dimensions differ from # img_shape_x and img_shape_y -overlap_pixels: 0 +# overlap_pixels: 0 # Number of overlapping pixels between adjacent patches -boundary_pixels: 0 +# boundary_pixels: 0 # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary # artifact. +patching: False hr_mean_conditioning: True -sample_res: full +# sample_res: full # Sampling resolution times_range: null times: - 20160101-0000 + # - 20160101-0600 + # - 20160101-1200 +has_laed_time: False perf: force_fp16: False @@ -31,9 +35,10 @@ perf: # To support multiple workers a threadsafe version of the netCDF library must be used io: - res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_test/checkpoints_diffusion + res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_refactoring/checkpoints_diffusion + # res_ckpt_path: null # Checkpoint filename for the diffusion model - reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression + reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_refactoring/checkpoints_regression # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression # Checkpoint filename for the mean predictor model output_path: ./images \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_diffusion.yaml b/src/hirad/conf/model/era_cosmo_diffusion.yaml index 06aa2a4..441239e 100644 --- a/src/hirad/conf/model/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/model/era_cosmo_diffusion.yaml @@ -2,4 +2,14 @@ name: diffusion # Name of the preconditioner hr_mean_conditioning: True # High-res mean (regression's output) as additional condition -scale_cond_input: False \ No newline at end of file + +# Standard model parameters. +model_args: + gridtype: "sinusoidal" + # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'. + # Controls how positional information is encoded. + N_grid_channels: 4 + # Number of channels for positional grid embeddings + embedding_type: "zero" + # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, + # 'zero' for none \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_regression.yaml b/src/hirad/conf/model/era_cosmo_regression.yaml index 487eb4b..29b43e8 100644 --- a/src/hirad/conf/model/era_cosmo_regression.yaml +++ b/src/hirad/conf/model/era_cosmo_regression.yaml @@ -1,2 +1,10 @@ name: regression -hr_mean_conditioning: False \ No newline at end of file +hr_mean_conditioning: False + +# Default regression model parameters. Do not modify. +model_args: + "N_grid_channels": 4 + # Number of channels for positional grid embeddings + "embedding_type": "zero" + # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, + # 'zero' for none \ No newline at end of file diff --git a/src/hirad/conf/model_size/mini.yaml b/src/hirad/conf/model_size/mini.yaml new file mode 100644 index 0000000..2eb8f8a --- /dev/null +++ b/src/hirad/conf/model_size/mini.yaml @@ -0,0 +1,26 @@ +# @package _global_.model + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +model_args: + # Base multiplier for the number of channels across the network. + model_channels: 64 + # Per-resolution multipliers for the number of channels. + channel_mult: [1, 2, 2] + # Resolutions at which self-attention layers are applied. + attn_resolutions: [16] \ No newline at end of file diff --git a/src/hirad/conf/model_size/normal.yaml b/src/hirad/conf/model_size/normal.yaml new file mode 100644 index 0000000..b81fe15 --- /dev/null +++ b/src/hirad/conf/model_size/normal.yaml @@ -0,0 +1,26 @@ +# @package _global_.model + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +model_args: + # Base multiplier for the number of channels across the network. + model_channels: 128 + # Per-resolution multipliers for the number of channels. + channel_mult: [1, 2, 2, 2, 2] + # Resolutions at which self-attention layers are applied. + attn_resolutions: [28] \ No newline at end of file diff --git a/src/hirad/conf/sampler/stochastic.yaml b/src/hirad/conf/sampler/stochastic.yaml index 5e8fa88..2481cd3 100644 --- a/src/hirad/conf/sampler/stochastic.yaml +++ b/src/hirad/conf/sampler/stochastic.yaml @@ -1,3 +1,3 @@ type: stochastic -boundary_pix: 2 -overlap_pix: 4 \ No newline at end of file +# boundary_pix: 2 +# overlap_pix: 4 \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml index dc498ce..1de83d9 100644 --- a/src/hirad/conf/training_era_cosmo_regression.yaml +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -15,5 +15,7 @@ defaults: # Model - model/era_cosmo_regression + - model_size/normal + # Training - training/era_cosmo_regression \ No newline at end of file diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 674dbf0..f97dbc6 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -63,8 +63,7 @@ def __getitem__(self, idx): cosmo_data = self.normalize_output(cosmo_data) # return samples return torch.tensor(cosmo_data),\ - torch.tensor(era5_data),\ - 0 + torch.tensor(era5_data), # return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ # F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ # 0 diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index fbfd8cf..8fed809 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -6,6 +6,8 @@ import torch._dynamo import nvtx import numpy as np +import contextlib + from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from concurrent.futures import ThreadPoolExecutor @@ -18,7 +20,8 @@ from hydra.utils import to_absolute_path -from hirad.models import EDMPrecondSR, UNet +from hirad.models import EDMPrecondSuperResolution, UNet +from hirad.utils.patching import GridPatching2D from hirad.utils.stochastic_sampler import stochastic_sampler from hirad.utils.deterministic_sampler import deterministic_sampler from hirad.utils.inference_utils import ( @@ -85,19 +88,23 @@ def main(cfg: DictConfig) -> None: img_out_channels = len(dataset.output_channels()) # Parse the patch shape - if hasattr(cfg.generation, "patch_shape_x"): # TODO better config handling + if cfg.generation.patching: patch_shape_x = cfg.generation.patch_shape_x - else: - patch_shape_x = None - if hasattr(cfg.generation, "patch_shape_y"): patch_shape_y = cfg.generation.patch_shape_y else: - patch_shape_y = None + patch_shape_x, patch_shape_y = None, None patch_shape = (patch_shape_y, patch_shape_x) - img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) - if patch_shape != img_shape: + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + patching = GridPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=cfg.generation.boundary_pix, + overlap_pix=cfg.generation.overlap_pix, + ) logger0.info("Patch-based training enabled") else: + patching = None logger0.info("Patch-based training disabled") # Parse the inference mode @@ -121,7 +128,7 @@ def main(cfg: DictConfig) -> None: with open(diffusion_model_args_path, 'r') as f: diffusion_model_args = json.load(f) - net_res = EDMPrecondSR(**diffusion_model_args) + net_res = EDMPrecondSuperResolution(**diffusion_model_args) _ = load_checkpoint( path=res_ckpt_path, @@ -130,9 +137,13 @@ def main(cfg: DictConfig) -> None: ) #TODO fix to use channels_last which is optimal for H100 - net_res = net_res.eval().to(device)#.to(memory_format=torch.channels_last) + net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_res.use_fp16 = True + + # Disable AMP for inference (even if model is trained with AMP) + if hasattr(net_res, "amp_mode"): + net_res.amp_mode = False else: net_res = None @@ -156,9 +167,13 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - net_reg = net_reg.eval().to(device)#.to(memory_format=torch.channels_last) + net_reg = net_reg.eval().to(device).to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_reg.use_fp16 = True + + # Disable AMP for inference (even if model is trained with AMP) + if hasattr(net_reg, "amp_mode"): + net_reg.amp_mode = False else: net_reg = None @@ -183,47 +198,28 @@ def main(cfg: DictConfig) -> None: solver=cfg.sampler.solver, ) elif cfg.sampler.type == "stochastic": - sampler_fn = partial( - stochastic_sampler, - img_shape=img_shape, - patch_shape_x=patch_shape[0], - patch_shape_y=patch_shape[1], - boundary_pix=cfg.sampler.boundary_pix, - overlap_pix=cfg.sampler.overlap_pix, - ) + sampler_fn = partial(stochastic_sampler, patching=patching) else: raise ValueError(f"Unknown sampling method {cfg.sampling.type}") # Main generation definition - def generate_fn(image_lr, labels, lead_time_label): - img_shape_y, img_shape_x = img_shape + def generate_fn(image_lr, lead_time_label): with nvtx.annotate("generate_fn", color="green"): - if cfg.generation.sample_res == "full": - image_lr_patch = image_lr - else: - torch.cuda.nvtx.range_push("rearrange") - image_lr_patch = rearrange( - image_lr, - "b c (h1 h) (w1 w) -> (b h1 w1) c h w", - h1=img_shape_y // patch_shape[0], - w1=img_shape_x // patch_shape[1], - ) - torch.cuda.nvtx.range_pop() - image_lr_patch = image_lr_patch #.to(memory_format=torch.channels_last) + # (1, C, H, W) + image_lr = image_lr.to(memory_format=torch.channels_last) if net_reg: with nvtx.annotate("regression_model", color="yellow"): image_reg = regression_step( net=net_reg, - img_lr=image_lr_patch, - labels=labels, + img_lr=image_lr, latents_shape=( cfg.generation.seed_batch_size, img_out_channels, img_shape[0], img_shape[1], - ), + ), # (batch_size, C, H, W) lead_time_label=lead_time_label, ) if net_res: @@ -235,16 +231,15 @@ def generate_fn(image_lr, labels, lead_time_label): image_res = diffusion_step( net=net_res, sampler_fn=sampler_fn, - seed_batch_size=cfg.generation.seed_batch_size, img_shape=img_shape, img_out_channels=img_out_channels, rank_batches=rank_batches, - img_lr=image_lr_patch.expand( + img_lr=image_lr.expand( cfg.generation.seed_batch_size, -1, -1, -1 ), #.to(memory_format=torch.channels_last), rank=dist.rank, device=device, - hr_mean=mean_hr, + mean_hr=mean_hr, lead_time_label=lead_time_label, ) if cfg.generation.inference_mode == "regression": @@ -254,13 +249,6 @@ def generate_fn(image_lr, labels, lead_time_label): else: image_out = image_reg[0:1,::] + image_res - if cfg.generation.sample_res != "full": - image_out = rearrange( - image_out, - "(b h1 w1) c h w -> b c (h1 h) (w1 w)", - h1=img_shape_y // patch_shape[0], - w1=img_shape_x // patch_shape[1], - ) # Gather tensors on rank 0 if dist.world_size > 1: if dist.rank == 0: @@ -300,8 +288,18 @@ def generate_fn(image_lr, labels, lead_time_label): # through the dataset using a data loader, computes predictions, and saves them along # with associated metadata. - with torch.cuda.profiler.profile(): - with torch.autograd.profiler.emit_nvtx(): + torch_cuda_profiler = ( + torch.cuda.profiler.profile() + if torch.cuda.is_available() + else contextlib.nullcontext() + ) + torch_nvtx_profiler = ( + torch.autograd.profiler.emit_nvtx() + if torch.cuda.is_available() + else contextlib.nullcontext() + ) + with torch_cuda_profiler: + with torch_nvtx_profiler: data_loader = torch.utils.data.DataLoader( dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True @@ -313,11 +311,29 @@ def generate_fn(image_lr, labels, lead_time_label): ) writer_threads = [] - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + # Create timer objects only if CUDA is available + use_cuda_timing = torch.cuda.is_available() + if use_cuda_timing: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + else: + # Dummy no-op functions for CPU case + class DummyEvent: + def record(self): + pass + + def synchronize(self): + pass + + def elapsed_time(self, _): + return 0 + + start = end = DummyEvent() times = dataset.time() - for image_tar, image_lr, labels, *lead_time_label in iter(data_loader): + for index, (image_tar, image_lr, *lead_time_label) in enumerate( + iter(data_loader) + ): time_index += 1 if dist.rank == 0: logger0.info(f"starting index: {time_index}") @@ -333,11 +349,10 @@ def generate_fn(image_lr, labels, lead_time_label): image_lr = ( image_lr.to(device=device) .to(torch.float32) - #.to(memory_format=torch.channels_last) + .to(memory_format=torch.channels_last) ) image_tar = image_tar.to(device=device).to(torch.float32) - labels = labels.to(device).to(torch.float32).contiguous() - image_out, image_reg = generate_fn(image_lr,labels,lead_time_label) + image_out, image_reg = generate_fn(image_lr,lead_time_label) if dist.rank == 0: batch_size = image_out.shape[0] # write out data in a seperate thread so we don't hold up inferencing @@ -355,9 +370,11 @@ def generate_fn(image_lr, labels, lead_time_label): ) end.record() end.synchronize() - elapsed_time = start.elapsed_time(end) / 1000.0 # Convert ms to s + elapsed_time = ( + start.elapsed_time(end) / 1000.0 if use_cuda_timing else 0 + ) # Convert ms to s timed_steps = time_index + 1 - warmup_steps - if dist.rank == 0: + if dist.rank == 0 and use_cuda_timing: average_time_per_batch_element = elapsed_time / timed_steps / batch_size logger.info( f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s" @@ -378,6 +395,9 @@ def generate_fn(image_lr, labels, lead_time_label): logger0.info("Generation Completed.") def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): + + os.makedirs(output_path, exist_ok=True) + longitudes = dataset.longitude() latitudes = dataset.latitude() input_channels = dataset.input_channels() diff --git a/src/hirad/losses/__init__.py b/src/hirad/losses/__init__.py index 185527b..868ffdf 100644 --- a/src/hirad/losses/__init__.py +++ b/src/hirad/losses/__init__.py @@ -1 +1 @@ -from .loss import ResLoss, RegressionLoss, RegressionLossCE \ No newline at end of file +from .loss import ResidualLoss, RegressionLoss, RegressionLossCE \ No newline at end of file diff --git a/src/hirad/losses/loss.py b/src/hirad/losses/loss.py index 18dde13..fb65960 100644 --- a/src/hirad/losses/loss.py +++ b/src/hirad/losses/loss.py @@ -18,12 +18,12 @@ """Loss functions used in the paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" -import random -from typing import Callable, Optional, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch +from hirad.utils.patching import RandomPatching2D class VPLoss: """ @@ -333,7 +333,7 @@ def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - # augment for conditional generaiton + # augment for conditional generation img_tot = torch.cat((img_clean, img_lr), dim=1) y_tot, augment_labels = ( augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) @@ -349,16 +349,13 @@ def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): class RegressionLoss: """ - Regression loss function for the U-Net for deterministic predictions. + Regression loss function for the deterministic predictions. + Note: this loss does not apply any reduction. - Parameters + Attributes ---------- - P_mean: float, optional - Mean value for `sigma` computation, by default -1.2. - P_std: float, optional: - Standard deviation for `sigma` computation, by default 1.2. - sigma_data: float, optional - Standard deviation for data, by default 0.5. + sigma_data: float + Standard deviation for data. Deprecated and ignored. Note ---- @@ -368,43 +365,68 @@ class RegressionLoss: arXiv preprint arXiv:2309.15214. """ - def __init__( - self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 - ): - self.P_mean = P_mean - self.P_std = P_std - self.sigma_data = sigma_data + def __init__(self): + """ + Arguments + ---------- + """ + return - def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + def __call__( + self, + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + ) -> torch.Tensor: """ - Calculate and return the loss for the U-Net for deterministic predictions. + Calculate and return the regression loss for + deterministic predictions. - Parameters: + Parameters ---------- - net: torch.nn.Module + net : torch.nn.Module The neural network model that will make predictions. + Expected signature: `net(x, img_lr, + augment_labels=augment_labels, force_fp32=False)`, where: + x (torch.Tensor): Tensor of shape (B, C_hr, H, W). Is zero-filled. + img_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + augment_labels (torch.Tensor, optional): Optional augmentation + labels, returned by `augment_pipe`. + force_fp32 (bool, optional): Whether to force the model to use + fp32, by default False. + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if 'augment_pipe' is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the neural network. + + augment_pipe : callable, optional + An optional data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution + images of shape (B, C_hr+C_lr, H, W) + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels - img_clean: torch.Tensor - Input images (high resolution) to the neural network. - - img_lr: torch.Tensor - Input images (low resolution) to the neural network. - - labels: torch.Tensor - Ground truth labels for the input images. - - augment_pipe: callable, optional - An optional data augmentation function that takes images as input and - returns augmented images. If not provided, no data augmentation is applied. - - Returns: + Returns ------- torch.Tensor - A tensor representing the loss calculated based on the network's - predictions. + A tensor representing the per-sample element-wise squared + difference between the network's predictions and the high + resolution images `img_clean` (possibly data-augmented by + `augment_pipe`). + Shape: (B, C_hr, H, W), same as `img_clean`. """ - rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = ( 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 ) @@ -416,100 +438,214 @@ def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): y = y_tot[:, : img_clean.shape[1], :, :] y_lr = y_tot[:, img_clean.shape[1] :, :, :] - input = torch.zeros_like(y, device=img_clean.device) - D_yn = net(input, y_lr, sigma, labels, augment_labels=augment_labels) + zero_input = torch.zeros_like(y, device=img_clean.device) + D_yn = net(zero_input, y_lr, force_fp32=False, augment_labels=augment_labels) loss = weight * ((D_yn - y) ** 2) return loss -class ResLoss: +class ResidualLoss: """ Mixture loss function for denoising score matching. - Parameters + This class implements a loss function that combines deterministic + regression with denoising score matching. It uses a pre-trained regression + network to compute residuals before applying the diffusion process. + + Attributes ---------- - P_mean: float, optional - Mean value for `sigma` computation, by default -1.2. - P_std: float, optional: - Standard deviation for `sigma` computation, by default 1.2. - sigma_data: float, optional - Standard deviation for data, by default 0.5. + regression_net : torch.nn.Module + The regression network used for computing residuals. + P_mean : float + Mean value for noise level computation. + P_std : float + Standard deviation for noise level computation. + sigma_data : float + Standard deviation for data weighting. + hr_mean_conditioning : bool + Flag indicating whether to use high-resolution mean for conditioning. Note ---- Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., - Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. - Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. - arXiv preprint arXiv:2309.15214. + Liu, C.C., Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric + Downscaling. arXiv preprint arXiv:2309.15214. """ def __init__( self, - regression_net, - img_shape_x, - img_shape_y, - patch_shape_x, - patch_shape_y, - patch_num, + regression_net: torch.nn.Module, P_mean: float = 0.0, P_std: float = 1.2, sigma_data: float = 0.5, hr_mean_conditioning: bool = False, ): - self.unet = regression_net + """ + Arguments + ---------- + regression_net : torch.nn.Module + Pre-trained regression network used to compute residuals. + Expected signature: `net(zero_input, y_lr, + lead_time_label=lead_time_label, augment_labels=augment_labels)` or + `net(zero_input, y_lr, augment_labels=augment_labels)`, where: + zero_input (torch.Tensor): Zero tensor of shape (B, C_hr, H, W) + y_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + lead_time_label (torch.Tensor, optional): Optional lead time labels + augment_labels (torch.Tensor, optional): Optional augmentation labels + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + P_mean : float, optional + Mean value for noise level computation, by default 0.0. + + P_std : float, optional + Standard deviation for noise level computation, by default 1.2. + + sigma_data : float, optional + Standard deviation for data weighting, by default 0.5. + + hr_mean_conditioning : bool, optional + Whether to use high-resolution mean for conditioning predicted, by default False. + When True, the mean prediction from `regression_net` is channel-wise + concatenated with `img_lr` for conditioning. + """ + self.regression_net = regression_net self.P_mean = P_mean self.P_std = P_std self.sigma_data = sigma_data - self.img_shape_x = img_shape_x - self.img_shape_y = img_shape_y - self.patch_shape_x = patch_shape_x - self.patch_shape_y = patch_shape_y - self.patch_num = patch_num self.hr_mean_conditioning = hr_mean_conditioning + self.y_mean = None def __call__( self, - net, - img_clean, - img_lr, - labels=None, - lead_time_label=None, - augment_pipe=None, - ): + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + patching: Optional[RandomPatching2D] = None, + lead_time_label: Optional[torch.Tensor] = None, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + use_patch_grad_acc: bool = False, + ) -> torch.Tensor: """ Calculate and return the loss for denoising score matching. - Parameters: - ---------- - net: torch.nn.Module - The neural network model that will make predictions. - - img_clean: torch.Tensor - Input images (high resolution) to the neural network. - - img_lr: torch.Tensor - Input images (low resolution) to the neural network. - - labels: torch.Tensor - Ground truth labels for the input images. + This method computes a mixture loss that combines deterministic + regression with denoising score matching. It first computes residuals + using the regression network, then applies the diffusion process to + these residuals. + + In addition to the standard denoising score matching loss, this method + also supports optional patching for multi-diffusion. In this case, the spatial + dimensions of the input are decomposed into `P` smaller patches of shape + (H_patch, W_patch), that are grouped along the batch dimension, and the + model is applied to each patch individually. In the following, if `patching` + is not provided, then the input is not patched and `P=1` and `(H_patch, + W_patch) = (H, W)`. When patching is used, the original non-patched conditioning is + interpolated onto a spatial grid of shape `(H_patch, W_patch)` and channel-wise + concatenated to the patched conditioning. This ensures that each patch + maintains global information from the entire domain. + + The diffusion model `net` is expected to be conditioned on an input with + `C_cond` channels, which should be: + - `C_cond = C_lr` if `hr_mean_conditioning` is `False` and + `patching` is None. + - `C_cond = C_hr + C_lr` if `hr_mean_conditioning` is `True` and + `patching` is None. + - `C_cond = C_hr + 2*C_lr` if `hr_mean_conditioning` is `True` and + `patching` is not None. + - `C_cond = 2*C_lr` if `hr_mean_conditioning` is `False` and + `patching` is not None. + Additionally, `C_cond` should also include any embedding channels, + such as positional embeddings or time embeddings. + + Note: this loss function does not apply any reduction. - augment_pipe: callable, optional - An optional data augmentation function that takes images as input and - returns augmented images. If not provided, no data augmentation is applied. + Parameters + ---------- + net : torch.nn.Module + The neural network model for the diffusion process. + Expected signature: `net(latent, y_lr, sigma, + embedding_selector=embedding_selector, lead_time_label=lead_time_label, + augment_labels=augment_labels)`, where: + latent (torch.Tensor): Noisy input of shape (B[*P], C_hr, H_patch, W_patch) + y_lr (torch.Tensor): Conditioning of shape (B[*P], C_cond, H_patch, W_patch) + sigma (torch.Tensor): Noise level of shape (B[*P], 1, 1, 1) + embedding_selector (callable, optional): Function to select + positional embeddings. Only used if `patching` is provided. + lead_time_label (torch.Tensor, optional): Lead time labels. + augment_labels (torch.Tensor, optional): Augmentation labels + Returns: + torch.Tensor: Predictions of shape (B[*P], C_hr, H_patch, W_patch) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if 'augment_pipe' is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the regression network and conditioning for the + diffusion process. + + patching : Optional[RandomPatching2D], optional + Patching strategy for processing large images, by default None. See + :class:`physicsnemo.utils.patching.RandomPatching2D` for details. + When provided, the patching strategy is used for both image patches + and positional embeddings selection in the diffusion model `net`. + Transforms tensors from shape (B, C, H, W) to (B*P, C, H_patch, + W_patch). + + lead_time_label : Optional[torch.Tensor], optional + Labels for lead-time aware predictions, by default None. + Shape can vary based on model requirements, typically (B,) or scalar. + + augment_pipe : Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]] + Data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution images + of shape (B, C_hr+C_lr, H, W) + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels + use_patch_grad_acc: bool, optional + A boolean flag indicating whether to enable multi-iterations of patching accumulations + for amortizing regression cost. Default False. - Returns: + Returns ------- torch.Tensor - A tensor representing the loss calculated based on the network's - predictions. + If patching is not used: + A tensor of shape (B, C_hr, H, W) representing the per-sample loss. + If patching is used: + A tensor of shape (B*P, C_hr, H_patch, W_patch) representing + the per-patch loss. + + Raises + ------ + ValueError + If patching is provided but is not an instance of RandomPatching2D. + If shapes of img_clean and img_lr are incompatible. """ - rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() - weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + # Safety check: enforce patching object + if patching and not isinstance(patching, RandomPatching2D): + raise ValueError("patching must be a 'RandomPatching2D' object.") + # Safety check: enforce shapes + if ( + img_clean.shape[0] != img_lr.shape[0] + or img_clean.shape[2:] != img_lr.shape[2:] + ): + raise ValueError( + f"Shape mismatch between img_clean {img_clean.shape} and " + f"img_lr {img_lr.shape}. " + f"Batch size, height and width must match." + ) - # augment for conditional generaiton + # augment for conditional generation img_tot = torch.cat((img_clean, img_lr), dim=1) y_tot, augment_labels = ( augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) @@ -517,114 +653,71 @@ def __call__( y = y_tot[:, : img_clean.shape[1], :, :] y_lr = y_tot[:, img_clean.shape[1] :, :, :] y_lr_res = y_lr - - # global index - b = y.shape[0] - Nx = torch.arange(self.img_shape_x).int() - Ny = torch.arange(self.img_shape_y).int() - grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ - None, - ].expand(b, -1, -1, -1) - - # form residual - if lead_time_label is not None: - y_mean = self.unet( - torch.zeros_like(y, device=img_clean.device), - y_lr_res, - sigma, - labels, - lead_time_label=lead_time_label, - augment_labels=augment_labels, - ) + batch_size = y.shape[0] + + # if using multi-iterations of patching, switch to optimized version + if use_patch_grad_acc: + # form residual + if self.y_mean is None: + if lead_time_label is not None: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + augment_labels=augment_labels, + ) + self.y_mean = y_mean + + # if on full domain, or if using patching without multi-iterations else: - y_mean = self.unet( - torch.zeros_like(y, device=img_clean.device), - y_lr_res, - sigma, - labels, - augment_labels=augment_labels, - ) + # form residual + if lead_time_label is not None: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + augment_labels=augment_labels, + ) - y = y - y_mean + self.y_mean = y_mean + + y = y - self.y_mean if self.hr_mean_conditioning: - y_lr = torch.cat((y_mean, y_lr), dim=1).contiguous() - global_index = None + y_lr = torch.cat((self.y_mean, y_lr), dim=1) + # patchified training # conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4 - if ( - self.img_shape_x != self.patch_shape_x - or self.img_shape_y != self.patch_shape_y - ): - c_in = y_lr.shape[1] - c_out = y.shape[1] - rnd_normal = torch.randn( - [img_clean.shape[0] * self.patch_num, 1, 1, 1], device=img_clean.device - ) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() - weight = (sigma**2 + self.sigma_data**2) / ( - sigma * self.sigma_data - ) ** 2 - - # global interpolation - input_interp = torch.nn.functional.interpolate( - img_lr, - (self.patch_shape_y, self.patch_shape_x), - mode="bilinear", - ) + # removed patch_embedding_selector due to compilation issue with dynamo. + if patching: + # Patched residual + # (batch_size * patch_num, c_out, patch_shape_y, patch_shape_x) + y_patched = patching.apply(input=y) + # Patched conditioning on y_lr and interp(img_lr) + # (batch_size * patch_num, 2*c_in, patch_shape_y, patch_shape_x) + y_lr_patched = patching.apply(input=y_lr, additional_input=img_lr) + + y = y_patched + y_lr = y_lr_patched + + # Noise + rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - # patch generation from a single sample (not from random samples due to memory consumption of regression) - y_new = torch.zeros( - b * self.patch_num, - c_out, - self.patch_shape_y, - self.patch_shape_x, - device=img_clean.device, - ) - y_lr_new = torch.zeros( - b * self.patch_num, - c_in + input_interp.shape[1], - self.patch_shape_y, - self.patch_shape_x, - device=img_clean.device, - ) - global_index = torch.zeros( - b * self.patch_num, - 2, - self.patch_shape_y, - self.patch_shape_x, - dtype=torch.int, - device=img_clean.device, - ) - for i in range(self.patch_num): - rnd_x = random.randint(0, self.img_shape_x - self.patch_shape_x) - rnd_y = random.randint(0, self.img_shape_y - self.patch_shape_y) - y_new[b * i : b * (i + 1),] = y[ - :, - :, - rnd_y : rnd_y + self.patch_shape_y, - rnd_x : rnd_x + self.patch_shape_x, - ] - global_index[b * i : b * (i + 1),] = grid[ - :, - :, - rnd_y : rnd_y + self.patch_shape_y, - rnd_x : rnd_x + self.patch_shape_x, - ] - y_lr_new[b * i : b * (i + 1),] = torch.cat( - ( - y_lr[ - :, - :, - rnd_y : rnd_y + self.patch_shape_y, - rnd_x : rnd_x + self.patch_shape_x, - ], - input_interp, - ), - 1, - ) - y = y_new - y_lr = y_lr_new + # Input + noise latent = y + torch.randn_like(y) * sigma if lead_time_label is not None: @@ -632,8 +725,10 @@ def __call__( latent, y_lr, sigma, - labels, - global_index=global_index, + embedding_selector=None, + global_index=patching.global_index(batch_size, img_clean.device) + if patching is not None + else None, lead_time_label=lead_time_label, augment_labels=augment_labels, ) @@ -642,8 +737,10 @@ def __call__( latent, y_lr, sigma, - labels, - global_index=global_index, + embedding_selector=None, + global_index=patching.global_index(batch_size, img_clean.device) + if patching is not None + else None, augment_labels=augment_labels, ) loss = weight * ((D_yn - y) ** 2) @@ -651,6 +748,7 @@ def __call__( return loss + class VELoss_dfsr: """ Loss function for dfsr model, modified from class VELoss. @@ -792,20 +890,19 @@ def __call__(self, net, images, labels, augment_pipe=None): class RegressionLossCE: """ - A regression loss function for the GEFS-HRRR model with probability channels, adapted - from RegressionLoss. In this version, probability channels are evaluated using - CrossEntropyLoss instead of MSELoss. - - Parameters + A regression loss function for deterministic predictions with probability + channels and lead time labels. Adapted from + :class:`physicsnemo.metrics.diffusion.loss.RegressionLoss`. In this version, + probability channels are evaluated using CrossEntropyLoss instead of + squared error. + Note: this loss does not apply any reduction. + + Attributes ---------- - P_mean: float, optional - Mean value for `sigma` computation, by default -1.2. - P_std: float, optional: - Standard deviation for `sigma` computation, by default 1.2. - sigma_data: float, optional - Standard deviation for data, by default 0.5. - prob_channels: list, optional - A index list of output probability channels. + entropy : torch.nn.CrossEntropyLoss + Cross entropy loss function used for probability channels. + prob_channels : list[int] + List of channel indices to be treated as probability channels. Note ---- @@ -817,62 +914,86 @@ class RegressionLossCE: def __init__( self, - P_mean: float = -1.2, - P_std: float = 1.2, - sigma_data: float = 0.5, - prob_channels: list = [4, 5, 6, 7, 8], + prob_channels: list[int] = [4, 5, 6, 7, 8], ): - self.P_mean = P_mean - self.P_std = P_std - self.sigma_data = sigma_data + """ + Arguments + ---------- + prob_channels: list[int], optional + List of channel indices from the target tensor to be treated as + probability channels. Cross entropy loss is computed over these + channels, while the remaining channels are treated as scalar + channels and the squared error loss is computed over them. By + default, [4, 5, 6, 7, 8]. + """ self.entropy = torch.nn.CrossEntropyLoss(reduction="none") self.prob_channels = prob_channels def __call__( self, - net, - img_clean, - img_lr, - lead_time_label=None, - labels=None, - augment_pipe=None, - ): + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + lead_time_label: Optional[torch.Tensor] = None, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + ) -> torch.Tensor: """ - Calculate and return the loss for the U-Net for deterministic predictions. + Calculate and return the loss for deterministic + predictions, treating specific channels as probability distributions. - Parameters: + Parameters ---------- - net: torch.nn.Module + net : torch.nn.Module The neural network model that will make predictions. + Expected signature: `net(input, img_lr, lead_time_label=lead_time_label, augment_labels=augment_labels)`, + where: + input (torch.Tensor): Tensor of shape (B, C_hr, H, W). Zero-filled. + y_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + lead_time_label (torch.Tensor, optional): Optional lead time + labels. If provided, should be of shape (B,). + augment_labels (torch.Tensor, optional): Optional augmentation + labels, returned by `augment_pipe`. + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if `augment_pipe` is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the neural network. + + lead_time_label : Optional[torch.Tensor], optional + Lead time labels for temporal predictions, by default None. + Shape can vary based on model requirements, typically (B,) or scalar. + + augment_pipe : Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]] + Data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution + images of shape (B, C_hr+C_lr, H, W). + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels - img_clean: torch.Tensor - Input images (high resolution) to the neural network. - - img_lr: torch.Tensor - Input images (low resolution) to the neural network. - - lead_time_label: torch.Tensor - Lead time labels for input batches. - - labels: torch.Tensor - Ground truth labels for the input images. - - augment_pipe: callable, optional - An optional data augmentation function that takes images as input and - returns augmented images. If not provided, no data augmentation is applied. - - Returns: + Returns ------- torch.Tensor - A tensor representing the loss calculated based on the network's - predictions. + A tensor of shape (B, C_loss, H, W) representing the pixel-wise + loss., where `C_loss = C_hr - len(prob_channels) + 1`. More + specifically, the last channel of the output tensor corresponds to + the cross-entropy loss computed over the channels specified in + `prob_channels`, while the first `C_hr - len(prob_channels)` + channels of the output tensor correspond to the squared error loss. """ all_channels = list(range(img_clean.shape[1])) # [0, 1, 2, ..., 10] scalar_channels = [ item for item in all_channels if item not in self.prob_channels ] - rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = ( 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 ) @@ -890,8 +1011,6 @@ def __call__( D_yn = net( input, y_lr, - sigma, - labels, lead_time_label=lead_time_label, augment_labels=augment_labels, ) @@ -899,11 +1018,10 @@ def __call__( D_yn = net( input, y_lr, - sigma, - labels, + lead_time_label=lead_time_label, augment_labels=augment_labels, ) - loss1 = weight * ((D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2) + loss1 = weight * (D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2 loss2 = ( weight * self.entropy(D_yn[:, self.prob_channels], y[:, self.prob_channels])[ @@ -911,4 +1029,4 @@ def __call__( ] ) loss = torch.cat((loss1, loss2), dim=1) - return loss + return loss \ No newline at end of file diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index 3ab4a6f..b00a477 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,6 +1,14 @@ -from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding +from .layers import ( + Linear, + Conv2d, + GroupNorm, + AttentionOp, + UNetBlock, + PositionalEmbedding, + FourierEmbedding +) from .meta import ModelMetaData from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd from .dhariwal_unet import DhariwalUNet from .unet import UNet -from .preconditioning import EDMPrecondSR, EDMPrecond +from .preconditioning import EDMPrecondSuperResolution, EDMPrecondSR, EDMPrecond diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index 8612da7..d7e63d7 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -19,15 +19,27 @@ Diffusion-Based Generative Models". """ +import contextlib +import importlib from typing import Any, Dict, List import numpy as np +import nvtx import torch +import torch.cuda.amp as amp from einops import rearrange -from torch.nn.functional import silu +from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh from hirad.utils.model_utils import weight_init +_is_apex_available = False +if torch.cuda.is_available(): + try: + apex_gn_module = importlib.import_module("apex.contrib.group_norm") + ApexGroupNorm = getattr(apex_gn_module, "GroupNorm") + _is_apex_available = True + except ImportError: + pass class Linear(torch.nn.Module): """ @@ -56,6 +68,8 @@ class Linear(torch.nn.Module): A scaling factor to multiply with the initialized weights. By default 1. init_bias : float, optional A scaling factor to multiply with the initialized biases. By default 0. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( @@ -66,10 +80,12 @@ def __init__( init_mode: str = "kaiming_normal", init_weight: int = 1, init_bias: int = 0, + amp_mode: bool = False, ): super().__init__() self.in_features = in_features self.out_features = out_features + self.amp_mode = amp_mode init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) self.weight = torch.nn.Parameter( weight_init([out_features, in_features], **init_kwargs) * init_weight @@ -81,9 +97,16 @@ def __init__( ) def forward(self, x): - x = x @ self.weight.to(x.dtype).t() + weight, bias = self.weight, self.bias + # pdb.set_trace() + if not self.amp_mode: + if self.weight is not None and self.weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + x = x @ weight.t() if self.bias is not None: - x = x.add_(self.bias.to(x.dtype)) + x = x.add_(bias) return x @@ -128,6 +151,10 @@ class Conv2d(torch.nn.Module): A scaling factor to multiply with the initialized weights. By default 1.0. init_bias : float, optional A scaling factor to multiply with the initialized biases. By default 0.0. + fused_conv_bias: bool, optional + A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( @@ -143,9 +170,16 @@ def __init__( init_mode: str = "kaiming_normal", init_weight: float = 1.0, init_bias: float = 0.0, + fused_conv_bias: bool = False, + amp_mode: bool = False, ): if up and down: raise ValueError("Both 'up' and 'down' cannot be true at the same time.") + if not kernel and fused_conv_bias: + print( + "Warning: Kernel is required when fused_conv_bias is enabled. Setting fused_conv_bias to False." + ) + fused_conv_bias = False super().__init__() self.in_channels = in_channels @@ -153,6 +187,8 @@ def __init__( self.up = up self.down = down self.fused_resample = fused_resample + self.fused_conv_bias = fused_conv_bias + self.amp_mode = amp_mode init_kwargs = dict( mode=init_mode, fan_in=in_channels * kernel * kernel, @@ -176,13 +212,21 @@ def __init__( self.register_buffer("resample_filter", f if up or down else None) def forward(self, x): - w = self.weight.to(x.dtype) if self.weight is not None else None - b = self.bias.to(x.dtype) if self.bias is not None else None - f = ( - self.resample_filter.to(x.dtype) - if self.resample_filter is not None - else None - ) + weight, bias, resample_filter = self.weight, self.bias, self.resample_filter + if not self.amp_mode: + if self.weight is not None and self.weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + if ( + self.resample_filter is not None + and self.resample_filter.dtype != x.dtype + ): + resample_filter = self.resample_filter.to(x.dtype) + + w = weight if weight is not None else None + b = bias if bias is not None else None + f = resample_filter if resample_filter is not None else None w_pad = w.shape[-1] // 2 if w is not None else 0 f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 @@ -194,15 +238,29 @@ def forward(self, x): stride=2, padding=max(f_pad - w_pad, 0), ) - x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) + if self.fused_conv_bias: + x = torch.nn.functional.conv2d( + x, w, padding=max(w_pad - f_pad, 0), bias=b + ) + else: + x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) elif self.fused_resample and self.down and w is not None: x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) - x = torch.nn.functional.conv2d( - x, - f.tile([self.out_channels, 1, 1, 1]), - groups=self.out_channels, - stride=2, - ) + if self.fused_conv_bias: + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + bias=b, + ) + else: + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + ) else: if self.up: x = torch.nn.functional.conv_transpose2d( @@ -220,11 +278,15 @@ def forward(self, x): stride=2, padding=f_pad, ) - if w is not None: - #TODO during inference, model breaks here for some reason - # current fix is to disable torch.backends.cudnn.enabled = False - x = torch.nn.functional.conv2d(x, w, padding=w_pad) - if b is not None: + + #TODO during inference, model breaks here for some reason + # current fix is to disable torch.backends.cudnn.enabled = False + if w is not None: # ask in corrdiff channel whether w will ever be none + if self.fused_conv_bias: + x = torch.nn.functional.conv2d(x, w, padding=w_pad, bias=b) + else: + x = torch.nn.functional.conv2d(x, w, padding=w_pad) + if b is not None and not self.fused_conv_bias: x = x.add_(b.reshape(1, -1, 1, 1)) return x @@ -251,7 +313,15 @@ class GroupNorm(torch.nn.Module): eps : float, optional A small number added to the variance to prevent division by zero, by default 1e-5. - + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + fused_act : bool, optional + Whether to fuse the activation function with GroupNorm. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. Notes ----- If `num_channels` is not divisible by `num_groups`, the actual number of groups @@ -264,28 +334,71 @@ def __init__( num_groups: int = 32, min_channels_per_group: int = 4, eps: float = 1e-5, + use_apex_gn: bool = False, + fused_act: bool = False, + act: str = None, + amp_mode: bool = False, ): + if fused_act and act is None: + raise ValueError("'act' must be specified when 'fused_act' is set to True.") + super().__init__() self.num_groups = min(num_groups, num_channels // min_channels_per_group) self.eps = eps self.weight = torch.nn.Parameter(torch.ones(num_channels)) self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + if use_apex_gn and not _is_apex_available: + raise ValueError("'apex' is not installed, set `use_apex_gn=False`") + self.use_apex_gn = use_apex_gn + self.fused_act = fused_act + self.act = act.lower() if act else act + self.act_fn = None + self.amp_mode = amp_mode + if self.use_apex_gn: + if self.act: + self.gn = ApexGroupNorm( + num_groups=self.num_groups, + num_channels=num_channels, + eps=self.eps, + affine=True, + act=self.act, + ) + + else: + self.gn = ApexGroupNorm( + num_groups=self.num_groups, + num_channels=num_channels, + eps=self.eps, + affine=True, + ) + if self.fused_act: + self.act_fn = self.get_activation_function() def forward(self, x): - if self.training: + weight, bias = self.weight, self.bias + if not self.amp_mode: + if not self.use_apex_gn: + if weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + if self.use_apex_gn: + x = self.gn(x) + elif self.training: # Use default torch implementation of GroupNorm for training # This does not support channels last memory format x = torch.nn.functional.group_norm( x, num_groups=self.num_groups, - weight=self.weight.to(x.dtype), - bias=self.bias.to(x.dtype), + weight=weight, + bias=bias, eps=self.eps, ) + if self.fused_act: + x = self.act_fn(x) else: # Use custom GroupNorm implementation that supports channels last # memory layout for inference - dtype = x.dtype x = x.float() x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups) @@ -295,12 +408,33 @@ def forward(self, x): x = (x - mean) * (var + self.eps).rsqrt() x = rearrange(x, "b g c h w -> b (g c) h w") - weight = rearrange(self.weight, "c -> 1 c 1 1") - bias = rearrange(self.bias, "c -> 1 c 1 1") + weight = rearrange(weight, "c -> 1 c 1 1") + bias = rearrange(bias, "c -> 1 c 1 1") x = x * weight + bias - x = x.type(dtype) + if self.fused_act: + x = self.act_fn(x) return x + + def get_activation_function(self): + """ + Get activation function given string input + """ + + activation_map = { + "silu": silu, + "relu": relu, + "leaky_relu": leaky_relu, + "sigmoid": sigmoid, + "tanh": tanh, + "gelu": gelu, + "elu": elu, + } + + act_fn = activation_map.get(self.act, None) + if act_fn is None: + raise ValueError(f"Unknown activation function: {self.act}") + return act_fn class AttentionOp(torch.autograd.Function): @@ -333,6 +467,7 @@ def backward(ctx, dw): dim=2, input_dtype=torch.float32, ) + dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( q.dtype ) / np.sqrt(k.shape[1]) @@ -385,6 +520,17 @@ class UNetBlock(torch.nn.Module): init_attn : dict, optional Initialization parameters specific to attention mechanism layers. Defaults to 'init' if not provided. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + fused_conv_bias: bool, optional + A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( @@ -406,6 +552,11 @@ def __init__( init: Dict[str, Any] = dict(), init_zero: Dict[str, Any] = dict(init_weight=0), init_attn: Any = None, + use_apex_gn: bool = False, + act: str = "silu", + fused_conv_bias: bool = False, + profile_mode: bool = False, + amp_mode: bool = False, ): super().__init__() @@ -423,7 +574,16 @@ def __init__( self.skip_scale = skip_scale self.adaptive_scale = adaptive_scale - self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) + self.profile_mode = profile_mode + self.amp_mode = amp_mode + self.norm0 = GroupNorm( + num_channels=in_channels, + eps=eps, + use_apex_gn=use_apex_gn, + fused_act=True, + act=act, + amp_mode=amp_mode, + ) self.conv0 = Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -431,21 +591,45 @@ def __init__( up=up, down=down, resample_filter=resample_filter, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **init, ) self.affine = Linear( in_features=emb_channels, out_features=out_channels * (2 if adaptive_scale else 1), + amp_mode=amp_mode, **init, ) - self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) + if self.adaptive_scale: + self.norm1 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, + ) + else: + self.norm1 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + act=act, + fused_act=True, + amp_mode=amp_mode, + ) self.conv1 = Conv2d( - in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero + in_channels=out_channels, + out_channels=out_channels, + kernel=3, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, + **init_zero, ) self.skip = None if out_channels != in_channels or up or down: kernel = 1 if resample_proj or out_channels != in_channels else 0 + fused_conv_bias = fused_conv_bias if kernel != 0 else False self.skip = Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -453,55 +637,75 @@ def __init__( up=up, down=down, resample_filter=resample_filter, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **init, ) if self.num_heads: - self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) + self.norm2 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, + ) self.qkv = Conv2d( in_channels=out_channels, out_channels=out_channels * 3, kernel=1, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **(init_attn if init_attn is not None else init), ) self.proj = Conv2d( in_channels=out_channels, out_channels=out_channels, kernel=1, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **init_zero, ) def forward(self, x, emb): - torch.cuda.nvtx.range_push("UNetBlock") - orig = x - x = self.conv0(silu(self.norm0(x))) - params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) - if self.adaptive_scale: - scale, shift = params.chunk(chunks=2, dim=1) - x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) - else: - x = silu(self.norm1(x.add_(params))) - - x = self.conv1( - torch.nn.functional.dropout(x, p=self.dropout, training=self.training) - ) - x = x.add_(self.skip(orig) if self.skip is not None else orig) - x = x * self.skip_scale - - if self.num_heads: - q, k, v = ( - self.qkv(self.norm2(x)) - .reshape( - x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 - ) - .unbind(2) + with nvtx.annotate( + message="UNetBlock", color="purple" + ) if self.profile_mode else contextlib.nullcontext(): + orig = x + x = self.conv0(self.norm0(x)) + params = self.affine(emb).unsqueeze(2).unsqueeze(3) + if not self.amp_mode: + if params.dtype != x.dtype: + params = params.to(x.dtype) + + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = self.norm1(x.add_(params)) + + x = self.conv1( + torch.nn.functional.dropout(x, p=self.dropout, training=self.training) ) - w = AttentionOp.apply(q, k) - a = torch.einsum("nqk,nck->ncq", w, v) - x = self.proj(a.reshape(*x.shape)).add_(x) + x = x.add_(self.skip(orig) if self.skip is not None else orig) x = x * self.skip_scale - torch.cuda.nvtx.range_pop() - return x + + if self.num_heads: + q, k, v = ( + self.qkv(self.norm2(x)) + .reshape( + x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1 + ) + .unbind(3) + ) + # w = AttentionOp.apply(q, k) + # a = torch.einsum("nqk,nck->ncq", w, v) + # Compute attention in one step + with amp.autocast(enabled=self.amp_mode): + attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = self.proj(attn.reshape(*x.shape)).add_(x) + x = x * self.skip_scale + + return x class PositionalEmbedding(torch.nn.Module): @@ -517,16 +721,23 @@ class PositionalEmbedding(torch.nn.Module): Maximum number of positions for the embeddings, by default 10000. endpoint : bool, optional If True, the embedding considers the endpoint. By default False. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( - self, num_channels: int, max_positions: int = 10000, endpoint: bool = False + self, + num_channels: int, + max_positions: int = 10000, + endpoint: bool = False, + amp_mode: bool = False, ): super().__init__() self.num_channels = num_channels self.max_positions = max_positions self.endpoint = endpoint + self.amp_mode = amp_mode def forward(self, x): freqs = torch.arange( @@ -534,7 +745,10 @@ def forward(self, x): ) freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) freqs = (1 / self.max_positions) ** freqs - x = x.ger(freqs.to(x.dtype)) + if not self.amp_mode: + if freqs.dtype != x.dtype: + freqs = freqs.to(x.dtype) + x = x.ger(freqs) x = torch.cat([x.cos(), x.sin()], dim=1) return x @@ -556,13 +770,21 @@ class FourierEmbedding(torch.nn.Module): scale : int, optional A scale factor applied to the random frequencies, controlling their range and thereby the frequency of oscillations in the embedding space. By default 16. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ - def __init__(self, num_channels: int, scale: int = 16): + def __init__(self, num_channels: int, scale: int = 16, amp_mode: bool = False): super().__init__() self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) + self.amp_mode = amp_mode def forward(self, x): - x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) + freqs = self.freqs + if not self.amp_mode: + if x.dtype != self.freqs.dtype: + freqs = self.freqs.to(x.dtype) + + x = x.ger((2 * np.pi * freqs)) x = torch.cat([x.cos(), x.sin()], dim=1) return x diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py index c66b6b6..74496a5 100644 --- a/src/hirad/models/preconditioning.py +++ b/src/hirad/models/preconditioning.py @@ -22,19 +22,12 @@ import importlib import warnings from dataclasses import dataclass -from typing import List, Union +from typing import List, Literal, Tuple, Union import numpy as np -import nvtx import torch import torch.nn as nn -from .song_unet import ( - SongUNet, # noqa: F401 for globals -) -from .dhariwal_unet import ( - DhariwalUNet, # noqa: F401 for globals -) from .meta import ModelMetaData network_module = importlib.import_module("hirad.models") @@ -694,12 +687,11 @@ def round_sigma(sigma: Union[float, List, torch.Tensor]): """ return torch.as_tensor(sigma) - @dataclass -class EDMPrecondSRMetaData(ModelMetaData): +class EDMPrecondSuperResolutionMetaData(ModelMetaData): """EDMPrecondSR meta data""" - name: str = "EDMPrecondSR" + name: str = "EDMPrecondSuperResolution" # Optimization jit: bool = False cuda_graphs: bool = False @@ -715,33 +707,40 @@ class EDMPrecondSRMetaData(ModelMetaData): auto_grad: bool = False -class EDMPrecondSR(nn.Module): +class EDMPrecondSuperResolution(nn.Module): """ Improved preconditioning proposed in the paper "Elucidating the Design Space of - Diffusion-Based Generative Models" (EDM) for super-resolution tasks + Diffusion-Based Generative Models" (EDM). + + This is a variant of `EDMPrecond` that is specifically designed for super-resolution + tasks. It wraps a neural network that predicts the denoised high-resolution image + given a noisy high-resolution image, and additional conditioning that includes a + low-resolution image, and a noise level. Parameters ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. + img_resolution : Union[int, Tuple[int, int]] + Spatial resolution `(H, W)` of the image. If a single int is provided, + the image is assumed to be square. img_in_channels : int - Number of input color channels. + Number of input channels in the low-resolution input image. img_out_channels : int - Number of output color channels. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float + Number of output channels in the high-resolution output image. + use_fp16 : bool, optional + Whether to use half-precision floating point (FP16) for model execution, + by default False. + model_type : str, optional + Class name of the underlying model. Must be one of the following: + 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. + Defaults to 'SongUNetPosEmbd'. + sigma_data : float, optional + Expected standard deviation of the training data, by default 0.5. + sigma_min : float, optional Minimum supported noise level, by default 0.0. - sigma_max : float + sigma_max : float, optional Maximum supported noise level, by default inf. - sigma_data : float - Expected standard deviation of the training data, by default 0.5. - model_type :str - Class name of the underlying model, by default "SongUNetPosEmbd". **model_kwargs : dict - Keyword arguments for the underlying model. + Keyword arguments passed to the underlying model `__init__` method. Note ---- @@ -757,28 +756,26 @@ class EDMPrecondSR(nn.Module): def __init__( self, - img_resolution, - img_channels, - img_in_channels, - img_out_channels, - use_fp16=False, + img_resolution: Union[int, Tuple[int, int]], + img_in_channels: int, + img_out_channels: int, + use_fp16: bool = False, + model_type: Literal[ + "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" + ] = "SongUNetPosEmbd", + sigma_data: float = 0.5, sigma_min=0.0, sigma_max=float("inf"), - sigma_data=0.5, - model_type="SongUNetPosEmbd", - scale_cond_input=True, - **model_kwargs, + **model_kwargs: dict, ): super().__init__() #meta=EDMPrecondSRMetaData self.img_resolution = img_resolution - self.img_channels = img_channels # TODO: this is not used, remove it self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels self.use_fp16 = use_fp16 + self.sigma_data = sigma_data self.sigma_min = sigma_min self.sigma_max = sigma_max - self.sigma_data = sigma_data - self.scale_cond_input = scale_cond_input model_class = getattr(network_module, model_type) self.model = model_class( @@ -787,39 +784,73 @@ def __init__( out_channels=img_out_channels, **model_kwargs, ) # TODO needs better handling - self.scaling_fn = self._get_scaling_fn() - - def _get_scaling_fn(self): - if self.scale_cond_input: - warnings.warn( - "scale_cond_input=True does not properly scale the conditional input. " - "(see https://github.com/NVIDIA/modulus/issues/229). " - "This setup will be deprecated. " - "Please set scale_cond_input=False.", - DeprecationWarning, - ) - return self._legacy_scaling_fn - else: - return self._scaling_fn + self.scaling_fn = self._scaling_fn @staticmethod - def _scaling_fn(x, img_lr, c_in): - return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) + def _scaling_fn( + x: torch.Tensor, img_lr: torch.Tensor, c_in: torch.Tensor + ) -> torch.Tensor: + """ + Scale input tensors by first scaling the high-resolution tensor and then + concatenating with the low-resolution tensor. - @staticmethod - def _legacy_scaling_fn(x, img_lr, c_in): - return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1) + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution image of shape (B, C_lr, H, W). + c_in : torch.Tensor + Scaling factor of shape (B, 1, 1, 1). + + Returns + ------- + torch.Tensor + Scaled and concatenated tensor of shape (B, C_in+C_out, H, W). + """ + return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) - @nvtx.annotate(message="EDMPrecondSR", color="orange") def forward( self, - x, - img_lr, - sigma, - force_fp32=False, - **model_kwargs, - ): - # Concatenate input channels + x: torch.Tensor, + img_lr: torch.Tensor, + sigma: torch.Tensor, + force_fp32: bool = False, + **model_kwargs: dict, + ) -> torch.Tensor: + """ + Forward pass of the EDMPrecondSuperResolution model wrapper. + + This method applies the EDM preconditioning to compute the denoised image + from a noisy high-resolution image and low-resolution conditioning image. + + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). The number of + channels `C_hr` should be equal to `img_out_channels`. + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). The number + of channels `C_lr` should be equal to `img_in_channels`. + sigma : torch.Tensor + Noise level of shape (B) or (B, 1) or (B, 1, 1, 1). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model + `self.model` forward method. + + Returns + ------- + torch.Tensor + Denoised high-resolution image of shape (B, C_hr, H, W). + + Raises + ------ + ValueError + If the model output dtype doesn't match the expected dtype. + """ x = x.to(torch.float32) sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) dtype = ( @@ -855,13 +886,190 @@ def forward( return D_x @staticmethod - def round_sigma(sigma: Union[float, List, torch.Tensor]): + def round_sigma(sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: """ Convert a given sigma value(s) to a tensor representation. - See EDMPrecond.round_sigma + + Parameters + ---------- + sigma : Union[float, List, torch.Tensor] + Sigma value(s) to convert. + + Returns + ------- + torch.Tensor + Tensor representation of sigma values. + + See Also + -------- + EDMPrecond.round_sigma """ return EDMPrecond.round_sigma(sigma) + @property + def amp_mode(self): + """ + Return the *amp_mode* flag of the wrapped model or *None*. + """ + return getattr(self.model, "amp_mode", None) + + @amp_mode.setter + def amp_mode(self, value: bool): + """ + Propagate *amp_mode* to the model and all its sub-modules. + """ + + if not isinstance(value, bool): + raise TypeError("amp_mode must be a boolean value.") + + if hasattr(self.model, "amp_mode"): + self.model.amp_mode = value + + for sub_module in self.model.modules(): + if hasattr(sub_module, "amp_mode"): + sub_module.amp_mode = value + +# NOTE: This is a deprecated version of the EDMPrecondSuperResolution model. +# This was used to maintain backwards compatibility and allow loading old models. +@dataclass +class EDMPrecondSRMetaData(ModelMetaData): + """EDMPrecondSR meta data""" + + name: str = "EDMPrecondSR" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecondSR(EDMPrecondSuperResolution): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) for super-resolution tasks + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "SongUNetPosEmbd". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + References: + - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, #deprecated + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNetPosEmbd", + scale_cond_input=True, #deprecated + **model_kwargs, + ): + warnings.warn( + "EDMPrecondSR is deprecated and will be removed in a future version. " + "Please use EDMPrecondSuperResolution instead.", + DeprecationWarning, + stacklevel=2, + ) + + if scale_cond_input: + warnings.warn( + "scale_cond_input=True does not properly scale the conditional input. " + "(see https://github.com/NVIDIA/modulus/issues/229). " + "This setup will be deprecated. " + "Please set scale_cond_input=False.", + DeprecationWarning, + ) + + super().__init__( + img_resolution=img_resolution, + img_in_channels=img_in_channels, + img_out_channels=img_out_channels, + use_fp16=use_fp16, + sigma_min=sigma_min, + sigma_max=sigma_max, + sigma_data=sigma_data, + model_type=model_type, + **model_kwargs, + ) + + # Store deprecated parameters for backward compatibility + self.img_channels = img_channels + self.scale_cond_input = scale_cond_input + + def forward( + self, + x, + img_lr, + sigma, + force_fp32=False, + **model_kwargs, + ): + """ + Forward pass of the EDMPrecondSR model wrapper. + + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). + sigma : torch.Tensor + Noise level of shape (B) or (B, 1) or (B, 1, 1, 1). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model. + + Returns + ------- + torch.Tensor + Denoised high-resolution image of shape (B, C_hr, H, W). + """ + return super().forward( + x=x, img_lr=img_lr, sigma=sigma, force_fp32=force_fp32, **model_kwargs + ) class VEPrecond_dfsr(nn.Module): """ @@ -912,7 +1120,8 @@ def __init__( self.img_channels = img_channels self.label_dim = label_dim self.use_fp16 = use_fp16 - self.model = globals()[model_type]( + model_class = getattr(network_module, model_type) + self.model = model_class( img_resolution=img_resolution, in_channels=self.img_channels, out_channels=img_channels, @@ -1011,7 +1220,8 @@ def __init__( self.img_channels = img_channels self.label_dim = label_dim self.use_fp16 = use_fp16 - self.model = globals()[model_type]( + model_class = getattr(network_module, model_type) + self.model = model_class( img_resolution=img_resolution, in_channels=model_kwargs["model_channels"] * 2, out_channels=img_channels, diff --git a/src/hirad/models/song_unet.py b/src/hirad/models/song_unet.py index 6267dfc..a56f861 100644 --- a/src/hirad/models/song_unet.py +++ b/src/hirad/models/song_unet.py @@ -19,8 +19,9 @@ Diffusion-Based Generative Models". """ +import contextlib from dataclasses import dataclass -from typing import List, Union +from typing import Callable, List, Optional, Union import numpy as np import nvtx @@ -71,7 +72,8 @@ class SongUNet(nn.Module): Parameters ----------- img_resolution : Union[List[int], int] - The resolution of the input/output image, 1 value represents a square image. + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int @@ -81,7 +83,7 @@ class SongUNet(nn.Module): augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional - Base multiplier for the number of channels across the network, by default 128. + Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional Per-resolution multipliers for the number of channels. By default [1,2,2,2]. channel_mult_emb : int, optional @@ -93,29 +95,39 @@ class SongUNet(nn.Module): dropout : float, optional Dropout probability applied to intermediate activations. By default 0.10. label_dropout : float, optional - Dropout probability of class labels for classifier-free guidance. By default 0.0. + Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional - Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional - Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default - 'standard'. + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. , 'skip' for skip connections. + By default 'standard'. decoder_type : str, optional - Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default - 'standard'. - resample_filter : List[int], optional (default=[1,1]) - Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - checkpoint_level : int, optional (default=0) - How many layers should use gradient checkpointing, 0 is None - additive_pos_embed: bool = False, - Set to True to add a learned position embedding after the first conv (used in StormCast) + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. Reference ---------- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and Poole, B., 2020. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456. @@ -156,6 +168,10 @@ def __init__( resample_filter: List[int] = [1, 1], checkpoint_level: int = 0, additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, ): valid_embedding_types = ["fourier", "positional", "zero"] if embedding_type not in valid_embedding_types: @@ -196,7 +212,14 @@ def __init__( init=init, init_zero=init_zero, init_attn=init_attn, + use_apex_gn=use_apex_gn, + act=act, + fused_conv_bias=True, + profile_mode=profile_mode, + amp_mode=amp_mode, ) + self.profile_mode = profile_mode + self.amp_mode = amp_mode # for compatibility with older versions that took only 1 dimension self.img_resolution = img_resolution @@ -220,12 +243,19 @@ def __init__( # Mapping. if self.embedding_type != "zero": self.map_noise = ( - PositionalEmbedding(num_channels=noise_channels, endpoint=True) + PositionalEmbedding( + num_channels=noise_channels, endpoint=True, amp_mode=amp_mode + ) if embedding_type == "positional" - else FourierEmbedding(num_channels=noise_channels) + else FourierEmbedding(num_channels=noise_channels, amp_mode=amp_mode) ) self.map_label = ( - Linear(in_features=label_dim, out_features=noise_channels, **init) + Linear( + in_features=label_dim, + out_features=noise_channels, + amp_mode=amp_mode, + **init, + ) if label_dim else None ) @@ -234,16 +264,23 @@ def __init__( in_features=augment_dim, out_features=noise_channels, bias=False, + amp_mode=amp_mode, **init, ) if augment_dim else None ) self.map_layer0 = Linear( - in_features=noise_channels, out_features=emb_channels, **init + in_features=noise_channels, + out_features=emb_channels, + amp_mode=amp_mode, + **init, ) self.map_layer1 = Linear( - in_features=emb_channels, out_features=emb_channels, **init + in_features=emb_channels, + out_features=emb_channels, + amp_mode=amp_mode, + **init, ) # Encoder. @@ -256,7 +293,12 @@ def __init__( cin = cout cout = model_channels self.enc[f"{res}x{res}_conv"] = Conv2d( - in_channels=cin, out_channels=cout, kernel=3, **init + in_channels=cin, + out_channels=cout, + kernel=3, + fused_conv_bias=True, + amp_mode=amp_mode, + **init, ) else: self.enc[f"{res}x{res}_down"] = UNetBlock( @@ -269,9 +311,15 @@ def __init__( kernel=0, down=True, resample_filter=resample_filter, + amp_mode=amp_mode, ) self.enc[f"{res}x{res}_aux_skip"] = Conv2d( - in_channels=caux, out_channels=cout, kernel=1, **init + in_channels=caux, + out_channels=cout, + kernel=1, + fused_conv_bias=True, + amp_mode=amp_mode, + **init, ) if encoder_type == "residual": self.enc[f"{res}x{res}_aux_residual"] = Conv2d( @@ -281,6 +329,8 @@ def __init__( down=True, resample_filter=resample_filter, fused_resample=True, + fused_conv_bias=True, + amp_mode=amp_mode, **init, ) caux = cout @@ -325,107 +375,138 @@ def __init__( kernel=0, up=True, resample_filter=resample_filter, + amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( - num_channels=cout, eps=1e-6 + num_channels=cout, + eps=1e-6, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_conv"] = Conv2d( - in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + in_channels=cout, + out_channels=out_channels, + kernel=3, + fused_conv_bias=True, + amp_mode=amp_mode, + **init_zero, ) - @nvtx.annotate(message="SongUNet", color="blue") def forward(self, x, noise_labels, class_labels, augment_labels=None): - if self.embedding_type != "zero": - # Mapping. - emb = self.map_noise(noise_labels) - emb = ( - emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) - ) # swap sin/cos - if self.map_label is not None: - tmp = class_labels - if self.training and self.label_dropout: - tmp = tmp * ( - torch.rand([x.shape[0], 1], device=x.device) - >= self.label_dropout - ).to(tmp.dtype) - emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) - if self.map_augment is not None and augment_labels is not None: - emb = emb + self.map_augment(augment_labels) - emb = silu(self.map_layer0(emb)) - emb = silu(self.map_layer1(emb)) - else: - emb = torch.zeros( - (noise_labels.shape[0], self.emb_channels), device=x.device - ) + with nvtx.annotate( + message="SongUNet", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if self.embedding_type != "zero": + # Mapping. + emb = self.map_noise(noise_labels) + emb = ( + emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) + ) # swap sin/cos + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) + >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label( + tmp * np.sqrt(self.map_label.in_features) + ) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = silu(self.map_layer1(emb)) + else: + emb = torch.zeros( + (noise_labels.shape[0], self.emb_channels), device=x.device + ) - # Encoder. - skips = [] - aux = x - for name, block in self.enc.items(): - with nvtx.annotate(f"SongUNet encoder: {name}", color="blue"): - if "aux_down" in name: - aux = block(aux) - elif "aux_skip" in name: - x = skips[-1] = x + block(aux) - elif "aux_residual" in name: - x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) - elif "_conv" in name: - x = block(x) - if self.additive_pos_embed: - x = x + self.spatial_emb.to(dtype=x.dtype) - skips.append(x) - else: - # For UNetBlocks check if we should use gradient checkpointing - if isinstance(block, UNetBlock): - if x.shape[-1] > self.checkpoint_threshold: - x = checkpoint(block, x, emb, use_reentrant=False) - else: - x = block(x, emb) - else: + # Encoder. + skips = [] + aux = x + for name, block in self.enc.items(): + with nvtx.annotate( + f"SongUNet encoder: {name}", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if "aux_down" in name: + aux = block(aux) + elif "aux_skip" in name: + x = skips[-1] = x + block(aux) + elif "aux_residual" in name: + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + elif "_conv" in name: x = block(x) - skips.append(x) + if self.additive_pos_embed: + x = x + self.spatial_emb.to(dtype=x.dtype) + skips.append(x) + else: + # For UNetBlocks check if we should use gradient checkpointing + if isinstance(block, UNetBlock): + if x.shape[-1] > self.checkpoint_threshold: + # self.checkpoint = checkpoint? + # else: self.checkpoint = lambda(block,x,emb:block(x,emb)) + x = checkpoint(block, x, emb, use_reentrant=False) + else: + # AssertionError: Only support NHWC layout. + x = block(x, emb) + else: + x = block(x) + skips.append(x) - # Decoder. - aux = None - tmp = None - for name, block in self.dec.items(): - with nvtx.annotate(f"SongUNet decoder: {name}", color="blue"): - if "aux_up" in name: - aux = block(aux) - elif "aux_norm" in name: - tmp = block(x) - elif "aux_conv" in name: - tmp = block(silu(tmp)) - aux = tmp if aux is None else tmp + aux - else: - if x.shape[1] != block.in_channels: - x = torch.cat([x, skips.pop()], dim=1) - # check for checkpointing on decoder blocks and up sampling blocks - if ( - x.shape[-1] > self.checkpoint_threshold and "_block" in name - ) or ( - x.shape[-1] > (self.checkpoint_threshold / 2) and "_up" in name - ): - x = checkpoint(block, x, emb, use_reentrant=False) + # Decoder. + aux = None + tmp = None + for name, block in self.dec.items(): + with nvtx.annotate( + f"SongUNet decoder: {name}", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if "aux_up" in name: + aux = block(aux) + elif "aux_norm" in name: + tmp = block(x) + elif "aux_conv" in name: + tmp = block(silu(tmp)) + aux = tmp if aux is None else tmp + aux else: - x = block(x, emb) - return aux + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + # check for checkpointing on decoder blocks and up sampling blocks + if ( + x.shape[-1] > self.checkpoint_threshold and "_block" in name + ) or ( + x.shape[-1] > (self.checkpoint_threshold / 2) + and "_up" in name + ): + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + return aux class SongUNetPosEmbd(SongUNet): - """ - Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with - optional self-attention,embeddings, and encoder-decoder components. + """Extends SongUNet with positional embeddings. This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations. + This model adds positional embeddings to the base SongUNet architecture. The embeddings + can be selected using either a selector function or global indices, with the selector + approach being more computationally efficient. + + The model provides two methods for selecting positional embeddings: + + 1. Using a selector function (preferred method). See + :meth:`positional_embedding_selector` for details. + 2. Using global indices. See :meth:`positional_embedding_indexing` for + details. + Parameters ----------- img_resolution : Union[List[int], int] - The resolution of the input/output image, 1 value represents a square image. + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int @@ -435,39 +516,63 @@ class SongUNetPosEmbd(SongUNet): augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional - Base multiplier for the number of channels across the network, by default 128. + Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional - Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 4. attn_resolutions : List[int], optional - Resolutions at which self-attention layers are applied. By default [16]. + Resolutions at which self-attention layers are applied. By default [28]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.13. label_dropout : float, optional - Dropout probability of class labels for classifier-free guidance. By default 0.0. + Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional - Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default - 'standard'. + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. , 'skip' for skip connections. + By default'standard'. decoder_type : str, optional - Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default - 'standard'. - resample_filter : List[int], optional (default=[1,1]) - Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - - - Reference - ---------- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + gridtype : str, optional + Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. + Controls how positional information is encoded. By default 'sinusoidal'. + N_grid_channels : int, optional + Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or + multiple of 4. For 'linear' must be 2. By default 4. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + lead_time_mode : bool, optional + A boolean flag indicating whether we are running SongUNet with lead time embedding. Defaults to False. + lead_time_channels : int, optional + Number of channels in the lead time embedding. These are learned embeddings that + encode temporal forecast information. By default None. + lead_time_steps : int, optional + Number of discrete lead time steps to support. Each step gets its own learned + embedding vector. By default 9. + prob_channels : List[int], optional + Indices of probability output channels that should use softmax activation. + Used for classification outputs. By default empty list. Note ----- @@ -476,13 +581,41 @@ class SongUNetPosEmbd(SongUNet): Example -------- - >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> import torch + >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosEmbd + >>> from physicsnemo.utils.patching import GridPatching2D + >>> + >>> # Model initialization - in_channels must include both original input channels (2) + >>> # and the positional embedding channels (N_grid_channels=4 by default) + >>> model = SongUNetPosEmbd(img_resolution=16, in_channels=2+4, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> # The input has only the original 2 channels - positional embeddings are + >>> # added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16]) + >>> + >>> # Using a global index to select all positional embeddings + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) + >>> global_index = patching.global_index(batch_size=1) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... global_index=global_index + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using a custom embedding selector to select all positional embeddings + >>> def patch_embedding_selector(emb): + ... return patching.apply(emb[None].expand(1, -1, -1, -1)) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... embedding_selector=patch_embedding_selector + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) """ def __init__( @@ -507,6 +640,15 @@ def __init__( gridtype: str = "sinusoidal", N_grid_channels: int = 4, checkpoint_level: int = 0, + additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, + lead_time_mode: bool = False, + lead_time_channels: int = None, + lead_time_steps: int = 9, + prob_channels: List[int] = [], ): super().__init__( img_resolution, @@ -527,49 +669,286 @@ def __init__( decoder_type, resample_filter, checkpoint_level, + additive_pos_embed, + use_apex_gn, + act, + profile_mode, + amp_mode, ) self.gridtype = gridtype self.N_grid_channels = N_grid_channels - self.pos_embd = self._get_positional_embedding() + if self.gridtype == "learnable": + self.pos_embd = self._get_positional_embedding() + else: + self.register_buffer("pos_embd", self._get_positional_embedding().float()) + self.lead_time_mode = lead_time_mode + if self.lead_time_mode: + self.lead_time_channels = lead_time_channels + self.lead_time_steps = lead_time_steps + self.lt_embd = self._get_lead_time_embedding() + self.prob_channels = prob_channels + if self.prob_channels: + self.scalar = torch.nn.Parameter( + torch.ones((1, len(self.prob_channels), 1, 1)) + ) - @nvtx.annotate(message="SongUNet", color="blue") def forward( - self, x, noise_labels, class_labels, global_index=None, augment_labels=None + self, + x, + noise_labels, + class_labels, + global_index: Optional[torch.Tensor] = None, + embedding_selector: Optional[Callable] = None, + augment_labels=None, + lead_time_label=None, ): - # append positional embedding to input conditioning - if self.pos_embd is not None: - selected_pos_embd = self.positional_embedding_indexing(x, global_index) - x = torch.cat((x, selected_pos_embd), dim=1) + with nvtx.annotate( + message="SongUNetPosEmbd", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if embedding_selector is not None and global_index is not None: + raise ValueError( + "Cannot provide both embedding_selector and global_index. " + "embedding_selector is the preferred approach for better efficiency." + ) + + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) + + # Append positional embedding to input conditioning + if self.pos_embd is not None: + # Select positional embeddings with a selector function + if embedding_selector is not None: + selected_pos_embd = self.positional_embedding_selector( + x, embedding_selector + ) + # Select positional embeddings using global indices (selects all + # embeddings if global_index is None) + else: + selected_pos_embd = self.positional_embedding_indexing( + x, global_index=global_index, lead_time_label=lead_time_label + ) + x = torch.cat((x, selected_pos_embd), dim=1) + + out = super().forward(x, noise_labels, class_labels, augment_labels) + + if self.lead_time_mode: + # if training mode, let crossEntropyLoss do softmax. The model outputs logits. + # if eval mode, the model outputs probability + all_channels = list(range(out.shape[1])) # [0, 1, 2, ..., 10] + scalar_channels = [ + item for item in all_channels if item not in self.prob_channels + ] + if self.prob_channels and (not self.training): + out_final = torch.cat( + ( + out[:, scalar_channels], + (out[:, self.prob_channels] * self.scalar).softmax(dim=1), + ), + dim=1, + ) + elif self.prob_channels and self.training: + out_final = torch.cat( + ( + out[:, scalar_channels], + (out[:, self.prob_channels] * self.scalar), + ), + dim=1, + ) + else: + out_final = out + return out_final + + return out + + def positional_embedding_indexing( + self, + x: torch.Tensor, + global_index: Optional[torch.Tensor] = None, + lead_time_label=None, + ) -> torch.Tensor: + """Select positional embeddings using global indices. - return super().forward(x, noise_labels, class_labels, augment_labels) + This method either uses global indices to select specific embeddings or expands + the embeddings for the full input when no indices are provided. + + Typically used in patch-based training, where the batch dimension + contains multiple patches extracted from a larger image. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W), used to determine batch size + and device. + global_index : Optional[torch.Tensor] + Optional tensor of indices for selecting embeddings. These should + correspond to the spatial indices of the batch elements in the + input tensor x. When provided, should have shape (P, 2, H, W) where + the second dimension contains y,x coordinates (indices of the + positional embedding grid). + + Returns + ------- + torch.Tensor + Selected positional embeddings with shape: + - If global_index provided: (B, N_pe, H, W) + - If global_index is None: (B, N_pe, H_pe, W_pe) + where N_pe is the number of positional embedding channels, and H_pe + and W_pe are the height and width of the positional embedding grid. + + Example + ------- + >>> # Create global indices using patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> global_index = patching.global_index(batch_size=3) + >>> print(global_index.shape) + torch.Size([4, 2, 8, 8]) + + See Also + -------- + :meth:`physicsnemo.utils.patching.RandomPatching2D.global_index` + For generating random patch indices. + :meth:`physicsnemo.utils.patching.GridPatching2D.global_index` + For generating deterministic grid-based patch indices. + See these methods for possible ways to generate the global_index parameter. + """ + # If no global indices are provided, select all embeddings and expand + # to match the batch size of the input + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) - def positional_embedding_indexing(self, x, global_index): if global_index is None: - selected_pos_embd = ( - self.pos_embd.to(x.dtype) - .to(x.device)[None] - .expand((x.shape[0], -1, -1, -1)) - ) + if self.lead_time_mode: + selected_pos_embd = [] + if self.pos_embd is not None: + selected_pos_embd.append( + self.pos_embd[None].expand((x.shape[0], -1, -1, -1)) + ) + if self.lt_embd is not None: + selected_pos_embd.append( + torch.reshape( + self.lt_embd[lead_time_label.int()], + ( + x.shape[0], + self.lead_time_channels, + self.img_shape_y, + self.img_shape_x, + ), + ) + ) + if len(selected_pos_embd) > 0: + selected_pos_embd = torch.cat(selected_pos_embd, dim=1) + else: + selected_pos_embd = self.pos_embd[None].expand( + (x.shape[0], -1, -1, -1) + ) # (B, N_pe, H, W) + else: - B = global_index.shape[0] - X = global_index.shape[2] - Y = global_index.shape[3] + P = global_index.shape[0] + B = x.shape[0] // P + H = global_index.shape[2] + W = global_index.shape[3] + global_index = torch.reshape( torch.permute(global_index, (1, 0, 2, 3)), (2, -1) - ) # (B, 2, X, Y) to (2, B*X*Y) - selected_pos_embd = self.pos_embd.to(x.device)[ + ) # (P, 2, X, Y) to (2, P*X*Y) + selected_pos_embd = self.pos_embd[ :, global_index[0], global_index[1] - ] # (N_pe, B*X*Y) - selected_pos_embd = ( - torch.permute( - torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], B, X, Y)), - (1, 0, 2, 3), - ) - .to(x.device) - .to(x.dtype) - ) # (B, N_pe, X, Y) + ] # (N_pe, P*X*Y) + selected_pos_embd = torch.permute( + torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), + (1, 0, 2, 3), + ) # (P, N_pe, X, Y) + + selected_pos_embd = selected_pos_embd.repeat( + B, 1, 1, 1 + ) # (B*P, N_pe, X, Y) + + # Append positional and lead time embeddings to input conditioning + if self.lead_time_mode: + embeds = [] + if self.pos_embd is not None: + embeds.append(selected_pos_embd) # reuse code below + if self.lt_embd is not None: + lt_embds = self.lt_embd[ + lead_time_label.int() + ] # (B, self.lead_time_channels, self.img_shape_y, self.img_shape_x), + + selected_lt_pos_embd = lt_embds[ + :, :, global_index[0], global_index[1] + ] # (B, N_lt, P*X*Y) + selected_lt_pos_embd = torch.reshape( + torch.permute( + torch.reshape( + selected_lt_pos_embd, + (B, self.lead_time_channels, P, H, W), + ), + (0, 2, 1, 3, 4), + ).contiguous(), + (B * P, self.lead_time_channels, H, W), + ) # (B*P, N_pe, X, Y) + embeds.append(selected_lt_pos_embd) + + if len(embeds) > 0: + selected_pos_embd = torch.cat(embeds, dim=1) + return selected_pos_embd + + def positional_embedding_selector( + self, + x: torch.Tensor, + embedding_selector: Callable[[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Select positional embeddings using a selector function. + + Similar to positional_embedding_indexing, but uses a selector function + to select the embeddings. This method provides a more efficient way to + select embeddings for batches of data. + Typically used with patch-based processing, where the batch dimension + contains multiple patches extracted from a larger image. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W) only used to determine dtype and + device. + embedding_selector : Callable + Function that takes as input an embedding tensor of shape (N_pe, + H_pe, W_pe) and returns selected embeddings with shape (batch_size, N_pe, H, W). + Each selected embedding should correspond to the positional + information of each batch element in x. + For patch-based processing, typically this should be based on + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` method to + maintain consistency with patch extraction. + embeds : Optional[torch.Tensor] + Optional tensor for combined positional and lead time embeddings tensor + + Returns + ------- + torch.Tensor + Selected positional embeddings with shape (B, N_pe, H, W) + where N_pe is the number of positional embedding channels. + + Example + ------- + >>> # Define a selector function with a patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> batch_size = 4 + >>> def embedding_selector(emb): + ... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + >>> + + See Also + -------- + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` + For the base patching method typically used in embedding_selector. + """ + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) + + return embedding_selector(self.pos_embd) # (B, N_pe, H, W) def _get_positional_embedding(self): if self.N_grid_channels == 0: @@ -577,14 +956,16 @@ def _get_positional_embedding(self): elif self.gridtype == "learnable": grid = torch.nn.Parameter( torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) - ) + ) # (N_grid_channels, img_shape_y, img_shape_x) elif self.gridtype == "linear": if self.N_grid_channels != 2: raise ValueError("N_grid_channels must be set to 2 for gridtype linear") x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) grid_x, grid_y = np.meshgrid(y, x) - grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) + grid = torch.from_numpy( + np.stack((grid_x, grid_y), axis=0) + ) # (2, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: # print('sinusuidal grid added ......') @@ -600,7 +981,7 @@ def _get_positional_embedding(self): np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 ) ) - ) + ) # (4, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: if self.N_grid_channels % 4 != 0: @@ -616,28 +997,50 @@ def _get_positional_embedding(self): for p_fn in [np.sin, np.cos]: grid_list.append(p_fn(grid_x * freq)) grid_list.append(p_fn(grid_y * freq)) - grid = torch.from_numpy(np.stack(grid_list, axis=0)) + grid = torch.from_numpy( + np.stack(grid_list, axis=0) + ) # (N_grid_channels, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "test" and self.N_grid_channels == 2: idx_x = torch.arange(self.img_shape_y) idx_y = torch.arange(self.img_shape_x) mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) - grid = torch.stack((mesh_x, mesh_y), dim=0) + grid = torch.stack((mesh_x, mesh_y), dim=0) # (2, img_shape_y, img_shape_x) else: raise ValueError("Gridtype not supported.") return grid + + def _get_lead_time_embedding(self): + if (self.lead_time_steps is None) or (self.lead_time_channels is None): + return None + grid = torch.nn.Parameter( + torch.randn( + self.lead_time_steps, + self.lead_time_channels, + self.img_shape_y, + self.img_shape_x, + ) + ) # (lead_time_steps, lead_time_channels, img_shape_y, img_shape_x) + return grid -class SongUNetPosLtEmbd(SongUNet): +class SongUNetPosLtEmbd(SongUNetPosEmbd): """ - This model is adapated from SongUNetPosEmbd, with the incoporatation of lead-time aware - embedding for the GEFS-HRRR model. The lead-time embedding is activated by setting the - lead_time_channels and lead_time_steps parameters. + This model is adapted from SongUNetPosEmbd, with the incorporation of lead-time aware + embeddings. The lead-time embedding is activated by setting the + `lead_time_channels` and `lead_time_steps` parameters. + + Like SongUNetPosEmbd, this model provides two methods for selecting positional embeddings: + 1. Using a selector function (preferred method). See + :meth:`positional_embedding_selector` for details. + 2. Using global indices. See :meth:`positional_embedding_indexing` for + details. Parameters ----------- img_resolution : Union[List[int], int] - The resolution of the input/output image, 1 value represents a square image. + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int @@ -647,44 +1050,63 @@ class SongUNetPosLtEmbd(SongUNet): augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional - Base multiplier for the number of channels across the network, by default 128. + Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional - Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 4. attn_resolutions : List[int], optional - Resolutions at which self-attention layers are applied. By default [16]. + Resolutions at which self-attention layers are applied. By default [28]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.13. label_dropout : float, optional - Dropout probability of class labels for classifier-free guidance. By default 0.0. + Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional - Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default - 'standard'. + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. + By default 'standard'. decoder_type : str, optional - Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default - 'standard'. - resample_filter : List[int], optional (default=[1,1]) - Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - lead_time_channels: int, optional - Length of lead time embedding vector - lead_time_steps: int, optional - Total number of lead times + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + gridtype : str, optional + Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. + Controls how positional information is encoded. By default 'sinusoidal'. + N_grid_channels : int, optional + Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or + multiple of 4. For 'linear' must be 2. By default 4. + lead_time_channels : int, optional + Number of channels in the lead time embedding. These are learned embeddings that + encode temporal forecast information. By default None. + lead_time_steps : int, optional + Number of discrete lead time steps to support. Each step gets its own learned + embedding vector. By default 9. + prob_channels : List[int], optional + Indices of probability output channels that should use softmax activation. + Used for classification outputs. By default empty list. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. - Reference - ---------- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. - Note ----- Equivalent to the original implementation by Song et al., available at @@ -692,13 +1114,54 @@ class SongUNetPosLtEmbd(SongUNet): Example -------- - >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> import torch + >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosLtEmbd + >>> from physicsnemo.utils.patching import GridPatching2D + >>> + >>> # Model initialization - in_channels must include original input channels (2), + >>> # positional embedding channels (N_grid_channels=4 by default) and + >>> # lead time embedding channels (4) + >>> model = SongUNetPosLtEmbd( + ... img_resolution=16, in_channels=2+4+4, out_channels=2, + ... lead_time_channels=4, lead_time_steps=9 + ... ) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> # The input has only the original 2 channels - positional embeddings and + >>> # lead time embeddings are added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) - >>> output_image = model(input_image, noise_labels, class_labels) + >>> lead_time_label = torch.tensor([3]) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... lead_time_label=lead_time_label + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using global_index to select all the positional and lead time embeddings + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) + >>> global_index = patching.global_index(batch_size=1) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... lead_time_label=lead_time_label, + ... global_index=global_index + ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) + + # NOTE: commented out doctest for embedding_selector due to compatibility issue + # >>> + # >>> # Using custom embedding selector to select all the positional and lead time embeddings + # >>> def patch_embedding_selector(emb): + # ... return patching.apply(emb[None].expand(1, -1, -1, -1)) + # >>> output_image = model( + # ... input_image, noise_labels, class_labels, + # ... lead_time_label=lead_time_label, + # ... embedding_selector=patch_embedding_selector + # ... ) + # >>> output_image.shape + # torch.Size([1, 2, 16, 16]) + """ def __init__( @@ -726,6 +1189,11 @@ def __init__( lead_time_steps: int = 9, prob_channels: List[int] = [], checkpoint_level: int = 0, + additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, ): super().__init__( img_resolution, @@ -745,162 +1213,38 @@ def __init__( encoder_type, decoder_type, resample_filter, + gridtype, + N_grid_channels, checkpoint_level, + additive_pos_embed, + use_apex_gn, + act, + profile_mode, + amp_mode, + True, # Note: lead_time_mode=True is enforced here + lead_time_channels, + lead_time_steps, + prob_channels, ) - self.gridtype = gridtype - self.N_grid_channels = N_grid_channels - self.pos_embd = self._get_positional_embedding() - self.lead_time_channels = lead_time_channels - self.lead_time_steps = lead_time_steps - self.lt_embd = self._get_lead_time_embedding() - self.prob_channels = prob_channels - if self.prob_channels: - self.scalar = torch.nn.Parameter( - torch.ones((1, len(self.prob_channels), 1, 1)) - ) - - @nvtx.annotate(message="SongUNet", color="blue") def forward( self, x, noise_labels, class_labels, lead_time_label=None, - global_index=None, + global_index: Optional[torch.Tensor] = None, + embedding_selector: Optional[Callable] = None, augment_labels=None, ): - # append positional embedding to input conditioning - embeds = [] - if self.pos_embd is not None: - embeds.append(self.pos_embd.to(x.device)) - if self.lt_embd is not None: - embeds.append( - torch.reshape( - self.lt_embd[lead_time_label.int()], - (self.lead_time_channels, self.img_shape_y, self.img_shape_x), - ).to(x.device) - ) - if len(embeds) > 0: - embeds = torch.cat(embeds, dim=0) - selected_pos_embd = self.positional_embedding_indexing( - x, embeds, global_index - ) - x = torch.cat((x, selected_pos_embd), dim=1) - out = super().forward(x, noise_labels, class_labels, augment_labels) - # if training mode, let crossEntropyLoss do softmax. The model outputs logits. - # if eval mode, the model outputs probability - all_channels = list(range(out.shape[1])) # [0, 1, 2, ..., 10] - scalar_channels = [ - item for item in all_channels if item not in self.prob_channels - ] - if self.prob_channels and (not self.training): - out_final = torch.cat( - ( - out[:, scalar_channels], - (out[:, self.prob_channels] * self.scalar).softmax(dim=1), - ), - dim=1, - ) - elif self.prob_channels and self.training: - out_final = torch.cat( - (out[:, scalar_channels], (out[:, self.prob_channels] * self.scalar)), - dim=1, - ) - else: - out_final = out - return out_final - - def positional_embedding_indexing(self, x, pos_embd, global_index): - if global_index is None: - selected_pos_embd = ( - pos_embd.to(x.dtype).to(x.device)[None].expand((x.shape[0], -1, -1, -1)) - ) - else: - B = global_index.shape[0] - X = global_index.shape[2] - Y = global_index.shape[3] - global_index = torch.reshape( - torch.permute(global_index, (1, 0, 2, 3)), (2, -1) - ) # (B, 2, X, Y) to (2, B*X*Y) - selected_pos_embd = pos_embd.to(x.device)[ - :, global_index[0], global_index[1] - ] # (N_pe, B*X*Y) - selected_pos_embd = ( - torch.permute( - torch.reshape(selected_pos_embd, (pos_embd.shape[0], B, X, Y)), - (1, 0, 2, 3), - ) - .to(x.device) - .to(x.dtype) - ) # (B, N_pe, X, Y) - return selected_pos_embd - - def _get_positional_embedding(self): - if self.N_grid_channels == 0: - return None - elif self.gridtype == "learnable": - grid = torch.nn.Parameter( - torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) - ) - elif self.gridtype == "linear": - if self.N_grid_channels != 2: - raise ValueError("N_grid_channels must be set to 2 for gridtype linear") - x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) - y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) - grid_x, grid_y = np.meshgrid(y, x) - grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) - grid.requires_grad = False - elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: - # print('sinusuidal grid added ......') - x1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_y))) - x2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_y))) - y1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_x))) - y2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_x))) - grid_x1, grid_y1 = np.meshgrid(y1, x1) - grid_x2, grid_y2 = np.meshgrid(y2, x2) - grid = torch.squeeze( - torch.from_numpy( - np.expand_dims( - np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 - ) - ) - ) - grid.requires_grad = False - elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: - if self.N_grid_channels % 4 != 0: - raise ValueError("N_grid_channels must be a factor of 4") - num_freq = self.N_grid_channels // 4 - freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) - grid_list = [] - grid_x, grid_y = np.meshgrid( - np.linspace(0, 2 * np.pi, self.img_shape_x), - np.linspace(0, 2 * np.pi, self.img_shape_y), - ) - for freq in freq_bands: - for p_fn in [np.sin, np.cos]: - grid_list.append(p_fn(grid_x * freq)) - grid_list.append(p_fn(grid_y * freq)) - grid = torch.from_numpy(np.stack(grid_list, axis=0)) - grid.requires_grad = False - elif self.gridtype == "test" and self.N_grid_channels == 2: - idx_x = torch.arange(self.img_shape_y) - idx_y = torch.arange(self.img_shape_x) - mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) - grid = torch.stack((mesh_x, mesh_y), dim=0) - else: - raise ValueError("Gridtype not supported.") - return grid - - def _get_lead_time_embedding(self): - if (self.lead_time_steps is None) or (self.lead_time_channels is None): - return None - grid = torch.nn.Parameter( - torch.randn( - self.lead_time_steps, - self.lead_time_channels, - self.img_shape_y, - self.img_shape_x, - ) + return super().forward( + x=x, + noise_labels=noise_labels, + class_labels=class_labels, + global_index=global_index, + embedding_selector=embedding_selector, + augment_labels=augment_labels, + lead_time_label=lead_time_label, ) - return grid + + # Nothing else is re-implemented, because everything is already in the parent SongUNetPosEmb \ No newline at end of file diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index 10079ec..e0a447a 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -16,6 +16,7 @@ import importlib from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Tuple, Union import torch import torch.nn as nn @@ -45,31 +46,35 @@ class MetaData(ModelMetaData): class UNet(nn.Module): # TODO a lot of redundancy, need to clean up """ - U-Net Wrapper for CorrDiff. + U-Net Wrapper for CorrDiff deterministic regression model. Parameters ----------- - img_resolution : int - The resolution of the input/output image. - img_channels : int - Number of color channels. + img_resolution : Union[int, Tuple[int, int]] + The resolution of the input/output image. If a single int is provided, + then the image is assumed to be square. img_in_channels : int - Number of input color channels. + Number of channels in the input image. img_out_channels : int - Number of output color channels. + Number of channels in the output image. use_fp16: bool, optional - Execute the underlying model at FP16 precision?, by default False. - sigma_min: float, optional - Minimum supported noise level, by default 0. - sigma_max: float, optional - Maximum supported noise level, by default float('inf'). - sigma_data: float, optional - Expected standard deviation of the training data, by default 0.5. + Execute the underlying model at FP16 precision, by default False. model_type: str, optional - Class name of the underlying model, by default 'DhariwalUNet'. + Class name of the underlying model. Must be one of the following: + 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. + Defaults to 'SongUNetPosEmbd'. **model_kwargs : dict - Keyword arguments for the underlying model. + Keyword arguments passed to the underlying model `__init__` method. + + See Also + -------- + For information on model types and their usage: + :class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models + :class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings + :class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings + Please refer to the documentation of these classes for details on how to call + and use these models directly. References ---------- @@ -79,37 +84,66 @@ class UNet(nn.Module): # TODO a lot of redundancy, need to clean up arXiv preprint arXiv:2309.15214. """ + @classmethod + def _backward_compat_arg_mapper( + cls, version: str, args: Dict[str, Any] + ) -> Dict[str, Any]: + """Map arguments from older versions to current version format. + + Parameters + ---------- + version : str + Version of the checkpoint being loaded + args : Dict[str, Any] + Arguments dictionary from the checkpoint + + Returns + ------- + Dict[str, Any] + Updated arguments dictionary compatible with current version + """ + # Call parent class method first + args = super()._backward_compat_arg_mapper(version, args) + + if version == "0.1.0": + # In version 0.1.0, img_channels was unused + if "img_channels" in args: + _ = args.pop("img_channels") + + # Sigma parameters are also unused + if "sigma_min" in args: + _ = args.pop("sigma_min") + if "sigma_max" in args: + _ = args.pop("sigma_max") + if "sigma_data" in args: + _ = args.pop("sigma_data") + + return args + def __init__( self, - img_resolution, - img_channels, - img_in_channels, - img_out_channels, - use_fp16=False, - sigma_min=0, - sigma_max=float("inf"), - sigma_data=0.5, - model_type="SongUNetPosEmbd", - **model_kwargs, + img_resolution: Union[int, Tuple[int, int]], + img_in_channels: int, + img_out_channels: int, + use_fp16: bool = False, + model_type: Literal[ + "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" + ] = "SongUNetPosEmbd", + **model_kwargs: dict, ): super().__init__() #meta=MetaData - self.img_channels = img_channels - # for compatibility with older versions that took only 1 dimension if isinstance(img_resolution, int): self.img_shape_x = self.img_shape_y = img_resolution else: - self.img_shape_x = img_resolution[0] - self.img_shape_y = img_resolution[1] + self.img_shape_y = img_resolution[0] + self.img_shape_x = img_resolution[1] self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data model_class = getattr(network_module, model_type) self.model = model_class( img_resolution=img_resolution, @@ -118,13 +152,47 @@ def __init__( **model_kwargs, ) - def forward(self, x, img_lr, sigma, force_fp32=False, **model_kwargs): + def forward( + self, + x: torch.Tensor, + img_lr: torch.Tensor, + force_fp32: bool = False, + **model_kwargs: dict, + ) -> torch.Tensor: + """ + Forward pass of the UNet wrapper model. + + This method concatenates the input tensor with the low-resolution conditioning tensor + and passes the result through the underlying model. + + Parameters + ---------- + x : torch.Tensor + The input tensor, typically zero-filled, of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model + `self.model` forward method. + + Returns + ------- + torch.Tensor + Output tensor (prediction) of shape (B, C_hr, H, W). + + Raises + ------ + ValueError + If the model output dtype doesn't match the expected dtype. + """ + # SR: concatenate input channels if img_lr is not None: x = torch.cat((x, img_lr), dim=1) - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) dtype = ( torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") @@ -133,29 +201,27 @@ def forward(self, x, img_lr, sigma, force_fp32=False, **model_kwargs): F_x = self.model( x.to(dtype), # (c_in * x).to(dtype), - torch.zeros( - sigma.numel(), dtype=sigma.dtype, device=sigma.device - ), # c_noise.flatten() + torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten() class_labels=None, **model_kwargs, ) if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + f"Expected the dtype to be {dtype}, " f"but got {F_x.dtype} instead." ) - # skip connection - for SR there's size mismatch bwtween input and output + # skip connection D_x = F_x.to(torch.float32) return D_x - def round_sigma(self, sigma): + def round_sigma(self, sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: """ Convert a given sigma value(s) to a tensor representation. Parameters ---------- - sigma : Union[float list, torch.Tensor] + sigma : Union[float, List, torch.Tensor] The sigma value(s) to convert. Returns @@ -164,8 +230,31 @@ def round_sigma(self, sigma): The tensor representation of the provided sigma value(s). """ return torch.as_tensor(sigma) + + @property + def amp_mode(self): + """ + Return the *amp_mode* flag of the underlying model if present. + """ + return getattr(self.model, "amp_mode", None) + + @amp_mode.setter + def amp_mode(self, value: bool): + """ + Update *amp_mode* on the wrapped model and its sub-modules. + """ + if not isinstance(value, bool): + raise TypeError("amp_mode must be a boolean value.") + + if hasattr(self.model, "amp_mode"): + self.model.amp_mode = value + # Recursively update sub-modules that define *amp_mode*. + for sub_module in self.model.modules(): + if hasattr(sub_module, "amp_mode"): + sub_module.amp_mode = value +# TODO: implement amp_mode property for StormCastUNet (same as UNet) class StormCastUNet(nn.Module): """ U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. @@ -189,7 +278,7 @@ class StormCastUNet(nn.Module): sigma_data: float, optional Expected standard deviation of the training data, by default 0.5. model_type: str, optional - Class name of the underlying model, by default 'DhariwalUNet'. + Class name of the underlying model, by default 'SongUNet'. **model_kwargs : dict Keyword arguments for the underlying model. diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 559d800..794dd55 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -5,6 +5,8 @@ import hydra from omegaconf import DictConfig, OmegaConf import json +from contextlib import nullcontext +import nvtx import torch from hydra.utils import to_absolute_path from torch.utils.tensorboard import SummaryWriter @@ -17,12 +19,44 @@ set_patch_shape, compute_num_accumulation_rounds, \ is_time_for_periodic_task, handle_and_clip_gradients from hirad.utils.checkpoint import load_checkpoint, save_checkpoint -from hirad.models import UNet, EDMPrecondSR -from hirad.losses import ResLoss, RegressionLoss, RegressionLossCE +from hirad.utils.patching import RandomPatching2D +from hirad.models import UNet, EDMPrecondSuperResolution, EDMPrecondSR +from hirad.losses import ResidualLoss, RegressionLoss, RegressionLossCE from hirad.datasets import init_train_valid_datasets_from_config from matplotlib import pyplot as plt +torch._dynamo.reset() +# Increase the cache size limit +torch._dynamo.config.cache_size_limit = 264 # Set to a higher value +torch._dynamo.config.verbose = True # Enable verbose logging +torch._dynamo.config.suppress_errors = False # Forces the error to show all details +torch._logging.set_logs(recompiles=True, graph_breaks=True) + +# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available +def cuda_profiler(): + if torch.cuda.is_available(): + return torch.cuda.profiler.profile() + else: + return nullcontext() + + +def cuda_profiler_start(): + if torch.cuda.is_available(): + torch.cuda.profiler.start() + + +def cuda_profiler_stop(): + if torch.cuda.is_available(): + torch.cuda.profiler.stop() + + +def profiler_emit_nvtx(): + if torch.cuda.is_available(): + return torch.autograd.profiler.emit_nvtx() + else: + return nullcontext() + @hydra.main(version_base=None, config_path="../conf", config_name="training") def main(cfg: DictConfig) -> None: # Initialize distributed environment for training @@ -63,7 +97,7 @@ def main(cfg: DictConfig) -> None: data_loader_kwargs = { "pin_memory": True, "num_workers": cfg.training.perf.dataloader_workers, - "prefetch_factor": 2, + "prefetch_factor": 2 if cfg.training.perf.dataloader_workers > 0 else None, } ( dataset, @@ -104,80 +138,64 @@ def main(cfg: DictConfig) -> None: else: patch_shape_x = None patch_shape_y = None + if ( + patch_shape_x + and patch_shape_y + and patch_shape_y >= img_shape[0] + and patch_shape_x >= img_shape[1] + ): + logger0.warning( + f"Patch shape {patch_shape_y}x{patch_shape_x} is larger than \ + the image shape {img_shape[0]}x{img_shape[1]}. Patching will not be used." + ) patch_shape = (patch_shape_y, patch_shape_x) - img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) - if patch_shape != img_shape: + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + # Utility to perform patches extraction and batching + patching = RandomPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + patch_num=getattr(cfg.training.hp, "patch_num", 1), + ) logger0.info("Patch-based training enabled") else: + patching = None logger0.info("Patch-based training disabled") # interpolate global channel if patch-based model is used - if img_shape[1] != patch_shape[1]: + if use_patching: img_in_channels += dataset_channels # Instantiate the model and move to device. - if cfg.model.name not in ( - "regression", - "lt_aware_ce_regression", - "diffusion", - "patched_diffusion", - "lt_aware_patched_diffusion", - ): - raise ValueError("Invalid model") model_args = { # default parameters for all networks "img_out_channels": img_out_channels, "img_resolution": list(img_shape), "use_fp16": fp16, + "checkpoint_level": songunet_checkpoint_level, } - standard_model_cfgs = { # default parameters for different network types - "regression": { - "img_channels": 4, - "N_grid_channels": 4, - "embedding_type": "zero", - "checkpoint_level": songunet_checkpoint_level, - }, - "lt_aware_ce_regression": { - "img_channels": 4, - "N_grid_channels": 4, - "embedding_type": "zero", - "lead_time_channels": 4, - "lead_time_steps": 9, - "prob_channels": prob_channels, - "checkpoint_level": songunet_checkpoint_level, - "model_type": "SongUNetPosLtEmbd", - }, - "diffusion": { - "img_channels": img_out_channels, - "gridtype": "sinusoidal", - "N_grid_channels": 4, - "checkpoint_level": songunet_checkpoint_level, - }, - "patched_diffusion": { - "img_channels": img_out_channels, - "gridtype": "learnable", - "N_grid_channels": 100, - "checkpoint_level": songunet_checkpoint_level, - }, - "lt_aware_patched_diffusion": { - "img_channels": img_out_channels, - "gridtype": "learnable", - "N_grid_channels": 100, - "lead_time_channels": 20, - "lead_time_steps": 9, - "checkpoint_level": songunet_checkpoint_level, - "model_type": "SongUNetPosLtEmbd", - }, - } - - - model_args.update(standard_model_cfgs[cfg.model.name]) - if cfg.model.name in ( - "diffusion", - "patched_diffusion", - "lt_aware_patched_diffusion", - ): - model_args["scale_cond_input"] = cfg.model.scale_cond_input + if cfg.model.name == "lt_aware_ce_regression": + model_args["prob_channels"] = prob_channels + if hasattr(cfg.model, "model_args"): # override defaults from config file model_args.update(OmegaConf.to_container(cfg.model.model_args)) + + use_torch_compile = False + use_apex_gn = False + profile_mode = False + + if hasattr(cfg.training.perf, "torch_compile"): + use_torch_compile = cfg.training.perf.torch_compile + if hasattr(cfg.training.perf, "use_apex_gn"): + use_apex_gn = cfg.training.perf.use_apex_gn + model_args["use_apex_gn"] = use_apex_gn + + if hasattr(cfg.training.perf, "profile_mode"): + profile_mode = cfg.training.perf.profile_mode + model_args["profile_mode"] = profile_mode + + if enable_amp: + model_args["amp_mode"] = enable_amp + + if cfg.model.name == "regression": model = UNet( img_in_channels=img_in_channels + model_args["N_grid_channels"], @@ -193,7 +211,7 @@ def main(cfg: DictConfig) -> None: ) model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] elif cfg.model.name == "lt_aware_patched_diffusion": - model = EDMPrecondSR( + model = EDMPrecondSuperResolution( img_in_channels=img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"], @@ -201,7 +219,7 @@ def main(cfg: DictConfig) -> None: ) model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] else: # diffusion or patched diffusion - model = EDMPrecondSR( + model = EDMPrecondSuperResolution( img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) @@ -209,6 +227,18 @@ def main(cfg: DictConfig) -> None: model.train().requires_grad_(True).to(dist.device) + # param_to_name = {} + # ppp = False + # for name, param in model.named_parameters(): + # pid = id(param) + # if pid in param_to_name: + # print(f"[SHARED PARAM] {name} == {param_to_name[pid]}") + # ppp = True + # break + # else: + # param_to_name[pid] = name + # print(f'There are shared parameters: {ppp}') + # TODO write summry from rank=0 possibly # summary(model, input_size=[(1,img_out_channels,*img_shape),(1,img_in_channels,*img_shape),(1,1)]) @@ -216,6 +246,18 @@ def main(cfg: DictConfig) -> None: with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: json.dump(model_args, f) + if use_apex_gn: + model.to(memory_format=torch.channels_last) + + # Check if regression model is used with patching + if ( + cfg.model.name in ["regression", "lt_aware_ce_regression"] + and patching is not None + ): + raise ValueError( + f"Regression model ({cfg.model.name}) cannot be used with patch-based training. " + ) + # Enable distributed data parallel if applicable if dist.world_size > 1: model = DistributedDataParallel( @@ -223,7 +265,9 @@ def main(cfg: DictConfig) -> None: device_ids=[dist.local_rank], broadcast_buffers=True, output_device=dist.device, - find_unused_parameters=dist.find_unused_parameters, + find_unused_parameters=True, # dist.find_unused_parameters, + bucket_cap_mb=35, + gradient_as_bucket_view=True, ) # Load the regression checkpoint if applicable #TODO test when training correction @@ -245,6 +289,12 @@ def main(cfg: DictConfig) -> None: with open(regression_model_args_path, 'r') as f: regression_model_args = json.load(f) + regression_model_args.update({ + "use_apex_gn": use_apex_gn, + "profile_mode": profile_mode, + "amp_mode": enable_amp, + }) + regression_net = UNet(**regression_model_args) _ = load_checkpoint( @@ -253,22 +303,81 @@ def main(cfg: DictConfig) -> None: device=dist.device ) regression_net.eval().requires_grad_(False).to(dist.device) + if use_apex_gn: + regression_net.to(memory_format=torch.channels_last) logger0.success("Loaded the pre-trained regression model") + else: + regression_net = None + + # Compile the model and regression net if applicable + if use_torch_compile: + model = torch.compile(model) + if regression_net: + regression_net = torch.compile(regression_net) + + + # Compute the number of required gradient accumulation rounds + # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size + batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( + cfg.training.hp.total_batch_size, + cfg.training.hp.batch_size_per_gpu, + dist.world_size, + ) + batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") - # Instantiate the loss function patch_num = getattr(cfg.training.hp, "patch_num", 1) + max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", 1) + + # calculate patch per iter + if hasattr(cfg.training.hp, "max_patch_per_gpu") and max_patch_per_gpu > 1: + max_patch_num_per_iter = min( + patch_num, (max_patch_per_gpu // batch_size_per_gpu) + ) # Ensure at least 1 patch per iter + patch_iterations = ( + patch_num + max_patch_num_per_iter - 1 + ) // max_patch_num_per_iter + patch_nums_iter = [ + min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) + for i in range(patch_iterations) + ] + print( + f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}" + ) + else: + patch_nums_iter = [patch_num] + + # Set patch gradient accumulation only for patched diffusion models + if cfg.model.name in { + "patched_diffusion", + "lt_aware_patched_diffusion", + }: + if len(patch_nums_iter) > 1: + if not patching: + logger0.info( + "Patching is not enabled: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = False + else: + use_patch_grad_acc = True + else: + use_patch_grad_acc = False + # Automatically disable patch gradient accumulation for non-patched models + else: + logger0.info( + "Training a non-patched model: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = None + + + # Instantiate the loss function if cfg.model.name in ( "diffusion", "patched_diffusion", "lt_aware_patched_diffusion", ): - loss_fn = ResLoss( + loss_fn = ResidualLoss( regression_net=regression_net, - img_shape_x=img_shape[1], - img_shape_y=img_shape[0], - patch_shape_x=patch_shape[1], - patch_shape_y=patch_shape[0], - patch_num=patch_num, hr_mean_conditioning=cfg.model.hr_mean_conditioning, ) elif cfg.model.name == "regression": @@ -278,23 +387,17 @@ def main(cfg: DictConfig) -> None: # Instantiate the optimizer optimizer = torch.optim.Adam( - params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8 + params=model.parameters(), + lr=cfg.training.hp.lr, + betas=[0.9, 0.999], + eps=1e-8, + fused=True, ) # Record the current time to measure the duration of subsequent operations. start_time = time.time() - # Compute the number of required gradient accumulation rounds - # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size - batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( - cfg.training.hp.total_batch_size, - cfg.training.hp.batch_size_per_gpu, - dist.world_size, - ) - batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu - logger0.info(f"Using {num_accumulation_rounds} gradient accumulation {"rounds" if num_accumulation_rounds>1 else "round"}.") - logger0.info(f"Batch size per gpu: {batch_size_per_gpu}") - ## Resume training from previous checkpoints if exists + # Load optimizer checkpoint if it exists if dist.world_size > 1: torch.distributed.barrier() try: @@ -317,188 +420,308 @@ def main(cfg: DictConfig) -> None: # init variables to monitor running mean of average loss since last periodic average_loss_running_mean = 0 n_average_loss_running_mean = 1 - - while not done: - tick_start_nimg = cur_nimg - tick_start_time = time.time() - # Compute & accumulate gradients - optimizer.zero_grad(set_to_none=True) - loss_accum = 0 - for _ in range(num_accumulation_rounds): - img_clean, img_lr, labels, *lead_time_label = next(dataset_iterator) # what are labels and lead_time_label - img_clean = img_clean.to(dist.device).to(torch.float32).contiguous() - img_lr = img_lr.to(dist.device).to(torch.float32).contiguous() - labels = labels.to(dist.device).contiguous() - loss_fn_kwargs = { - "net": model, - "img_clean": img_clean, - "img_lr": img_lr, - "labels": labels, - "augment_pipe": None, - } - if lead_time_label: - lead_time_label = lead_time_label[0].to(dist.device).contiguous() - loss_fn_kwargs.update({"lead_time_label": lead_time_label}) - else: - lead_time_label = None - with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp): - loss = loss_fn(**loss_fn_kwargs) - loss = loss.sum() / batch_size_per_gpu - loss_accum += loss / num_accumulation_rounds - loss.backward() + start_nimg = cur_nimg + input_dtype = torch.float32 + if enable_amp: + input_dtype = torch.float32 + elif fp16: + input_dtype = torch.float16 + + # enable profiler: + with cuda_profiler(): + with profiler_emit_nvtx(): + while not done: + tick_start_nimg = cur_nimg + tick_start_time = time.time() + + if cur_nimg - start_nimg == 24 * cfg.training.hp.total_batch_size: + logger0.info(f"Starting Profiler at {cur_nimg}") + cuda_profiler_start() + + if cur_nimg - start_nimg == 25 * cfg.training.hp.total_batch_size: + logger0.info(f"Stopping Profiler at {cur_nimg}") + cuda_profiler_stop() + + with nvtx.annotate("Training iteration", color="green"): + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + loss_accum = 0 + for n_i in range(num_accumulation_rounds): + with nvtx.annotate( + f"accumulation round {n_i}", color="Magenta" + ): + with nvtx.annotate("loading data", color="green"): + img_clean, img_lr, *lead_time_label = next( + dataset_iterator + ) + if use_apex_gn: + img_clean = img_clean.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr = img_lr.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean = ( + img_clean.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr = ( + img_lr.to(dist.device) + .to(input_dtype) + .contiguous() + ) + loss_fn_kwargs = { + "net": model, + "img_clean": img_clean, + "img_lr": img_lr, + "augment_pipe": None, + } + if use_patch_grad_acc is not None: + loss_fn_kwargs[ + "use_patch_grad_acc" + ] = use_patch_grad_acc + + if lead_time_label: + lead_time_label = ( + lead_time_label[0].to(dist.device).contiguous() + ) + loss_fn_kwargs.update( + {"lead_time_label": lead_time_label} + ) + else: + lead_time_label = None + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_fn_kwargs.update({"patching": patching}) + with nvtx.annotate(f"loss forward", color="green"): + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss = loss_fn(**loss_fn_kwargs) + + loss = loss.sum() / batch_size_per_gpu + loss_accum += loss / num_accumulation_rounds + with nvtx.annotate(f"loss backward", color="yellow"): + loss.backward() - loss_sum = torch.tensor([loss_accum], device=dist.device) - if dist.world_size > 1: - torch.distributed.barrier() - torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM) - average_loss = (loss_sum / dist.world_size).cpu().item() - - # update running mean of average loss since last periodic task - average_loss_running_mean += ( - average_loss - average_loss_running_mean - ) / n_average_loss_running_mean - n_average_loss_running_mean += 1 - - if dist.rank == 0: - writer.add_scalar("training_loss", average_loss, cur_nimg) - writer.add_scalar( - "training_loss_running_mean", average_loss_running_mean, cur_nimg - ) + with nvtx.annotate(f"loss aggregate", color="green"): + loss_sum = torch.tensor([loss_accum], device=dist.device) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_loss = (loss_sum / dist.world_size).cpu().item() + + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 - - # Update weights. - lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate - for g in optimizer.param_groups: - if lr_rampup > 0: - g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) - if cur_nimg >= lr_rampup: - g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) - current_lr = g["lr"] - if dist.rank == 0: - writer.add_scalar("learning_rate", current_lr, cur_nimg) - handle_and_clip_gradients( - model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold - ) - optimizer.step() - - cur_nimg += cfg.training.hp.total_batch_size - done = cur_nimg >= cfg.training.hp.training_duration - - # Validation - if validation_dataset_iterator is not None: - valid_loss_accum = 0 - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.validation_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - ): - with torch.no_grad(): - for _ in range(cfg.training.io.validation_steps): - img_clean_valid, img_lr_valid, labels_valid = next( - validation_dataset_iterator - ) - - img_clean_valid = ( - img_clean_valid.to(dist.device) - .to(torch.float32) - .contiguous() - ) - img_lr_valid = ( - img_lr_valid.to(dist.device).to(torch.float32).contiguous() - ) - labels_valid = labels_valid.to(dist.device).contiguous() - loss_valid = loss_fn( - net=model, - img_clean=img_clean_valid, - img_lr=img_lr_valid, - labels=labels_valid, - augment_pipe=None, - ) - loss_valid = ( - (loss_valid.sum() / batch_size_per_gpu).cpu().item() - ) - valid_loss_accum += ( - loss_valid / cfg.training.io.validation_steps - ) - valid_loss_sum = torch.tensor( - [valid_loss_accum], device=dist.device - ) - if dist.world_size > 1: - torch.distributed.barrier() - torch.distributed.all_reduce( - valid_loss_sum, op=torch.distributed.ReduceOp.SUM - ) - average_valid_loss = valid_loss_sum / dist.world_size if dist.rank == 0: + writer.add_scalar("training_loss", average_loss, cur_nimg) writer.add_scalar( - "validation_loss", average_valid_loss, cur_nimg + "training_loss_running_mean", + average_loss_running_mean, + cur_nimg, ) - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - # Print stats if we crossed the printing threshold with this batch - tick_end_time = time.time() - fields = [] - fields += [f"samples {cur_nimg:<9.1f}"] - fields += [f"training_loss {average_loss:<7.2f}"] - fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] - fields += [f"learning_rate {current_lr:<7.8f}"] - fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] - fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] - fields += [ - f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" - ] - fields += [ - f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" - ] - fields += [ - f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" - ] - fields += [ - f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" - ] - logger0.info(" ".join(fields)) - torch.cuda.reset_peak_memory_stats() - - ptt = is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ) - if ptt: - # reset running mean of average loss - average_loss_running_mean = 0 - n_average_loss_running_mean = 1 - - # Save checkpoints - if dist.world_size > 1: - torch.distributed.barrier() - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.save_checkpoint_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - save_checkpoint( - path=checkpoint_dir, - model=model, - optimizer=optimizer, - epoch=cur_nimg, - ) + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + # Update weights. + with nvtx.annotate("update weights", color="blue"): + + lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) + if cur_nimg >= lr_rampup: + g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) + current_lr = g["lr"] + if dist.rank == 0: + writer.add_scalar("learning_rate", current_lr, cur_nimg) + handle_and_clip_gradients( + model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold + ) + with nvtx.annotate("optimizer step", color="blue"): + optimizer.step() + + cur_nimg += cfg.training.hp.total_batch_size + done = cur_nimg >= cfg.training.hp.training_duration + + with nvtx.annotate("validation", color="red"): + # Validation + if validation_dataset_iterator is not None: + valid_loss_accum = 0 + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.validation_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + with torch.no_grad(): + for _ in range(cfg.training.io.validation_steps): + ( + img_clean_valid, + img_lr_valid, + *lead_time_label_valid, + ) = next(validation_dataset_iterator) + + if use_apex_gn: + img_clean_valid = img_clean_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_valid = img_lr_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + + else: + img_clean_valid = ( + img_clean_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_valid = ( + img_lr_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + + loss_valid_kwargs = { + "net": model, + "img_clean": img_clean_valid, + "img_lr": img_lr_valid, + "augment_pipe": None, + } + if use_patch_grad_acc is not None: + loss_valid_kwargs[ + "use_patch_grad_acc" + ] = use_patch_grad_acc + if lead_time_label_valid: + lead_time_label_valid = ( + lead_time_label_valid[0] + .to(dist.device) + .contiguous() + ) + loss_valid_kwargs.update( + {"lead_time_label": lead_time_label_valid} + ) + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_fn_kwargs.update( + {"patching": patching} + ) + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss_valid = loss_fn(**loss_valid_kwargs) + + loss_valid = ( + (loss_valid.sum() / batch_size_per_gpu) + .cpu() + .item() + ) + valid_loss_accum += ( + loss_valid + / cfg.training.io.validation_steps + ) + valid_loss_sum = torch.tensor( + [valid_loss_accum], device=dist.device + ) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_valid_loss = valid_loss_sum / dist.world_size + if dist.rank == 0: + writer.add_scalar( + "validation_loss", average_valid_loss, cur_nimg + ) + + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + if torch.cuda.is_available(): + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.save_checkpoint_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + save_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + epoch=cur_nimg, + ) # Done. logger0.info("Training Completed.") diff --git a/src/hirad/utils/deterministic_sampler.py b/src/hirad/utils/deterministic_sampler.py index 9fcea1d..e502875 100644 --- a/src/hirad/utils/deterministic_sampler.py +++ b/src/hirad/utils/deterministic_sampler.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Literal, Optional import numpy as np import nvtx @@ -26,33 +27,142 @@ @nvtx.annotate(message="deterministic_sampler", color="red") def deterministic_sampler( - net, - latents, - img_lr, - img_shape=None, - class_labels=None, - randn_like=torch.randn_like, - num_steps=18, - sigma_min=None, - sigma_max=None, - rho=7, - solver="heun", - discretization="edm", - schedule="linear", - scaling="none", - epsilon_s=1e-3, - C_1=0.001, - C_2=0.008, - M=1000, - alpha=1, - S_churn=0, - S_min=0, - S_max=float("inf"), - S_noise=1, -): + net: torch.nn.Module, + latents: torch.Tensor, + img_lr: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + randn_like: Callable = torch.randn_like, + num_steps: int = 18, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + rho: float = 7.0, + solver: Literal["heun", "euler"] = "heun", + discretization: Literal["vp", "ve", "iddpm", "edm"] = "edm", + schedule: Literal["vp", "ve", "linear"] = "linear", + scaling: Literal["vp", "none"] = "none", + epsilon_s: float = 1e-3, + C_1: float = 0.001, + C_2: float = 0.008, + M: int = 1000, + alpha: float = 1.0, + S_churn: int = 0, + S_min: float = 0.0, + S_max: float = float("inf"), + S_noise: float = 1.0, +) -> torch.Tensor: """ - Generalized sampler, representing the superset of all sampling methods discussed - in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" + Generalized sampler, representing the superset of all sampling methods + discussed in the paper "Elucidating the Design Space of Diffusion-Based + Generative Models" (EDM). + - https://arxiv.org/abs/2206.00364 + + This function integrates an ODE (probability flow) or SDE over multiple + time-steps to generate samples from the diffusion model provided by the + argument 'net'. It can be used to combine multiple choices to + design a custom sampler, including multiple integration solver, + discretization method, noise schedule, and so on. + + Parameters: + ----------- + net : torch.nn.Module + The diffusion model to use in the sampling process. + latents : torch.Tensor + The latent random noise used as the initial condition for the + stochastic ODE. + img_lr : torch.Tensor + Low-resolution input image for conditioning the diffusion process. + Passed as a keywork argument to the model 'net'. + class_labels : Optional[torch.Tensor] + Labels of the classes used as input to a class-conditionned + diffusion model. Passed as a keyword argument to the model 'net'. + If provided, it must be a tensor containing integer values. + Defaults to None, in which case it is ignored. + randn_like: Callable + Random Number Generator to generate random noise that is added + during the stochastic sampling. Must have the same signature as + torch.randn_like and return torch.Tensor. Defaults to + torch.randn_like. + num_steps : Optional[int] + Number of time-steps for the stochastic ODE integration. Defaults + to 18. + sigma_min : Optional[float] + Minimum noise level for the diffusion process. 'sigma_min', + 'sigma_max', and 'rho' are used to compute the time-step + discretization, based on the choice of discretization. For the + default choice ("discretization='heun'"), the noise level schedule + is computed as: + :math:`\sigma_i = (\sigma_{max}^{1/\rho} + i / (num_steps - 1) * (\sigma_{min}^{1/\rho} - \sigma_{max}^{1/\rho}))^{rho}`. + For other choices of 'discretization', see details in the EDM + paper. Defaults to None, in which case defaults values depending + of the specified discretization are used. + sigma_max : Optional[float] + Maximum noise level for the diffusion process. See sigma_min for + details. Defaults to None, in which case defaults values depending + of the specified discretization are used. + rho : float, optional + Exponent used in the noise schedule. See sigma_min for details. + Only used when 'discretization' is 'heun'. Values in the range [5, + 10] produce better images. Lower values lead to truncation errors + equalized over all time steps. Defaults to 7. + solver : Literal["heun", "euler"] + The numerical method used to integrate the stochastic ODE. "euler" + is 1st order solver, which is faster but produces lower-quality + images. "heun" is 2nd order, more expensive, but produces + higher-quality images. Defaults to "heun". + discretization : Literal["vp", "ve", "iddpm", "edm"] + The method to discretize time-steps :math:`t_i` in the + diffusion process. See the EDM papper for details. Defaults to + "edm". + schedule : Literal["vp", "ve", "linear"] + The type of noise level schedule. Defaults to "linear". If + schedule='ve', then :math:`\sigma(t) = \sqrt{t}`. If + schedule='linear', then :math:`\sigma(t) = t`. If schedule='vp', + see EDM paper for details. Defaults to "linear". + scaling : Literal["vp", "none"] + The type of time-dependent signal scaling :math:`s(t)`, such that + :math:`x = s(t) \hat{x}`. See EDM paper for details on the 'vp' + scaling. Defaults to 'none', in which case :math:`s(t)=1`. + epsilon_s : float, optional + Parameter to compute both the noise level schedule and the + time-step discetization. Only used when discretization='vp' or + schedule='vp'. Ignored in other cases. Defaults to 1e-3. + C_1 : float, optional + Parameters to compute the time-step discetization. Only used when + discretization='iddpm'. Defaults to 0.001. + C_2 : float, optional + Same as for C_1. Only used when discretization='iddpm'. Defaults to + 0.008. + M : int, optional + Same as for C_1 and C_2. Only used when discretization='iddpm'. + Defaults to 1000. + alpha : float, optional + Controls (i.e. multiplies) the step size :math:`t_{i+1} - + \hat{t}_i` in the stochastic sampler, where :math:`\hat{t}_i` is + the temporarily increased noise level. Defaults to 1.0, which is + the recommended value. + S_churn : int, optional + Controls the amount of stochasticty injected in the SDE in the + stochatsic sampler. Larger values of S_churn lead to larger values + of :math:`\hat{t}_i`, which in turn lead to injecting more + stochasticity in the SDE by Defaults to 0, which means no + stochasticity is injected. + S_min : float, optional + S_min and S_max control the time-step range obver which + stochasticty is injected in the SDE. Stochasticity is injected + through `\hat{t}_i` for time-steps :math:`t_i` such that + :math:`S_{min} \leq t_i \leq S_{max}`. Defaults to 0.0. + S_max : float, optional + See S_min. Defaults to float("inf"). + S_noise : float, optional + Controls the amount of stochasticty injected in the SDE in the + stochatsic sampler. Added signal noise is proportinal to + :math:`\epsilon_i` where `\epsilon_i ~ N(0, S_{noise}^2)`. Defaults + to 1.0. + + Returns + ------- + torch.Tensor: + Generated batch of samples. Same shape as the input 'latents'. """ # conditioning diff --git a/src/hirad/utils/function_utils.py b/src/hirad/utils/function_utils.py index dcbb127..347457c 100644 --- a/src/hirad/utils/function_utils.py +++ b/src/hirad/utils/function_utils.py @@ -29,7 +29,7 @@ import sys import types import warnings -from typing import Any, List, Tuple, Union +from typing import Any, Iterator, List, Tuple, Union import cftime import numpy as np @@ -553,14 +553,37 @@ def decorator(*args, **kwargs): # indefinitely, shuffling items as it goes. -class InfiniteSampler(torch.utils.data.Sampler): # pragma: no cover - """ - Sampler for torch.utils.data.DataLoader that loops over the dataset - indefinitely, shuffling items as it goes. +class InfiniteSampler(torch.utils.data.Sampler[int]): # pragma: no cover + """Sampler for torch.utils.data.DataLoader that loops over the dataset indefinitely. + + This sampler yields indices indefinitely, optionally shuffling items as it goes. + It can also perform distributed sampling when rank and num_replicas are specified. + + Parameters + ---------- + dataset : torch.utils.data.Dataset + The dataset to sample from + rank : int, default=0 + The rank of the current process within num_replicas processes + num_replicas : int, default=1 + The number of processes participating in distributed sampling + shuffle : bool, default=True + Whether to shuffle the indices + seed : int, default=0 + Random seed for reproducibility when shuffling + window_size : float, default=0.5 + Fraction of dataset to use as window for shuffling. Must be between 0 and 1. + A larger window means more thorough shuffling but slower iteration. """ def __init__( - self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 + self, + dataset: torch.utils.data.Dataset, + rank: int = 0, + num_replicas: int = 1, + shuffle: bool = True, + seed: int = 0, + window_size: float = 0.5, ): if not len(dataset) > 0: raise ValueError("Dataset must contain at least one item") @@ -578,7 +601,7 @@ def __init__( self.seed = seed self.window_size = window_size - def __iter__(self): + def __iter__(self) -> Iterator[int]: order = np.arange(len(self.dataset)) rnd = None window = 0 diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index ace05ba..8665536 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import datetime +from typing import Optional import cftime import nvtx @@ -23,6 +24,9 @@ from .function_utils import StackedRandomGenerator, time_range +from .stochastic_sampler import stochastic_sampler +from .deterministic_sampler import deterministic_sampler + ############################################################################ # CorrDiff Generation Utilities # ############################################################################ @@ -31,35 +35,56 @@ def regression_step( net: torch.nn.Module, img_lr: torch.Tensor, - labels: torch.Tensor, latents_shape: torch.Size, - lead_time_label: torch.Tensor = None, + lead_time_label: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Given a low-res input, performs a regression step to produce ensemble mean. - This function performs the regression on a single instance and then replicates - the results across the batch dimension. - - Args: - net (torch.nn.Module): U-Net model for regression. - img_lr (torch.Tensor): Low-resolution input. - latents_shape (torch.Size): Shape of the latent representation. Typically - (batch_size, out_channels, image_shape_x, image_shape_y). - - - Returns: - torch.Tensor: Predicted output at the next time step. + Perform a regression step to produce ensemble mean prediction. + + This function takes a low-resolution input and performs a regression step to produce + an ensemble mean prediction. It processes a single instance and then replicates + the results across the batch dimension if needed. + + Parameters + ---------- + net : torch.nn.Module + U-Net model for regression. + img_lr : torch.Tensor + Low-resolution input to the network with shape (1, channels, height, width). + Must have a batch dimension of 1. + latents_shape : torch.Size + Shape of the latent representation with format + (batch_size, out_channels, image_shape_y, image_shape_x). + lead_time_label : Optional[torch.Tensor], optional + Lead time label tensor for lead time conditioning, + with shape (1, lead_time_dims). Default is None. + + Returns + ------- + torch.Tensor + Predicted ensemble mean at the next time step with shape matching latents_shape. + + Raises + ------ + ValueError + If img_lr has a batch size greater than 1. """ # Create a tensor of zeros with the given shape and move it to the appropriate device x_hat = torch.zeros(latents_shape, dtype=img_lr.dtype, device=img_lr.device) - t_hat = torch.tensor(1.0, dtype=img_lr.dtype, device=img_lr.device)#.reshape((1,1,1,1)) + + # Safety check: avoid silently ignoring batch elements in img_lr + if img_lr.shape[0] > 1: + raise ValueError( + f"Expected img_lr to have a batch size of 1, " + f"but found {img_lr.shape[0]}." + ) # Perform regression on a single batch element with torch.inference_mode(): if lead_time_label is not None: - x = net(x_hat[0:1], img_lr, t_hat, labels, lead_time_label=lead_time_label) + x = net(x=x_hat[0:1], img_lr=img_lr, lead_time_label=lead_time_label) else: - x = net(x_hat[0:1], img_lr, t_hat, labels) + x = net(x=x_hat[0:1], img_lr=img_lr) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: @@ -68,48 +93,85 @@ def regression_step( return x -def diffusion_step( # TODO generalize the module and add defaults +def diffusion_step( net: torch.nn.Module, sampler_fn: callable, - seed_batch_size: int, img_shape: tuple, img_out_channels: int, rank_batches: list, img_lr: torch.Tensor, rank: int, device: torch.device, - hr_mean: torch.Tensor = None, + mean_hr: torch.Tensor = None, lead_time_label: torch.Tensor = None, ) -> torch.Tensor: """ Generate images using diffusion techniques as described in the relevant paper. - Args: - net (torch.nn.Module): The diffusion model network. - sampler_fn (callable): Function used to sample images from the diffusion model. - seed_batch_size (int): Number of seeds per batch. - img_shape (tuple): Shape of the images, (height, width). - img_out_channels (int): Number of output channels for the image. - rank_batches (list): List of batches of seeds to process. - img_lr (torch.Tensor): Low-resolution input image. - rank (int): Rank of the current process for distributed processing. - device (torch.device): Device to perform computations. - mean_hr (torch.Tensor, optional): High-resolution mean tensor, to be used as an additional input. By default None. - - Returns: - torch.Tensor: Generated images concatenated across batches. + This function applies a diffusion model to generate high-resolution images based on + low-resolution inputs. It supports optional conditioning on high-resolution mean + predictions and lead time labels. + + For each low-resolution sample in `img_lr`, the function generates multiple + high-resolution samples, with different random seeds, specified in `rank_batches`. + The function then concatenates these high-resolution samples across the batch dimension. + + Parameters + ---------- + net : torch.nn.Module + The diffusion model network. + sampler_fn : callable + Function used to sample images from the diffusion model. + img_shape : tuple + Shape of the images, (height, width). + img_out_channels : int + Number of output channels for the image. + rank_batches : list + List of batches of seeds to process. + img_lr : torch.Tensor + Low-resolution input image with shape (seed_batch_size, channels_lr, height, width). + rank : int, optional + Rank of the current process for distributed processing. + device : torch.device, optional + Device to perform computations. + mean_hr : torch.Tensor, optional + High-resolution mean tensor to be used as an additional input, + with shape (1, channels_hr, height, width). Default is None. + lead_time_label : torch.Tensor, optional + Lead time label tensor for temporal conditioning, + with shape (batch_size, lead_time_dims). Default is None. + + Returns + ------- + torch.Tensor + Generated images concatenated across batches with shape + (seed_batch_size * len(rank_batches), out_channels, height, width). """ - img_lr = img_lr #.to(memory_format=torch.channels_last) + # Check img_lr dimensions match expected shape + if img_lr.shape[2:] != img_shape: + raise ValueError( + f"img_lr shape {img_lr.shape[2:]} does not match expected shape img_shape {img_shape}" + ) + + # Check mean_hr dimensions if provided + if mean_hr is not None: + if mean_hr.shape[2:] != img_shape: + raise ValueError( + f"mean_hr shape {mean_hr.shape[2:]} does not match expected shape img_shape {img_shape}" + ) + if mean_hr.shape[0] != 1: + raise ValueError(f"mean_hr must have batch size 1, got {mean_hr.shape[0]}") + + img_lr = img_lr.to(memory_format=torch.channels_last) # Handling of the high-res mean additional_args = {} - if hr_mean is not None: - additional_args["mean_hr"] = hr_mean + if mean_hr is not None: + additional_args["mean_hr"] = mean_hr if lead_time_label is not None: additional_args["lead_time_label"] = lead_time_label - additional_args["img_shape"] = img_shape # Loop over batches all_images = [] @@ -123,7 +185,7 @@ def diffusion_step( # TODO generalize the module and add defaults rnd = StackedRandomGenerator(device, batch_seeds) latents = rnd.randn( [ - seed_batch_size, + img_lr.shape[0], img_out_channels, img_shape[0], img_shape[1], @@ -139,6 +201,9 @@ def diffusion_step( # TODO generalize the module and add defaults return torch.cat(all_images) +def generate(): + pass + ############################################################################ # CorrDiff writer utilities # ############################################################################ diff --git a/src/hirad/utils/patching.py b/src/hirad/utils/patching.py new file mode 100644 index 0000000..6f4bc4d --- /dev/null +++ b/src/hirad/utils/patching.py @@ -0,0 +1,767 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import random +import warnings +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import Tensor + +""" +This module defines utilities, including classes and functions, for domain +decomposition. +""" + + +class BasePatching2D(ABC): + """ + Abstract base class for 2D image patching operations. + + This class provides a foundation for implementing various image patching + strategies. + It handles basic validation and provides abstract methods that must be + implemented by subclasses. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + """ + + def __init__( + self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int] + ) -> None: + # Check that img_shape and patch_shape are 2D + if len(img_shape) != 2: + raise ValueError(f"img_shape must be 2D, got {len(img_shape)}D") + if len(patch_shape) != 2: + raise ValueError(f"patch_shape must be 2D, got {len(patch_shape)}D") + + # Make sure patches fit within the image + if any(p > i for p, i in zip(patch_shape, img_shape)): + warnings.warn( + f"Patch shape {patch_shape} is larger than " + f"image shape {img_shape}. " + f"Patches will be cropped to fit within the image." + ) + self.img_shape = img_shape + self.patch_shape = tuple(min(p, i) for p, i in zip(patch_shape, img_shape)) + + @abstractmethod + def apply(self, input: Tensor, **kwargs) -> Tensor: + """ + Apply the patching operation to the input tensor. + + Parameters + ---------- + input : Tensor + Input tensor of shape (batch_size, channels, img_shape_y, + img_shape_x). + **kwargs : dict + Additional keyword arguments specific to the patching + implementation. + + Returns + ------- + Tensor + Patched tensor, shape depends on specific implementation. + """ + pass + + def fuse(self, input: Tensor, **kwargs) -> Tensor: + """ + Fuse patches back into a complete image. + + Parameters + ---------- + input : Tensor + Input tensor containing patches. + **kwargs : dict + Additional keyword arguments specific to the fusion implementation. + + Returns + ------- + Tensor + Fused tensor, shape depends on specific implementation. + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + """ + raise NotImplementedError("'fuse' method must be implemented in subclasses.") + + def global_index( + self, batch_size: int, device: Union[torch.device, str] = "cpu" + ) -> Tensor: + """ + Returns a tensor containing the global indices for each patch. + + Global indices correspond to (y, x) global grid coordinates of each + element within the original image (before patching). It is typically + used to keep track of the original position of each patch in the + original image. + + Parameters + ---------- + batch_size : int + The size of the batch of images to patch. + device : Union[torch.device, str] + Proper device to initialize global_index on. Default to `cpu` + + Returns + ------- + Tensor + A tensor of shape (self.patch_num, 2, patch_shape_y, + patch_shape_x). `global_index[:, 0, :, :]` contains the + y-coordinate (height), and `global_index[:, 1, :, :]` contains the + x-coordinate (width). + """ + Ny = torch.arange(self.img_shape[0], device=device).int() + Nx = torch.arange(self.img_shape[1], device=device).int() + grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0).unsqueeze(0) + global_index = self.apply(grid).long() + return global_index + + +class RandomPatching2D(BasePatching2D): + """ + Class for randomly extracting patches from 2D images. + + This class provides utilities to randomly extract patches from images + represented as 4D tensors. It maintains a list of random patch indices + that can be reset as needed. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + patch_num : int + The number of patches to extract. + + Attributes + ---------- + patch_indices : List[Tuple[int, int]] + The indices of the patches to extract from the images. These indices + correspond to the (y, x) coordinates of the lower left corner of each + patch. + + See Also + -------- + :class:`physicsnemo.utils.patching.BasePatching2D` + The base class providing the patching interface. + :class:`physicsnemo.utils.patching.GridPatching2D` + Alternative patching strategy using deterministic patch locations. + """ + + def __init__( + self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int], patch_num: int + ) -> None: + """ + Initialize the RandomPatching2D object with the provided image shape, + patch shape, and number of patches to extract. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, + img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) + to extract. + patch_num : int + The number of patches to extract. + + Returns + ------- + None + """ + super().__init__(img_shape, patch_shape) + self._patch_num = patch_num + # Generate the indices of the patches to extract + self.reset_patch_indices() + + @property + def patch_num(self) -> int: + """ + Get the number of patches to extract. + + Returns + ------- + int + The number of patches to extract. + """ + return self._patch_num + + def set_patch_num(self, value: int) -> None: + """ + Set the number of patches to extract and reset patch indices. + This is the only way to modify the patch_num value. + + Parameters + ---------- + value : int + The new number of patches to extract. + """ + self._patch_num = value + self.reset_patch_indices() + + def reset_patch_indices(self) -> None: + """ + Generate new random indices for the patches to extract. These are the + starting indices of the patches to extract (upper left corner). + + Returns + ------- + None + """ + self.patch_indices = [ + ( + random.randint(0, self.img_shape[0] - self.patch_shape[0]), + random.randint(0, self.img_shape[1] - self.patch_shape[1]), + ) + for _ in range(self.patch_num) + ] + return + + def get_patch_indices(self) -> List[Tuple[int, int]]: + """ + Get the current list of patch starting indices. + + These are the upper-left coordinates of each extracted patch + from the full image. + + Returns + ------- + List[Tuple[int, int]] + A list of (row, column) tuples representing patch starting positions. + """ + return self.patch_indices + + def apply( + self, + input: Tensor, + additional_input: Optional[Tensor] = None, + ) -> Tensor: + """ + Applies the patching operation by extracting patches specified by + `self.patch_indices` from the `input` Tensor. Extracted patches are + batched along the first dimension of the output. The layout of the + output assumes that for any i, `out[B * i: B * (i + 1)]` + corresponds to the same patch exacted from each batch element of + `input`. + + Arguments + --------- + input : Tensor + The input tensor representing the full image with shape + (batch_size, channels_in, img_shape_y, img_shape_x). + additional_input : Optional[Tensor], optional + If provided, it is concatenated to each patch along `dim=1`. + Must have same batch size as `input`. Bilinear interpolation + is used to interpolate `additional_input` onto a 2D grid of shape + (patch_shape_y, patch_shape_x). + + Returns + ------- + Tensor + A tensor of shape (batch_size * self.patch_num, channels [+ + additional_channels], patch_shape_y, patch_shape_x). If + `additional_input` is provided, its channels are concatenated + along the channel dimension. + """ + B = input.shape[0] + out = torch.zeros( + B * self.patch_num, + ( + input.shape[1] + + (additional_input.shape[1] if additional_input is not None else 0) + ), + self.patch_shape[0], + self.patch_shape[1], + device=input.device, + ) + out = out.to( + memory_format=torch.channels_last + if input.is_contiguous(memory_format=torch.channels_last) + else torch.contiguous_format + ) + if additional_input is not None: + add_input_interp = torch.nn.functional.interpolate( + input=additional_input, size=self.patch_shape, mode="bilinear" + ) + + for i, (py, px) in enumerate(self.patch_indices): + if additional_input is not None: + out[B * i : B * (i + 1),] = torch.cat( + ( + input[ + :, + :, + py : py + self.patch_shape[0], + px : px + self.patch_shape[1], + ], + add_input_interp, + ), + dim=1, + ) + else: + out[B * i : B * (i + 1),] = input[ + :, + :, + py : py + self.patch_shape[0], + px : px + self.patch_shape[1], + ] + return out + + +class GridPatching2D(BasePatching2D): + """ + Class for deterministically extracting patches from 2D images in a grid pattern. + + This class provides utilities to extract patches from images in a + deterministic manner, with configurable overlap and boundary pixels. + The patches are extracted in a grid-like pattern covering the entire image. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + overlap_pix : int, optional + Number of pixels to overlap between adjacent patches, by default 0. + boundary_pix : int, optional + Number of pixels to crop as boundary from each patch, by default 0. + + Attributes + ---------- + patch_num : int + Total number of patches that will be extracted from the image, + calculated as patch_num_x * patch_num_y. + + See Also + -------- + :class:`physicsnemo.utils.patching.BasePatching2D` + The base class providing the patching interface. + :class:`physicsnemo.utils.patching.RandomPatching2D` + Alternative patching strategy using random patch locations. + """ + + def __init__( + self, + img_shape: Tuple[int, int], + patch_shape: Tuple[int, int], + overlap_pix: int = 0, + boundary_pix: int = 0, + ): + super().__init__(img_shape, patch_shape) + self.overlap_pix = overlap_pix + self.boundary_pix = boundary_pix + patch_num_x = math.ceil( + img_shape[1] / (patch_shape[1] - overlap_pix - boundary_pix) + ) + patch_num_y = math.ceil( + img_shape[0] / (patch_shape[0] - overlap_pix - boundary_pix) + ) + self.patch_num = patch_num_x * patch_num_y + + def apply( + self, + input: Tensor, + additional_input: Optional[Tensor] = None, + ) -> Tensor: + """ + Apply deterministic patching to the input tensor. + + Splits the input tensor into patches in a grid-like pattern. Can + optionally concatenate additional interpolated data to each patch. + Extracted patches are batched along the first dimension of the output. + The layout of the output assumes that for any i, `out[B * i: B * (i + 1)]` + corresponds to the same patch exacted from each batch element of + `input`. The patches can be reconstructed back into the original image + using the fuse method. + + Parameters + ---------- + input : Tensor + Input tensor of shape (batch_size, channels, img_shape_y, + img_shape_x). + additional_input : Optional[Tensor], optional + Additional data to concatenate to each patch. Will be interpolated + to match patch dimensions. Shape must be (batch_size, + additional_channels, H, W), by default None. + + Returns + ------- + Tensor + Tensor containing patches with shape (batch_size * patch_num, + channels [+ additional_channels], patch_shape_y, patch_shape_x). + If additional_input is provided, its channels are concatenated + along the channel dimension. + + See Also + -------- + :func:`physicsnemo.utils.patching.image_batching` + The underlying function used to perform the patching operation. + """ + if additional_input is not None: + add_input_interp = torch.nn.functional.interpolate( + input=additional_input, size=self.patch_shape, mode="bilinear" + ) + else: + add_input_interp = None + out = image_batching( + input=input, + patch_shape_y=self.patch_shape[0], + patch_shape_x=self.patch_shape[1], + overlap_pix=self.overlap_pix, + boundary_pix=self.boundary_pix, + input_interp=add_input_interp, + ) + return out + + def fuse(self, input: Tensor, batch_size: int) -> Tensor: + """ + Fuse patches back into a complete image. + + Reconstructs the original image by stitching together patches, + accounting for overlapping regions and boundary pixels. In overlapping + regions, values are averaged. + + Parameters + ---------- + input : Tensor + Input tensor containing patches with shape (batch_size * patch_num, + channels, patch_shape_y, patch_shape_x). + batch_size : int + The original batch size before patching. + + Returns + ------- + Tensor + Reconstructed image tensor with shape (batch_size, channels, + img_shape_y, img_shape_x). + + See Also + -------- + :func:`physicsnemo.utils.patching.image_fuse` + The underlying function used to perform the fusion operation. + """ + out = image_fuse( + input=input, + img_shape_y=self.img_shape[0], + img_shape_x=self.img_shape[1], + batch_size=batch_size, + overlap_pix=self.overlap_pix, + boundary_pix=self.boundary_pix, + ) + return out + + +def image_batching( + input: Tensor, + patch_shape_y: int, + patch_shape_x: int, + overlap_pix: int, + boundary_pix: int, + input_interp: Optional[Tensor] = None, +) -> Tensor: + """ + Splits a full image into a batch of patched images. + + This function takes a full image and splits it into patches, adding padding + where necessary. It can also concatenate additional interpolated data to + each patch if provided. + + Parameters + ---------- + input : Tensor + The input tensor representing the full image with shape (batch_size, + channels, img_shape_y, img_shape_x). + patch_shape_y : int + The height (y-dimension) of each image patch. + patch_shape_x : int + The width (x-dimension) of each image patch. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + input_interp : Optional[Tensor], optional + Optional additional data to concatenate to each patch with shape + (batch_size, interp_channels, patch_shape_y, patch_shape_x). + By default None. + + Returns + ------- + Tensor + A tensor containing the image patches, with shape (total_patches * + batch_size, channels [+ interp_channels], patch_shape_x, + patch_shape_y). + """ + # Infer sizes from input image + batch_size, _, img_shape_y, img_shape_x = input.shape + + # Safety check: make sure patch_shapes are large enough to accommodate + # overlaps and boundaries pixels + if (patch_shape_x - overlap_pix - boundary_pix) < 1: + raise ValueError( + f"patch_shape_x must verify patch_shape_x ({patch_shape_x}) >= " + f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})" + ) + if (patch_shape_y - overlap_pix - boundary_pix) < 1: + raise ValueError( + f"patch_shape_y must verify patch_shape_y ({patch_shape_y}) >= " + f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})" + ) + # Safety check: validate input_interp dimensions if provided + if input_interp is not None: + if input_interp.shape[0] != batch_size: + raise ValueError( + f"input_interp batch size ({input_interp.shape[0]}) must match " + f"input batch size ({batch_size})" + ) + if (input_interp.shape[2] != patch_shape_y) or ( + input_interp.shape[3] != patch_shape_x + ): + raise ValueError( + f"input_interp patch shape ({input_interp.shape[2]}, {input_interp.shape[3]}) " + f"must match specified patch shape ({patch_shape_y}, {patch_shape_x})" + ) + + # Safety check: make sure patch_shape is large enough in comparison to + # overlap_pix and boundary_pix. Otherwise, number of patches extracted by + # unfold differs from the expected number of patches. + if patch_shape_x <= overlap_pix + 2 * boundary_pix: + raise ValueError( + f"patch_shape_x ({patch_shape_x}) must verify " + f"patch_shape_x ({patch_shape_x}) > " + f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})" + ) + if patch_shape_y <= overlap_pix + 2 * boundary_pix: + raise ValueError( + f"patch_shape_y ({patch_shape_y}) must verify " + f"patch_shape_y ({patch_shape_y}) > " + f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})" + ) + + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + image_padding = torch.nn.ReflectionPad2d( + (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + ).to( + input.device + ) # (padding_left,padding_right,padding_top,padding_bottom) + input_padded = image_padding(input) + patch_num = patch_num_x * patch_num_y + x_unfold = torch.nn.functional.unfold( + input=input_padded.view(_cast_type(input_padded)), # Cast to float + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ).to(input_padded.dtype) + x_unfold = rearrange( + x_unfold, + "b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w", + p_h=patch_shape_y, + p_w=patch_shape_x, + nb_p_h=patch_num_y, + nb_p_w=patch_num_x, + ) + if input_interp is not None: + input_interp_repeated = rearrange( + torch.repeat_interleave( + input=input_interp, + repeats=patch_num, + dim=0, + output_size=x_unfold.shape[0], + ), + "(b p) c h w -> (p b) c h w", + p=patch_num, + ) + return torch.cat((x_unfold, input_interp_repeated), dim=1) + else: + return x_unfold + + +def image_fuse( + input: Tensor, + img_shape_y: int, + img_shape_x: int, + batch_size: int, + overlap_pix: int, + boundary_pix: int, +) -> Tensor: + """ + Reconstructs a full image from a batch of patched images. Reverts the patching + operation performed by image_batching(). + + This function takes a batch of image patches and reconstructs the full + image by stitching the patches together. The function accounts for + overlapping and boundary pixels, ensuring that overlapping areas are + averaged. + + Parameters + ---------- + input : Tensor + The input tensor containing the image patches with shape (patch_num * batch_size, channels, patch_shape_y, patch_shape_x). + img_shape_y : int + The height (y-dimension) of the original full image. + img_shape_x : int + The width (x-dimension) of the original full image. + batch_size : int + The original batch size before patching. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + + Returns + ------- + Tensor + The reconstructed full image tensor with shape (batch_size, channels, + img_shape_y, img_shape_x). + + See Also + -------- + :func:`physicsnemo.utils.patching.image_batching` + The function this reverses, which splits images into patches. + """ + + # Infer sizes from input image shape + patch_shape_y, patch_shape_x = input.shape[2], input.shape[3] + + # Calculate the number of patches in each dimension + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + + # Calculate the shape of the input after padding + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + # Calculate the shape of the padding to add to input + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + pad = (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + + # Count local overlaps between patches + input_ones = torch.ones( + (batch_size, input.shape[1], padded_shape_y, padded_shape_x), + device=input.device, + ) + overlap_count = torch.nn.functional.unfold( + input=input_ones, + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + overlap_count = torch.nn.functional.fold( + input=overlap_count, + output_size=(padded_shape_y, padded_shape_x), + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + + # Reshape input to make it 3D to apply fold + x = rearrange( + input, + "(nb_p_w nb_p_h b) c p_h p_w -> b (c p_h p_w) (nb_p_h nb_p_w)", + p_h=patch_shape_y, + p_w=patch_shape_x, + nb_p_h=patch_num_y, + nb_p_w=patch_num_x, + ) + # Stitch patches together (by summing over overlapping patches) + x_folded = torch.nn.functional.fold( + input=x, + output_size=(padded_shape_y, padded_shape_x), + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + + # Remove padding + x_no_padding = x_folded[ + ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x + ] + overlap_count_no_padding = overlap_count[ + ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x + ] + + # Normalize by overlap count + return x_no_padding / overlap_count_no_padding + + +def _cast_type(input: Tensor) -> torch.dtype: + """Return float type based on input tensor type. + + Parameters + ---------- + input : Tensor + Input tensor to determine float type from + + Returns + ------- + torch.dtype + Float type corresponding to input tensor type for int32/64, + otherwise returns original dtype + """ + if input.dtype == torch.int32: + return torch.float32 + elif input.dtype == torch.int64: + return torch.float64 + else: + return input.dtype diff --git a/src/hirad/utils/stochastic_sampler.py b/src/hirad/utils/stochastic_sampler.py index ac5c13b..198fde4 100644 --- a/src/hirad/utils/stochastic_sampler.py +++ b/src/hirad/utils/stochastic_sampler.py @@ -15,290 +15,23 @@ # limitations under the License. -import math -from typing import Any, Callable, Optional +from typing import Callable, Optional import torch from torch import Tensor - -def image_batching( - input: Tensor, - img_shape_y: int, - img_shape_x: int, - patch_shape_y: int, - patch_shape_x: int, - batch_size: int, - overlap_pix: int, - boundary_pix: int, - input_interp: Optional[Tensor] = None, -) -> Tensor: - """ - Splits a full image into a batch of patched images. - - This function takes a full image and splits it into patches, adding padding where necessary. - It can also concatenate additional interpolated data to each patch if provided. - - Parameters - ---------- - input : Tensor - The input tensor representing the full image with shape (batch_size, channels, img_shape_x, img_shape_y). - img_shape_x : int - The width (x-dimension) of the original full image. - img_shape_y : int - The height (y-dimension) of the original full image. - patch_shape_x : int - The width (x-dimension) of each image patch. - patch_shape_y : int - The height (y-dimension) of each image patch. - batch_size : int - The original batch size before patching. - overlap_pix : int - The number of overlapping pixels between adjacent patches. - boundary_pix : int - The number of pixels to crop as a boundary from each patch. - input_interp : Optional[Tensor], optional - Optional additional data to concatenate to each patch with shape (batch_size, interp_channels, patch_shape_x, patch_shape_y). - By default None. - - Returns - ------- - Tensor - A tensor containing the image patches, with shape (total_patches * batch_size, channels [+ interp_channels], patch_shape_x, patch_shape_y). - """ - patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) - patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) - padded_shape_x = ( - (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) - + patch_shape_x - + boundary_pix - ) - padded_shape_y = ( - (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) - + patch_shape_y - + boundary_pix - ) - pad_x_right = padded_shape_x - img_shape_x - boundary_pix - pad_y_right = padded_shape_y - img_shape_y - boundary_pix - input_padded = torch.zeros( - input.shape[0], input.shape[1], padded_shape_y, padded_shape_x - ).to(input.device) - image_padding = torch.nn.ReflectionPad2d( - (boundary_pix, pad_x_right, boundary_pix, pad_y_right) - ).to( - input.device - ) # (padding_left,padding_right,padding_top,padding_bottom) - input_padded = image_padding(input) - patch_num = patch_num_x * patch_num_y - if input_interp is not None: - output = torch.zeros( - patch_num * batch_size, - input.shape[1] + input_interp.shape[1], - patch_shape_y, - patch_shape_x, - ).to(input.device) - else: - output = torch.zeros( - patch_num * batch_size, input.shape[1], patch_shape_y, patch_shape_x - ).to(input.device) - for x_index in range(patch_num_x): - for y_index in range(patch_num_y): - x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) - y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) - if input_interp is not None: - output[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - ] = torch.cat( - ( - input_padded[ - :, - :, - y_start : y_start + patch_shape_y, - x_start : x_start + patch_shape_x, - ], - input_interp, - ), - dim=1, - ) - else: - output[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - ] = input_padded[ - :, - :, - y_start : y_start + patch_shape_y, - x_start : x_start + patch_shape_x, - ] - return output - - -def image_fuse( - input: Tensor, - img_shape_y: int, - img_shape_x: int, - patch_shape_y: int, - patch_shape_x: int, - batch_size: int, - overlap_pix: int, - boundary_pix: int, -) -> Tensor: - """ - Reconstructs a full image from a batch of patched images. - - This function takes a batch of image patches and reconstructs the full image - by stitching the patches together. The function accounts for overlapping and - boundary pixels, ensuring that overlapping areas are averaged. - - Parameters - ---------- - input : Tensor - The input tensor containing the image patches with shape (total_patches * batch_size, channels, patch_shape_x, patch_shape_y). - img_shape_x : int - The width (x-dimension) of the original full image. - img_shape_y : int - The height (y-dimension) of the original full image. - patch_shape_x : int - The width (x-dimension) of each image patch. - patch_shape_y : int - The height (y-dimension) of each image patch. - batch_size : int - The original batch size before patching. - overlap_pix : int - The number of overlapping pixels between adjacent patches. - boundary_pix : int - The number of pixels to crop as a boundary from each patch. - - Returns - ------- - Tensor - The reconstructed full image tensor with shape (batch_size, channels, img_shape_x, img_shape_y). - - """ - patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) - patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) - padded_shape_x = ( - (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) - + patch_shape_x - + boundary_pix - ) - padded_shape_y = ( - (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) - + patch_shape_y - + boundary_pix - ) - pad_x_right = padded_shape_x - img_shape_x - boundary_pix - pad_y_right = padded_shape_y - img_shape_y - boundary_pix - residual_x = patch_shape_x - pad_x_right # residual pixels in the last patch - residual_y = patch_shape_y - pad_y_right # residual pixels in the last patch - output = torch.zeros( - batch_size, input.shape[1], img_shape_y, img_shape_x, device=input.device - ) - one_map = torch.ones(1, 1, input.shape[2], input.shape[3], device=input.device) - count_map = torch.zeros( - 1, 1, img_shape_y, img_shape_x, device=input.device - ) # to count the overlapping times - for x_index in range(patch_num_x): - for y_index in range(patch_num_y): - x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) - y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) - if (x_index == patch_num_x - 1) and (y_index != patch_num_y - 1): - output[ - :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: - ] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - count_map[ - :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: - ] += one_map[ - :, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - elif (y_index == patch_num_y - 1) and ((x_index != patch_num_x - 1)): - output[ - :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix - ] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - count_map[ - :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix - ] += one_map[ - :, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - elif x_index == patch_num_x - 1 and y_index == patch_num_y - 1: - output[:, :, y_start:, x_start:] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - count_map[:, :, y_start:, x_start:] += one_map[ - :, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - else: - output[ - :, - :, - y_start : y_start + patch_shape_y - 2 * boundary_pix, - x_start : x_start + patch_shape_x - 2 * boundary_pix, - ] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - count_map[ - :, - :, - y_start : y_start + patch_shape_y - 2 * boundary_pix, - x_start : x_start + patch_shape_x - 2 * boundary_pix, - ] += one_map[ - :, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - return output / count_map +from hirad.utils.patching import GridPatching2D def stochastic_sampler( - net: Any, - latents: Tensor, - img_lr: Tensor, + net: torch.nn.Module, + latents: torch.Tensor, + img_lr: torch.Tensor, class_labels: Optional[Tensor] = None, randn_like: Callable[[Tensor], Tensor] = torch.randn_like, - img_shape: tuple[int,int] = (448,448), - patch_shape_x: int = 448, - patch_shape_y: int = 448, - overlap_pix: int = 4, - boundary_pix: int = 2, - mean_hr: Optional[Tensor] = None, - lead_time_label: Optional[Tensor] = None, + patching: Optional[GridPatching2D] = None, + mean_hr: Optional[torch.Tensor] = None, + lead_time_label: Optional[torch.Tensor] = None, num_steps: int = 18, sigma_min: float = 0.002, sigma_max: float = 800, @@ -307,33 +40,63 @@ def stochastic_sampler( S_min: float = 0, S_max: float = float("inf"), S_noise: float = 1, -) -> Tensor: +) -> torch.Tensor: """ - Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution and patch-based diffusion. + Proposed EDM sampler (Algorithm 2) with minor changes to enable + super-resolution and patch-based diffusion. Parameters ---------- - net : Any - The neural network model that generates denoised images from noisy inputs. + net : torch.nn.Module + The neural network model that generates denoised images from noisy + inputs. + Expected signature: `net(x, x_lr, t_hat, class_labels, + lead_time_label=lead_time_label, embedding_selector=embedding_selector)`, + where: + x (torch.Tensor): Noisy input of shape (batch_size, C_out, H, W) + x_lr (torch.Tensor): Conditioning input of shape (batch_size, C_cond, H, W) + t_hat (torch.Tensor): Noise level of shape (batch_size, 1, 1, 1) or scalar + class_labels (torch.Tensor, optional): Optional class labels + lead_time_label (torch.Tensor, optional): Optional lead time labels + embedding_selector (callable, optional): Function to select + positional embeddings. Used for patch-based diffusion. + Returns: + torch.Tensor: Denoised prediction of shape (batch_size, C_out, H, W) + + Required attributes: + sigma_min (float): Minimum supported noise level for the model + sigma_max (float): Maximum supported noise level for the model + round_sigma (callable): Method to convert sigma values to tensor representation latents : Tensor - The latent variables (e.g., noise) used as the initial input for the sampler. + The latent variables (e.g., noise) used as the initial input for the + sampler. Has shape (batch_size, C_out, img_shape_y, img_shape_x). img_lr : Tensor - Low-resolution input image for conditioning the super-resolution process. + Low-resolution input image for conditioning the super-resolution + process. Must have shape (batch_size, C_lr, img_lr_ shape_y, + img_lr_shape_x). class_labels : Optional[Tensor], optional - Class labels for conditional generation, if required by the model. By default None. + Class labels for conditional generation, if required by the model. By + default None. randn_like : Callable[[Tensor], Tensor] - Function to generate random noise with the same shape as the input tensor. + Function to generate random noise with the same shape as the input + tensor. By default torch.randn_like. - img_shape : int - The height and width of the full image (assumed to be square). By default 448. - patch_shape : int - The height and width of each patch (assumed to be square). By default 448. - overlap_pix : int - Number of overlapping pixels between adjacent patches. By default 4. - boundary_pix : int - Number of pixels to be cropped as a boundary from each patch. By default 2. + patching : Optional[GridPatching2D], optional + A patching utility for patch-based diffusion. Implements methods to + extract patches from an image and batch the patches along `dim=0`. + Should also implement a `fuse` method to reconstruct the original image + from a batch of patches. See + :class:`physicsnemo.utils.patching.GridPatching2D` for details. By + default None, in which case non-patched diffusion is used. mean_hr : Optional[Tensor], optional - Optional tensor containing mean high-resolution images for conditioning. By default None. + Optional tensor containing mean high-resolution images for + conditioning. Must have same height and width as `img_lr`, with shape + (B_hr, C_hr, img_lr_shape_y, img_lr_shape_x) where the batch dimension + B_hr can be either 1, either equal to batch_size, or can be omitted. If + B_hr = 1 or is omitted, `mean_hr` will be expanded to match the shape + of `img_lr`. By default None. + lead_time_label : Optional[Tensor], optional + Optional lead time labels. By default None. num_steps : int Number of time steps for the sampler. By default 18. sigma_min : float @@ -343,7 +106,8 @@ def stochastic_sampler( rho : float Exponent used in the time step discretization. By default 7. S_churn : float - Churn parameter controlling the level of noise added in each step. By default 0. + Churn parameter controlling the level of noise added in each step. By + default 0. S_min : float Minimum time step for applying churn. By default 0. S_max : float @@ -354,20 +118,40 @@ def stochastic_sampler( Returns ------- Tensor - The final denoised image produced by the sampler. + The final denoised image produced by the sampler. Same shape as + `latents`: (batch_size, C_out, img_shape_y, img_shape_x). + + See Also + -------- + :class:`physicsnemo.models.diffusion.EDMPrecondSuperResolution`: A model + wrapper that provides preconditioning for super-resolution diffusion + models and implements the required interface for this sampler. """ # Adjust noise levels based on what's supported by the network. - "Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution." + # Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution. sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) - # if isinstance(img_shape, tuple): - # img_shape_y, img_shape_x = img_shape - # else: - # img_shape_x = img_shape_y = img_shape - img_shape_x, img_shape_y = img_shape - patch_shape_x = min(img_shape_x, patch_shape_x) - patch_shape_y = min(img_shape_y, patch_shape_y) + + if patching is not None and not isinstance(patching, GridPatching2D): + raise ValueError("patching must be an instance of GridPatching2D.") + + # Safety check: if patching is used then img_lr and latents must have same + # height and width, otherwise there is mismatch in the number + # of patches extracted to form the final batch_size. + if patching: + if img_lr.shape[-2:] != latents.shape[-2:]: + raise ValueError( + f"img_lr and latents must have the same height and width, " + f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. " + ) + # img_lr and latents must also have the same batch_size, otherwise mismatch + # when processed by the network + if img_lr.shape[0] != latents.shape[0]: + raise ValueError( + f"img_lr and latents must have the same batch size, but found " + f"{img_lr.shape[0]} vs {latents.shape[0]}." + ) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) @@ -381,46 +165,32 @@ def stochastic_sampler( [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] ) # t_N = 0 - b = latents.shape[0] - Nx = torch.arange(img_shape_x) - Ny = torch.arange(img_shape_y) - grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ - None, - ].expand(b, -1, -1, -1) + batch_size = img_lr.shape[0] # conditioning = [mean_hr, img_lr, global_lr, pos_embd] - batch_size = img_lr.shape[0] x_lr = img_lr if mean_hr is not None: + if mean_hr.shape[-2:] != img_lr.shape[-2:]: + raise ValueError( + f"mean_hr and img_lr must have the same height and width, " + f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}." + ) x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) - global_index = None # input and position padding + patching - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: - input_interp = torch.nn.functional.interpolate( - img_lr, (patch_shape_x, patch_shape_y), mode="bilinear" - ) - x_lr = image_batching( - x_lr, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - input_interp, - ) - global_index = image_batching( - grid.float(), - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ).int() + if patching: + # Patched conditioning [x_lr, mean_hr] + # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) + x_lr = patching.apply(input=x_lr, additional_input=img_lr) + + # Function to select the correct positional embedding for each patch + def patch_embedding_selector(emb): + # emb: (N_pe, image_shape_y, image_shape_x) + # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) + return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + + else: + patch_embedding_selector = None # Main sampling loop. x_next = latents.to(torch.float64) * t_steps[0] @@ -432,26 +202,14 @@ def stochastic_sampler( x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) - # Euler step. Perform patching operation on score tensor if patch-based generation is used - # denoised = net(x_hat, t_hat, class_labels,lead_time_label=lead_time_label).to(torch.float64) #x_lr + # Euler step. Perform patching operation on score tensor if patch-based + # generation is used denoised = net(x_hat, t_hat, + # class_labels,lead_time_label=lead_time_label).to(torch.float64) - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: - x_hat_batch = image_batching( - x_hat, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) - else: - x_hat_batch = x_hat - x_hat_batch = x_hat_batch.to(latents.device) + x_hat_batch = (patching.apply(input=x_hat) if patching else x_hat).to( + latents.device + ) x_lr = x_lr.to(latents.device) - if global_index is not None: - global_index = global_index.to(latents.device) if lead_time_label is not None: denoised = net( @@ -460,7 +218,7 @@ def stochastic_sampler( t_hat, class_labels, lead_time_label=lead_time_label, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) else: # print("Sizes") @@ -474,40 +232,24 @@ def stochastic_sampler( x_lr, t_hat, class_labels, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: + if patching: + # Un-patch the denoised image + # (batch_size, C_out, img_shape_y, img_shape_x) + denoised = patching.fuse(input=denoised, batch_size=batch_size) - denoised = image_fuse( - denoised, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: - x_next_batch = image_batching( - x_next, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) - else: - x_next_batch = x_next - # ask about this fix - x_next_batch = x_next_batch.to(latents.device) + # Patched input + # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x) + x_next_batch = (patching.apply(input=x_next) if patching else x_next).to( + latents.device + ) + if lead_time_label is not None: denoised = net( x_next_batch, @@ -515,7 +257,7 @@ def stochastic_sampler( t_next, class_labels, lead_time_label=lead_time_label, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) else: denoised = net( @@ -523,19 +265,13 @@ def stochastic_sampler( x_lr, t_next, class_labels, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: - denoised = image_fuse( - denoised, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) + if patching: + # Un-patch the denoised image + # (batch_size, C_out, img_shape_y, img_shape_x) + denoised = patching.fuse(input=denoised, batch_size=batch_size) + d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) return x_next diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index d4529ac..218d6f1 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -17,6 +17,7 @@ import torch import numpy as np from omegaconf import ListConfig +import warnings def set_patch_shape(img_shape, patch_shape): @@ -26,12 +27,21 @@ def set_patch_shape(img_shape, patch_shape): patch_shape_x = img_shape_x if (patch_shape_y is None) or (patch_shape_y > img_shape_y): patch_shape_y = img_shape_y - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: + if patch_shape_x == img_shape_x and patch_shape_y == img_shape_y: + use_patching = False + else: + use_patching = True + if use_patching: if patch_shape_x != patch_shape_y: + warnings.warn( + f"You are using rectangular patches " + f"of shape {(patch_shape_y, patch_shape_x)}, " + f"which are an experimental feature." + ) raise NotImplementedError("Rectangular patch not supported yet") if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0: raise ValueError("Patch shape needs to be a multiple of 32") - return (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) + return use_patching, (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) def set_seed(rank): From 5db5e47088b534b5ddc0362439fa6699f8361263 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 26 May 2025 18:22:20 +0200 Subject: [PATCH 048/189] small config fix --- src/hirad/conf/training_era_cosmo_diffusion.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 4271e44..0a069e9 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -15,5 +15,7 @@ defaults: # Model - model/era_cosmo_diffusion + - model_size/normal + # Training - training/era_cosmo_diffusion \ No newline at end of file From e8bd5cd7b3b9ae61f000f5fd676cda0780526f7d Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 27 May 2025 12:09:35 +0200 Subject: [PATCH 049/189] fix generation on distributed --- src/hirad/inference/generate.py | 4 +- src/hirad/training/train.py | 83 ++++++++++++++------------------- 2 files changed, 38 insertions(+), 49 deletions(-) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 8fed809..ec385dc 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -269,9 +269,11 @@ def generate_fn(image_lr, lead_time_label): ) if dist.rank == 0: + if cfg.generation.inference_mode != "regression": + return torch.cat(gathered_tensors), image_reg[0:1,::] return torch.cat(gathered_tensors) else: - return None + return None, None else: #TODO do this for multi-gpu setting above too if cfg.generation.inference_mode != "regression": diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 794dd55..39b3653 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -227,21 +227,6 @@ def main(cfg: DictConfig) -> None: model.train().requires_grad_(True).to(dist.device) - # param_to_name = {} - # ppp = False - # for name, param in model.named_parameters(): - # pid = id(param) - # if pid in param_to_name: - # print(f"[SHARED PARAM] {name} == {param_to_name[pid]}") - # ppp = True - # break - # else: - # param_to_name[pid] = name - # print(f'There are shared parameters: {ppp}') - - # TODO write summry from rank=0 possibly - # summary(model, input_size=[(1,img_out_channels,*img_shape),(1,img_in_channels,*img_shape),(1,1)]) - if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: json.dump(model_args, f) @@ -572,6 +557,41 @@ def main(cfg: DictConfig) -> None: cur_nimg += cfg.training.hp.total_batch_size done = cur_nimg >= cfg.training.hp.training_duration + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + if torch.cuda.is_available(): + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + logger0.info(img_clean.shape) + logger0.info(img_lr.shape) + with nvtx.annotate("validation", color="red"): # Validation if validation_dataset_iterator is not None: @@ -671,39 +691,6 @@ def main(cfg: DictConfig) -> None: "validation_loss", average_valid_loss, cur_nimg ) - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - # Print stats if we crossed the printing threshold with this batch - tick_end_time = time.time() - fields = [] - fields += [f"samples {cur_nimg:<9.1f}"] - fields += [f"training_loss {average_loss:<7.2f}"] - fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] - fields += [f"learning_rate {current_lr:<7.8f}"] - fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] - fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] - fields += [ - f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" - ] - fields += [ - f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" - ] - if torch.cuda.is_available(): - fields += [ - f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" - ] - fields += [ - f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" - ] - torch.cuda.reset_peak_memory_stats() - logger0.info(" ".join(fields)) - # Save checkpoints if dist.world_size > 1: From 692dfe25d91af282d4f0b72f3a05237720d1fb76 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 27 May 2025 12:18:08 +0200 Subject: [PATCH 050/189] delete unnecessary logging --- src/hirad/training/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 39b3653..3a2fe2e 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -589,8 +589,6 @@ def main(cfg: DictConfig) -> None: ] torch.cuda.reset_peak_memory_stats() logger0.info(" ".join(fields)) - logger0.info(img_clean.shape) - logger0.info(img_lr.shape) with nvtx.annotate("validation", color="red"): # Validation From f16b71d5fae2e528fa7fd72171ec6f069262dbd1 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 5 Jun 2025 15:23:40 +0200 Subject: [PATCH 051/189] add patch_num fix to loss calcualtion --- src/hirad/conf/dataset/era_cosmo.yaml | 3 +- .../conf/training/era_cosmo_diffusion.yaml | 18 ++++---- .../conf/training/era_cosmo_regression.yaml | 17 ++++--- src/hirad/datasets/dataset.py | 2 +- src/hirad/inference/generate.py | 46 +++++++++++++------ src/hirad/training/train.py | 13 ++++-- 6 files changed, 64 insertions(+), 35 deletions(-) diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml index 63d7361..b1e21e6 100644 --- a/src/hirad/conf/dataset/era_cosmo.yaml +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -1,2 +1,3 @@ type: era5_cosmo -dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_overfit \ No newline at end of file +dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_full +validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_full/validation \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index f8d19e6..07cbb03 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -1,8 +1,8 @@ # Hyperparameters hp: - training_duration: 16 + training_duration: 5000000 # Training duration based on the number of processed samples - total_batch_size: 4 + total_batch_size: 128 # Total batch size batch_size_per_gpu: "auto" # Batch size per GPU @@ -14,28 +14,30 @@ hp: # LR decay rate lr_rampup: 0 # Rampup for learning rate, in number of samples + lr_decay_rate: 5e5 + # Learning rate decay threshold in number of samples, applied every lr_decay_rate samples. # Performance perf: fp_optimizations: amp-bf16 # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 8 + dataloader_workers: 10 # DataLoader worker processes songunet_checkpoint_level: 0 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint # I/O io: - regression_checkpoint_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression + regression_checkpoint_path: /capstor/scratch/cscs/boeschf/HiRAD-Gen/outputs_full/regression/checkpoints_regression/ # Where to load the regression checkpoint - print_progress_freq: 128 + print_progress_freq: 5000 # How often to print progress - save_checkpoint_freq: 5000 + save_checkpoint_freq: 250000 # How often to save the checkpoints, measured in number of processed samples - validation_freq: 5000 + validation_freq: 25000 # how often to record the validation loss, measured in number of processed samples - validation_steps: 10 + validation_steps: 4 # how many loss evaluations are used to compute the validation loss per checkpoint # how many loss evaluations are used to compute the validation loss per checkpoint checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index 76bdc4e..98c6c24 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,13 +1,12 @@ # Hyperparameters hp: - training_duration: 8 + training_duration: 500000 # Training duration based on the number of processed samples - total_batch_size: 4 + total_batch_size: 64 # Total batch size batch_size_per_gpu: "auto" # Batch size per GPU - lr: 0.001 - #0.0002 + lr: 0.0002 # Learning rate grad_clip_threshold: null # no gradient clipping for defualt non-patch-based training @@ -15,22 +14,26 @@ hp: # LR decay rate lr_rampup: 0 # Rampup for learning rate, in number of samples + lr_decay_rate: 5e5 + # Learning rate decay threshold in number of samples, applied every lr_decay_rate samples. # Performance perf: fp_optimizations: amp-bf16 # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 4 + dataloader_workers: 10 # DataLoader worker processes songunet_checkpoint_level: 0 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint + # torch_compile: True + # use_apex_gn: True # I/O io: - print_progress_freq: 128 + print_progress_freq: 1024 # How often to print progress - save_checkpoint_freq: 5000 + save_checkpoint_freq: 25000 # How often to save the checkpoints, measured in number of processed samples validation_freq: 5000 # how often to record the validation loss, measured in number of processed samples diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 7ba8833..c09a4f2 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -58,7 +58,7 @@ def init_train_valid_datasets_from_config( """ config = copy.deepcopy(dataset_cfg) - if 'validation_path': + if 'validation_path' in config: del config['validation_path'] (dataset, dataset_iter) = init_dataset_from_config( config, dataloader_cfg, batch_size=batch_size, seed=seed diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index ec385dc..3cc1eb4 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -271,7 +271,7 @@ def generate_fn(image_lr, lead_time_label): if dist.rank == 0: if cfg.generation.inference_mode != "regression": return torch.cat(gathered_tensors), image_reg[0:1,::] - return torch.cat(gathered_tensors) + return torch.cat(gathered_tensors), None else: return None, None else: @@ -408,16 +408,31 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) + if mean_pred is not None: + mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + freqs = {} power = {} for idx, channel in enumerate(output_channels): input_channel_idx = input_channels.index(channel) + + _plot_projection(longitudes, latitudes, target[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-target.jpg')) + _plot_projection(longitudes, latitudes, prediction[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-prediction.jpg')) + _plot_projection(longitudes, latitudes, baseline[input_channel_idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-input.jpg')) + if mean_pred is not None: + _plot_projection(longitudes, latitudes, mean_pred[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-mean_prediction.jpg')) + _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) _, prediction_errors = compute_mae(prediction[idx,:,:], target[idx,:,:]) + if mean_pred is not None: + _, mean_prediction_errors = compute_mae(mean_pred[idx,:,:], target[idx,:,:]) + plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-baseline-error.jpg')) plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-prediction-error.jpg')) + if mean_pred is not None: + plot_error_projection(mean_prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-mean-prediction-error.jpg')) b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0) freqs['baseline'] = b_freq @@ -429,6 +444,10 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, p_freq, p_power = average_power_spectrum(prediction[idx,:,:].squeeze(), 2.0) freqs['prediction'] = p_freq power['prediction'] = p_power + if mean_pred is not None: + mp_freq, mp_power = average_power_spectrum(mean_pred[idx,:,:].squeeze(), 2.0) + freqs['mean_prediction'] = mp_freq + power['mean_prediction'] = mp_power plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) # def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): @@ -459,19 +478,18 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, # if mean_pred is not None: # _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-mean-pred.jpg')) -# def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): +def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): -# """Plot observed or interpolated data in a scatter plot.""" -# # TODO: Refactor this somehow, it's not really generalizing well across variables. -# fig = plt.figure() -# fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -# p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) -# ax.coastlines() -# ax.gridlines(draw_labels=True) -# plt.colorbar(p, label="K", orientation="horizontal") -# plt.savefig(filename) -# plt.close('all') + """Plot observed or interpolated data in a scatter plot.""" + # TODO: Refactor this somehow, it's not really generalizing well across variables. + fig = plt.figure() + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) + ax.coastlines() + ax.gridlines(draw_labels=True) + plt.colorbar(p, label="K", orientation="horizontal") + plt.savefig(filename) + plt.close('all') if __name__ == "__main__": - main() - \ No newline at end of file + main() \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 3a2fe2e..1b43a6e 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -494,8 +494,12 @@ def main(cfg: DictConfig) -> None: ): loss = loss_fn(**loss_fn_kwargs) - loss = loss.sum() / batch_size_per_gpu - loss_accum += loss / num_accumulation_rounds + loss_accum += ( + loss + / num_accumulation_rounds + / len(patch_nums_iter) + ) + loss_accum += loss / num_accumulation_rounds with nvtx.annotate(f"loss backward", color="yellow"): loss.backward() @@ -544,7 +548,7 @@ def main(cfg: DictConfig) -> None: if lr_rampup > 0: g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) if cur_nimg >= lr_rampup: - g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) + g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // cfg.training.hp.lr_decay_rate) current_lr = g["lr"] if dist.rank == 0: writer.add_scalar("learning_rate", current_lr, cur_nimg) @@ -658,7 +662,7 @@ def main(cfg: DictConfig) -> None: for patch_num_per_iter in patch_nums_iter: if patching is not None: patching.set_patch_num(patch_num_per_iter) - loss_fn_kwargs.update( + loss_valid_kwargs.update( {"patching": patching} ) with torch.autocast( @@ -674,6 +678,7 @@ def main(cfg: DictConfig) -> None: valid_loss_accum += ( loss_valid / cfg.training.io.validation_steps + / len(patch_nums_iter) ) valid_loss_sum = torch.tensor( [valid_loss_accum], device=dist.device From 4337ba5fbd095baf4f0613b2cd94079321db7e52 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 12 Jun 2025 10:19:54 +0200 Subject: [PATCH 052/189] add log scale for precipitation plotting --- src/hirad/eval/__init__.py | 2 ++ src/hirad/inference/generate.py | 45 +++++++++++++-------------------- 2 files changed, 20 insertions(+), 27 deletions(-) create mode 100644 src/hirad/eval/__init__.py diff --git a/src/hirad/eval/__init__.py b/src/hirad/eval/__init__.py new file mode 100644 index 0000000..a3228ee --- /dev/null +++ b/src/hirad/eval/__init__.py @@ -0,0 +1,2 @@ +from .metrics import compute_mae, average_power_spectrum +from .plotting import plot_error_projection, plot_power_spectra \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 3cc1eb4..35f856f 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -396,6 +396,7 @@ def elapsed_time(self, _): f.close() logger0.info("Generation Completed.") + def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): os.makedirs(output_path, exist_ok=True) @@ -417,6 +418,13 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, for idx, channel in enumerate(output_channels): input_channel_idx = input_channels.index(channel) + if channel.name=="tp": + target[idx,::] = prepare_precipitaiton(target[idx,:,:]) + prediction[idx,::] = prepare_precipitaiton(prediction[idx,:,:]) + baseline[input_channel_idx,:,:] = prepare_precipitaiton(baseline[input_channel_idx]) + if mean_pred is not None: + mean_pred[idx,::] = prepare_precipitaiton(mean_pred[idx,::]) + _plot_projection(longitudes, latitudes, target[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-target.jpg')) _plot_projection(longitudes, latitudes, prediction[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-prediction.jpg')) _plot_projection(longitudes, latitudes, baseline[input_channel_idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-input.jpg')) @@ -450,33 +458,16 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, power['mean_prediction'] = mp_power plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) -# def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): -# longitudes = dataset.longitude() -# latitudes = dataset.latitude() -# input_channels = dataset.input_channels() -# output_channels = dataset.output_channels() -# image_pred = image_pred.numpy() -# image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) -# if image_pred.shape[0]>1: -# image_pred_mean = np.flip(dataset.denormalize_output(image_pred.mean(axis=0)),1).reshape(len(output_channels),-1) -# image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) -# image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[image_pred.shape[0]//2,::].squeeze()),1).reshape(len(output_channels),-1) -# image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) -# image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) -# if mean_pred is not None: -# mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) -# os.makedirs(output_path, exist_ok=True) -# for idx, channel in enumerate(output_channels): -# input_channel_idx = input_channels.index(channel) -# _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-lr.jpg')) -# _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) -# _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) -# if image_pred.shape[0]>1: -# _plot_projection(longitudes,latitudes,image_pred_mean[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mean.jpg')) -# _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) -# _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) -# if mean_pred is not None: -# _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-mean-pred.jpg')) + +def prepare_precipitaiton(precip_array): + precip_array = np.clip(precip_array, 0, None) + epsilon = 1e-2 + precip_array = precip_array + epsilon + precip_array = np.log(precip_array) + # log_min, log_max = precip_array.min(), precip_array.max() + # precip_array = (precip_array-log_min)/(log_max-log_min) + return precip_array + def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): From 2ebf7e14bf1e21bd402fa20cce8af3189b447965 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 13 Jun 2025 12:52:54 +0200 Subject: [PATCH 053/189] add missing loss sum --- src/hirad/training/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 1b43a6e..12b6942 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -11,7 +11,7 @@ from hydra.utils import to_absolute_path from torch.utils.tensorboard import SummaryWriter from torch.nn.parallel import DistributedDataParallel -from torchinfo import summary +# from torchinfo import summary from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper @@ -494,6 +494,7 @@ def main(cfg: DictConfig) -> None: ): loss = loss_fn(**loss_fn_kwargs) + loss = loss.sum() / batch_size_per_gpu loss_accum += ( loss / num_accumulation_rounds From c63ca0659772fa1d87d79ba24170048d3a15ed23 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 13 Jun 2025 14:37:32 +0200 Subject: [PATCH 054/189] fix small things for merging --- src/hirad/conf/model_size/mini.yaml | 2 ++ src/hirad/conf/model_size/normal.yaml | 1 + src/hirad/conf/sampler/deterministic.yaml | 8 ++++++-- src/hirad/conf/sampler/stochastic.yaml | 6 ++++-- src/hirad/datasets/__init__.py | 2 +- src/hirad/distributed/__init__.py | 2 +- src/hirad/eval/__init__.py | 2 +- src/hirad/losses/__init__.py | 2 +- 8 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/hirad/conf/model_size/mini.yaml b/src/hirad/conf/model_size/mini.yaml index 2eb8f8a..a5847f5 100644 --- a/src/hirad/conf/model_size/mini.yaml +++ b/src/hirad/conf/model_size/mini.yaml @@ -16,6 +16,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Smaller model size (10 million parameters), lower learning capacity, should be used only for testing or for small datasets and small grid size. +# Learning capacity is reduced and final models are not recommmended to be used in production. model_args: # Base multiplier for the number of channels across the network. diff --git a/src/hirad/conf/model_size/normal.yaml b/src/hirad/conf/model_size/normal.yaml index b81fe15..96c29fb 100644 --- a/src/hirad/conf/model_size/normal.yaml +++ b/src/hirad/conf/model_size/normal.yaml @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Normal model size (80 million parameters) should be used by default for full datasets and higher grid size. model_args: # Base multiplier for the number of channels across the network. diff --git a/src/hirad/conf/sampler/deterministic.yaml b/src/hirad/conf/sampler/deterministic.yaml index 35bc0f6..856906b 100644 --- a/src/hirad/conf/sampler/deterministic.yaml +++ b/src/hirad/conf/sampler/deterministic.yaml @@ -1,4 +1,8 @@ +# Deterministic sampler is generally faster than stochastic sampler, but should theoretically produce worse results. +# Deterministic sampler is not implemented correctly in this codebase and shouldn't be used. + type: deterministic -num_steps: 9 +num_steps: 9 # Number of denoising steps -solver: euler \ No newline at end of file +solver: euler + # ODE solver type: euler is the simplest solver \ No newline at end of file diff --git a/src/hirad/conf/sampler/stochastic.yaml b/src/hirad/conf/sampler/stochastic.yaml index 2481cd3..808270c 100644 --- a/src/hirad/conf/sampler/stochastic.yaml +++ b/src/hirad/conf/sampler/stochastic.yaml @@ -1,3 +1,5 @@ +# Stochastic sampler is slower, but should give better results than deterministic sampler. + type: stochastic -# boundary_pix: 2 -# overlap_pix: 4 \ No newline at end of file +# boundary_pix: 2 # set for patched diffusion +# overlap_pix: 4 # set for patched diffusion \ No newline at end of file diff --git a/src/hirad/datasets/__init__.py b/src/hirad/datasets/__init__.py index 706284e..53e791e 100644 --- a/src/hirad/datasets/__init__.py +++ b/src/hirad/datasets/__init__.py @@ -1,3 +1,3 @@ from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config from .era5_cosmo import ERA5_COSMO -from .base import DownscalingDataset \ No newline at end of file +from .base import DownscalingDataset diff --git a/src/hirad/distributed/__init__.py b/src/hirad/distributed/__init__.py index 0da01f3..d953a85 100644 --- a/src/hirad/distributed/__init__.py +++ b/src/hirad/distributed/__init__.py @@ -1 +1 @@ -from .manager import DistributedManager \ No newline at end of file +from .manager import DistributedManager diff --git a/src/hirad/eval/__init__.py b/src/hirad/eval/__init__.py index a3228ee..13d9eb3 100644 --- a/src/hirad/eval/__init__.py +++ b/src/hirad/eval/__init__.py @@ -1,2 +1,2 @@ from .metrics import compute_mae, average_power_spectrum -from .plotting import plot_error_projection, plot_power_spectra \ No newline at end of file +from .plotting import plot_error_projection, plot_power_spectra diff --git a/src/hirad/losses/__init__.py b/src/hirad/losses/__init__.py index 868ffdf..1494a54 100644 --- a/src/hirad/losses/__init__.py +++ b/src/hirad/losses/__init__.py @@ -1 +1 @@ -from .loss import ResidualLoss, RegressionLoss, RegressionLossCE \ No newline at end of file +from .loss import ResidualLoss, RegressionLoss, RegressionLossCE From 4a3a352b7df0c1dfd7fa5c9722fd545fd9385a76 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 16 Jun 2025 14:13:06 +0200 Subject: [PATCH 055/189] add crps code --- src/hirad/eval/crps.py | 348 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100644 src/hirad/eval/crps.py diff --git a/src/hirad/eval/crps.py b/src/hirad/eval/crps.py new file mode 100644 index 0000000..2ccdb42 --- /dev/null +++ b/src/hirad/eval/crps.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import numpy as np +import torch + +from .histogram import cdf as cdf_function + +Tensor = torch.Tensor + + +@torch.jit.script +def _kernel_crps_implementation(pred: Tensor, obs: Tensor, biased: bool) -> Tensor: + """An O(m log m) implementation of the kernel CRPS formulas""" + skill = torch.abs(pred - obs[..., None]).mean(-1) + pred, _ = torch.sort(pred) + + # derivation of fast implementation of spread-portion of CRPS formula when x is sorted + # sum_(i,j=1)^m |x_i - x_j| = sum_(i j) |x_i - x_j| + # = 2 sum_(i <= j) |x_i -x_j| + # = 2 sum_(i <= j) (x_j - x_i) + # = 2 sum_(i <= j) x_j - 2 sum_(i <= j) x_i + # = 2 sum_(j=1)^m j x_j - 2 sum (m - i + 1) x_i + # = 2 sum_(i=1)^m (2i - m - 1) x_i + m = pred.size(-1) + i = torch.arange(1, m + 1, device=pred.device, dtype=pred.dtype) + denom = m * m if biased else m * (m - 1) + factor = (2 * i - m - 1) / denom + spread = torch.sum(factor * pred, dim=-1) + return skill - spread + + +def kcrps(pred: Tensor, obs: Tensor, dim: int = 0, biased: bool = True): + """Estimate the CRPS from a finite ensemble + + Computes the local Continuous Ranked Probability Score (CRPS) by using + the kernel version of CRPS. The cost is O(m log m). + + Creates a map of CRPS and does not accumulate over lat/lon regions. + Approximates: + .. math:: + CRPS(X, y) = E[X - y] - 0.5 E[X-X'] + + with + .. math:: + sum_i=1^m |X_i - y| / m - 1/(2m^2) sum_i,j=1^m |x_i - x_j| + + Parameters + ---------- + pred : Tensor + Tensor containing the ensemble predictions. The ensemble dimension + is assumed to be the leading dimension unless 'dim' is specified. + obs : Union[Tensor, np.ndarray] + Tensor or array containing an observation over which the CRPS is computed + with respect to. + dim : int, optional + The dimension over which to compute the CRPS, assumed to be 0. + biased : + When False, uses the unbiased estimators described in (Zamo and Naveau, 2018):: + + E|X-y|/m - 1/(2m(m-1)) sum_(i,j=1)|x_i - x_j| + + Unlike ``crps`` this is fair for finite ensembles. Non-fair ``crps`` favors less + dispersive ensembles since it is biased high by E|X- X'|/ m where m is the + ensemble size. + + Returns + ------- + Tensor + Map of CRPS + """ + pred = torch.movedim(pred, dim, -1) + return _kernel_crps_implementation(pred, obs, biased=biased) + + +def _crps_gaussian(mean: Tensor, std: Tensor, obs: Union[Tensor, np.ndarray]) -> Tensor: + """ + Computes the local Continuous Ranked Probability Score (CRPS) + using assuming that the forecast distribution is normal. + + Creates a map of CRPS and does not accumulate over lat/lon regions. + + Computes: + + .. math:: + + CRPS(mean, std, y) = std * [ \\frac{1}{\\sqrt{\\pi}}} - 2 \\phi ( \\frac{x-mean}{std} ) - + ( \\frac{x-mean}{std} ) * (2 \\Phi(\\frac{x-mean}{std}) - 1) ] + + where \\phi and \\Phi are the normal gaussian pdf/cdf respectively. + + Parameters + ---------- + mean : Tensor + Tensor of mean of forecast distribution. + std : Tensor + Tensor of standard deviation of forecast distribution. + obs : Union[Tensor, np.ndarray] + Tensor or array containing an observation over which the CRPS is computed + with respect to. Broadcasting dimensions must be compatible with the non-zeroth + dimensions of bins and cdf. + + Returns + ------- + Tensor + Map of CRPS + """ + if isinstance(obs, np.ndarray): + obs = torch.from_numpy(obs).to(mean.device) + # Check shape compatibility + if mean.shape != std.shape: + raise ValueError( + "Mean and standard deviation must have" + + "compatible shapes but found" + + str(mean.shape) + + " and " + + str(std.shape) + + "." + ) + if mean.shape != obs.shape: + raise ValueError( + "Mean and obs must have" + + "compatible shapes but found" + + str(mean.shape) + + " and " + + str(obs.shape) + + "." + ) + + d = (obs - mean) / std + phi = torch.exp(-0.5 * d**2) / torch.sqrt(torch.as_tensor(2 * torch.pi)) + + # Note, simplified expression below is not exactly Gaussian CDF + Phi = torch.erf(d / torch.sqrt(torch.as_tensor(2.0))) + + return std * (2 * phi + d * Phi - 1.0 / torch.sqrt(torch.as_tensor(torch.pi))) + + +def _crps_from_cdf( + bin_edges: Tensor, cdf: Tensor, obs: Union[Tensor, np.ndarray] +) -> Tensor: + """Computes the local Continuous Ranked Probability Score (CRPS) + using a cumulative distribution function. + + Creates a map of CRPS and does not accumulate over lat/lon regions. + + Computes: + + .. math:: + + CRPS(X, y) = \\int[ (F(x) - 1[x - y])^2 ] dx + + where F is the empirical cdf of X. + + Parameters + ---------- + bins_edges : Tensor + Tensor [N+1, ...] containing bin edges. The leading dimension must represent the + N+1 bin edges. + cdf : Tensor + Tensor [N, ...] containing a cdf, defined over bins. The non-zeroth dimensions + of bins and cdf must be compatible. + obs : Union[Tensor, np.ndarray] + Tensor or array containing an observation over which the CRPS is computed + with respect to. Broadcasting dimensions must be compatible with the non-zeroth + dimensions of bins and cdf. + + Returns + ------- + Tensor + Map of CRPS + """ + if isinstance(obs, np.ndarray): + obs = torch.from_numpy(obs).to(cdf.device) + if bin_edges.shape[1:] != cdf.shape[1:]: + raise ValueError( + "Expected bins and cdf to have compatible non-zeroth dimensions but have shapes" + + str(bin_edges.shape[1:]) + + " and " + + str(cdf.shape[1:]) + + "." + ) + if bin_edges.shape[1:] != obs.shape: + raise ValueError( + "Expected bins and observations to have compatible broadcasting dimensions but have shapes" + + str(bin_edges.shape[1:]) + + " and " + + str(obs.shape) + + "." + ) + if bin_edges.shape[0] != cdf.shape[0] + 1: + raise ValueError( + "Expected zeroth dimension of cdf to be equal to the zeroth dimension of bins + 1 but have shapes" + + str(bin_edges.shape[0]) + + " and " + + str(cdf.shape[0]) + + "+1." + ) + dbins = bin_edges[1, ...] - bin_edges[0, ...] + bin_mids = 0.5 * (bin_edges[1:] + bin_edges[:-1]) + obs = torch.ge(bin_mids, obs).int() + return torch.sum(torch.abs(cdf - obs) ** 2 * dbins, dim=0) + + +def _crps_from_counts( + bin_edges: Tensor, counts: Tensor, obs: Union[Tensor, np.ndarray] +) -> Tensor: + """Computes the local Continuous Ranked Probability Score (CRPS) + using a histogram of counts. + + Creates a map of CRPS and does not accumulate over lat/lon regions. + + Computes: + + .. math:: + + CRPS(X, y) = int[ (F(x) - 1[x - y])^2 ] dx + + where F is the empirical cdf of X. + + Parameters + ---------- + bins_edges : Tensor + Tensor [N+1, ...] containing bin edges. The leading dimension must represent the + N+1 bin edges. + counts : Tensor + Tensor [N, ...] containing counts, defined over bins. The non-zeroth dimensions + of bins and counts must be compatible. + obs : Union[Tensor, np.ndarray] + Tensor or array containing an observation over which the CRPS is computed + with respect to. Broadcasting dimensions must be compatible with the non-zeroth + dimensions of bins and counts. + + Returns + ------- + Tensor + Map of CRPS + """ + if isinstance(obs, np.ndarray): + obs = torch.from_numpy(obs).to(counts.device) + if bin_edges.shape[1:] != counts.shape[1:]: + raise ValueError( + "Expected bins and cdf to have compatible non-zeroth dimensions but have shapes" + + str(bin_edges.shape[1:]) + + " and " + + str(counts.shape[1:]) + + "." + ) + if bin_edges.shape[1:] != obs.shape: + raise ValueError( + "Expected bins and observations to have compatible broadcasting dimensions but have shapes" + + str(bin_edges.shape[1:]) + + " and " + + str(obs.shape) + + "." + ) + if bin_edges.shape[0] != counts.shape[0] + 1: + raise ValueError( + "Expected zeroth dimension of cdf to be equal to the zeroth dimension of bins + 1 but have shapes" + + str(bin_edges.shape[0]) + + " and " + + str(counts.shape[0]) + + "+1." + ) + cdf_hat = torch.cumsum(counts / torch.sum(counts, dim=0), dim=0) + return _crps_from_cdf(bin_edges, cdf_hat, obs) + + +def crps( + pred: Tensor, obs: Union[Tensor, np.ndarray], dim: int = 0, method: str = "kernel" +) -> Tensor: + """ + Computes the local Continuous Ranked Probability Score (CRPS). + + Creates a map of CRPS and does not accumulate over any other dimensions (e.g., lat/lon regions). + + Parameters + ---------- + pred : Tensor + Tensor containing the ensemble predictions. + obs : Union[Tensor, np.ndarray] + Tensor or array containing an observation over which the CRPS is computed + with respect to. + dim : int, Optional + Dimension with which to calculate the CRPS over, the ensemble dimension. + Assumed to be zero. + method: str, Optional + The method to calculate the crps. Can either be "kernel", "sort" or "histogram". + + The "kernel" method implements + .. math:: + CRPS(x, y) = E[X-y] - 0.5*E[X-X'] + + This method scales as O(n^2) where n is the number of ensemble members and + can potentially induce large memory consumption as the algorithm attempts + to vectorize over this O(n^2) operation. + + The "sort" method compute the exact CRPS using the CDF method + .. math:: + CRPS(x, y) = int [F(x) - 1(x-y)]^2 dx + + where F is the empirical CDF and 1(x-y) = 1 if x > y. + + This method is more memory efficient than the kernel method, and uses O(n + log n) compute instead of O(n^2), where n is the number of ensemble members. + + The "histogram" method computes an approximate CRPS using the CDF method + .. math:: + CRPS(x, y) = int [F(x) - 1(x-y)]^2 dx + + where F is the empirical CDF, estimated via a histogram of the samples. The + number of bins used is the lesser of the square root of the number of samples + and 100. For more control over the implementation of this method consider using + `cdf_function` to construct a cdf and `_crps_from_cdf` to compute CRPS. + + Returns + ------- + Tensor + Map of CRPS + """ + if method not in ["kernel", "sort", "histogram"]: + raise ValueError("Method must either be 'kernel', 'sort' or 'histogram'.") + + n = pred.shape[dim] + obs = torch.as_tensor(obs, device=pred.device, dtype=pred.dtype) + if method in ["kernel", "sort"]: + return kcrps(pred, obs, dim=dim) + else: + pred = pred.unsqueeze(0).transpose(0, dim + 1).squeeze(dim + 1) + number_of_bins = max(int(np.sqrt(n)), 100) + bin_edges, cdf = cdf_function(pred, bins=number_of_bins) + _crps = _crps_from_cdf(bin_edges, cdf, obs) + return _crps \ No newline at end of file From eb408f01e0ded58d4d8ca2df228bcf409b0e0fea Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 17 Jun 2025 16:14:41 +0200 Subject: [PATCH 056/189] restructure utils --- src/hirad/datasets/__init__.py | 2 +- src/hirad/datasets/dataset.py | 19 + src/hirad/inference/generate.py | 12 +- src/hirad/models/__init__.py | 1 + src/hirad/models/layers.py | 2 +- src/hirad/utils/capture.py | 513 ---------------------- src/hirad/utils/function_utils.py | 656 +---------------------------- src/hirad/utils/generate_utils.py | 22 - src/hirad/utils/inference_utils.py | 114 ----- src/hirad/utils/model_utils.py | 66 --- src/hirad/utils/train_helpers.py | 6 - 11 files changed, 44 insertions(+), 1369 deletions(-) delete mode 100644 src/hirad/utils/capture.py delete mode 100644 src/hirad/utils/generate_utils.py delete mode 100644 src/hirad/utils/model_utils.py diff --git a/src/hirad/datasets/__init__.py b/src/hirad/datasets/__init__.py index 53e791e..2fb3d5d 100644 --- a/src/hirad/datasets/__init__.py +++ b/src/hirad/datasets/__init__.py @@ -1,3 +1,3 @@ -from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config +from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config, get_dataset_and_sampler_inference from .era5_cosmo import ERA5_COSMO from .base import DownscalingDataset diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index c09a4f2..1d36bc2 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -111,3 +111,22 @@ def init_dataset_from_config( ) return (dataset_obj, dataset_iterator) + + +def get_dataset_and_sampler_inference(dataset_cfg, times, has_lead_time=False): + """ + Get a dataset and sampler for generation. + """ + (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1) + # if has_lead_time: + # plot_times = times + # else: + # plot_times = [ + # datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") + # for time in times + # ] + all_times = dataset.time() + time_indices = [all_times.index(t) for t in times] + sampler = time_indices + + return dataset, sampler diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 35f856f..b60d6d0 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -15,26 +15,20 @@ import cartopy.crs as ccrs from matplotlib import pyplot as plt -from einops import rearrange from torch.distributed import gather - -from hydra.utils import to_absolute_path from hirad.models import EDMPrecondSuperResolution, UNet from hirad.utils.patching import GridPatching2D from hirad.utils.stochastic_sampler import stochastic_sampler from hirad.utils.deterministic_sampler import deterministic_sampler from hirad.utils.inference_utils import ( - get_time_from_range, regression_step, diffusion_step, ) +from hirad.utils.function_utils import get_time_from_range from hirad.utils.checkpoint import load_checkpoint - -from hirad.utils.generate_utils import ( - get_dataset_and_sampler -) +from hirad.datasets import get_dataset_and_sampler_inference from hirad.utils.train_helpers import set_patch_shape @@ -81,7 +75,7 @@ def main(cfg: DictConfig) -> None: has_lead_time = cfg.generation["has_lead_time"] else: has_lead_time = False - dataset, sampler = get_dataset_and_sampler( + dataset, sampler = get_dataset_and_sampler_inference( dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time ) img_shape = dataset.image_shape() diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index b00a477..c8fb306 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,3 +1,4 @@ +from .utils import weight_init from .layers import ( Linear, Conv2d, diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index d7e63d7..96bb37f 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -30,7 +30,7 @@ from einops import rearrange from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh -from hirad.utils.model_utils import weight_init +from hirad.models import weight_init _is_apex_available = False if torch.cuda.is_available(): diff --git a/src/hirad/utils/capture.py b/src/hirad/utils/capture.py deleted file mode 100644 index 9c38d5a..0000000 --- a/src/hirad/utils/capture.py +++ /dev/null @@ -1,513 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import logging -import os -import time -from contextlib import nullcontext -from logging import Logger -from typing import Any, Callable, Dict, NewType, Optional, Union - -import torch - -from hirad.distributed import DistributedManager - -float16 = NewType("float16", torch.float16) -bfloat16 = NewType("bfloat16", torch.bfloat16) -optim = NewType("optim", torch.optim) - - -class _StaticCapture(object): - """Base class for StaticCapture decorator. - - This class should not be used, rather StaticCaptureTraining and StaticCaptureEvaluate - should be used instead for training and evaluation functions. - """ - - # Grad scaler and checkpoint class variables use for checkpoint saving and loading - # Since an instance of Static capture does not exist for checkpoint functions - # one must use class functions to access state dicts - _amp_scalers = {} - _amp_scaler_checkpoints = {} - _logger = logging.getLogger("capture") - - def __new__(cls, *args, **kwargs): - obj = super(_StaticCapture, cls).__new__(cls) - obj.amp_scalers = cls._amp_scalers - obj.amp_scaler_checkpoints = cls._amp_scaler_checkpoints - obj.logger = cls._logger - return obj - - def __init__( - self, - model: "physicsnemo.Module", - optim: Optional[optim] = None, - logger: Optional[Logger] = None, - use_graphs: bool = True, - use_autocast: bool = True, - use_gradscaler: bool = True, - compile: bool = False, - cuda_graph_warmup: int = 11, - amp_type: Union[float16, bfloat16] = torch.float16, - gradient_clip_norm: Optional[float] = None, - label: Optional[str] = None, - ): - self.logger = logger if logger else self.logger - # Checkpoint label (used for gradscaler) - self.label = label if label else f"scaler_{len(self.amp_scalers.keys())}" - - # DDP fix - if not isinstance(model, physicsnemo.models.Module) and hasattr( - model, "module" - ): - model = model.module - - if not isinstance(model, physicsnemo.models.Module): - self.logger.error("Model not a PhysicsNeMo Module!") - raise ValueError("Model not a PhysicsNeMo Module!") - if compile: - model = torch.compile(model) - - self.model = model - - self.optim = optim - self.eval = False - self.no_grad = False - self.gradient_clip_norm = gradient_clip_norm - - # Set up toggles for optimizations - if not (amp_type == torch.float16 or amp_type == torch.bfloat16): - raise ValueError("AMP type must be torch.float16 or torch.bfloat16") - # CUDA device - if "cuda" in str(self.model.device): - # CUDA graphs - if use_graphs and not self.model.meta.cuda_graphs: - self.logger.warning( - f"Model {model.meta.name} does not support CUDA graphs, turning off" - ) - use_graphs = False - self.cuda_graphs_enabled = use_graphs - - # AMP GPU - if not self.model.meta.amp_gpu: - self.logger.warning( - f"Model {model.meta.name} does not support AMP on GPUs, turning off" - ) - use_autocast = False - use_gradscaler = False - self.use_gradscaler = use_gradscaler - self.use_autocast = use_autocast - - self.amp_device = "cuda" - # Check if bfloat16 is suppored on the GPU - if amp_type == torch.bfloat16 and not torch.cuda.is_bf16_supported(): - self.logger.warning( - "Current CUDA device does not support bfloat16, falling back to float16" - ) - amp_type = torch.float16 - self.amp_dtype = amp_type - # Gradient Scaler - scaler_enabled = self.use_gradscaler and amp_type == torch.float16 - self.scaler = self._init_amp_scaler(scaler_enabled, self.logger) - - self.replay_stream = torch.cuda.Stream(self.model.device) - # CPU device - else: - self.cuda_graphs_enabled = False - # AMP CPU - if use_autocast and not self.model.meta.amp_cpu: - self.logger.warning( - f"Model {model.meta.name} does not support AMP on CPUs, turning off" - ) - use_autocast = False - - self.use_autocast = use_autocast - self.amp_device = "cpu" - # Only float16 is supported on CPUs - # https://pytorch.org/docs/stable/amp.html#cpu-op-specific-behavior - if amp_type == torch.float16 and use_autocast: - self.logger.warning( - "torch.float16 not supported for CPU AMP, switching to torch.bfloat16" - ) - amp_type = torch.bfloat16 - self.amp_dtype = torch.bfloat16 - # Gradient Scaler (not enabled) - self.scaler = self._init_amp_scaler(False, self.logger) - self.replay_stream = None - - if self.cuda_graphs_enabled: - self.graph = torch.cuda.CUDAGraph() - - self.output = None - self.iteration = 0 - self.cuda_graph_warmup = cuda_graph_warmup # Default for DDP = 11 - - def __call__(self, fn: Callable) -> Callable: - self.function = fn - - @functools.wraps(fn) - def decorated(*args: Any, **kwds: Any) -> Any: - """Training step decorator function""" - - with torch.no_grad() if self.no_grad else nullcontext(): - if self.cuda_graphs_enabled: - self._cuda_graph_forward(*args, **kwds) - else: - self._zero_grads() - self.output = self._amp_forward(*args, **kwds) - - if not self.eval: - # Update model parameters - self.scaler.step(self.optim) - self.scaler.update() - - return self.output - - return decorated - - def _cuda_graph_forward(self, *args: Any, **kwargs: Any) -> Any: - """Forward training step with CUDA graphs - - Returns - ------- - Any - Output of neural network forward - """ - # Graph warm up - if self.iteration < self.cuda_graph_warmup: - self.replay_stream.wait_stream(torch.cuda.current_stream()) - self._zero_grads() - with torch.cuda.stream(self.replay_stream): - output = self._amp_forward(*args, **kwargs) - self.output = output.detach() - torch.cuda.current_stream().wait_stream(self.replay_stream) - # CUDA Graphs - else: - # Graph record - if self.iteration == self.cuda_graph_warmup: - self.logger.warning(f"Recording graph of '{self.function.__name__}'") - self._zero_grads() - torch.cuda.synchronize() - if DistributedManager().distributed: - torch.distributed.barrier() - # TODO: temporary workaround till this issue is fixed: - # https://github.com/pytorch/pytorch/pull/104487#issuecomment-1638665876 - delay = os.environ.get("PHYSICSNEMO_CUDA_GRAPH_CAPTURE_DELAY", "10") - time.sleep(int(delay)) - with torch.cuda.graph(self.graph): - output = self._amp_forward(*args, **kwargs) - self.output = output.detach() - # Graph replay - self.graph.replay() - - self.iteration += 1 - return self.output - - def _zero_grads(self): - """Zero gradients - - Default to `set_to_none` since this will in general have lower memory - footprint, and can modestly improve performance. - - Note - ---- - Zeroing gradients can potentially cause an invalid CUDA memory access in another - graph. However if your graph involves gradients, you much set your gradients to none. - If there is already a graph recorded that includes these gradients, this will error. - Use the `NoGrad` version of capture to avoid this issue for inferencers / validators. - """ - # Skip zeroing if no grad is being used - if self.no_grad: - return - - try: - self.optim.zero_grad(set_to_none=True) - except Exception: - if self.optim: - self.optim.zero_grad() - # For apex optim support and eval mode (need to reset model grads) - self.model.zero_grad(set_to_none=True) - - def _amp_forward(self, *args, **kwargs) -> Any: - """Compute loss and gradients (if training) with AMP - - Returns - ------- - Any - Output of neural network forward - """ - with torch.autocast( - self.amp_device, enabled=self.use_autocast, dtype=self.amp_dtype - ): - output = self.function(*args, **kwargs) - - if not self.eval: - # In training mode output should be the loss - self.scaler.scale(output).backward() - if self.gradient_clip_norm is not None: - self.scaler.unscale_(self.optim) - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.gradient_clip_norm - ) - - return output - - def _init_amp_scaler( - self, scaler_enabled: bool, logger: Logger - ) -> torch.cuda.amp.GradScaler: - # Create gradient scaler - scaler = torch.cuda.amp.GradScaler(enabled=scaler_enabled) - # Store scaler in class variable - self.amp_scalers[self.label] = scaler - logging.debug(f"Created gradient scaler {self.label}") - - # If our checkpoint dictionary has weights for this scaler lets load - if self.label in self.amp_scaler_checkpoints: - try: - scaler.load_state_dict(self.amp_scaler_checkpoints[self.label]) - del self.amp_scaler_checkpoints[self.label] - self.logger.info(f"Loaded grad scaler state dictionary {self.label}.") - except Exception as e: - self.logger.error( - f"Failed to load grad scaler {self.label} state dict from saved " - + "checkpoints. Did you switch the ordering of declared static captures?" - ) - raise ValueError(e) - return scaler - - @classmethod - def state_dict(cls) -> Dict[str, Any]: - """Class method for accsessing the StaticCapture state dictionary. - Use this in a training checkpoint function. - - Returns - ------- - Dict[str, Any] - Dictionary of states to save for file - """ - scaler_states = {} - for key, value in cls._amp_scalers.items(): - scaler_states[key] = value.state_dict() - - return scaler_states - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any]) -> None: - """Class method for loading a StaticCapture state dictionary. - Use this in a training checkpoint function. - - Returns - ------- - Dict[str, Any] - Dictionary of states to save for file - """ - for key, value in state_dict.items(): - # If scaler has been created already load the weights - if key in cls._amp_scalers: - try: - cls._amp_scalers[key].load_state_dict(value) - cls._logger.info(f"Loaded grad scaler state dictionary {key}.") - except Exception as e: - cls._logger.error( - f"Failed to load grad scaler state dict with id {key}." - + " Something went wrong!" - ) - raise ValueError(e) - # Otherwise store in checkpoints for later use - else: - cls._amp_scaler_checkpoints[key] = value - - @classmethod - def reset_state(cls): - cls._amp_scalers = {} - cls._amp_scaler_checkpoints = {} - - -class StaticCaptureTraining(_StaticCapture): - """A performance optimization decorator for PyTorch training functions. - - This class should be initialized as a decorator on a function that computes the - forward pass of the neural network and loss function. The user should only call the - defind training step function. This will apply optimizations including: AMP and - Cuda Graphs. - - Parameters - ---------- - model : physicsnemo.models.Module - PhysicsNeMo Model - optim : torch.optim - Optimizer - logger : Optional[Logger], optional - PhysicsNeMo Launch Logger, by default None - use_graphs : bool, optional - Toggle CUDA graphs if supported by model, by default True - use_amp : bool, optional - Toggle AMP if supported by mode, by default True - cuda_graph_warmup : int, optional - Number of warmup steps for cuda graphs, by default 11 - amp_type : Union[float16, bfloat16], optional - Auto casting type for AMP, by default torch.float16 - gradient_clip_norm : Optional[float], optional - Threshold for gradient clipping - label : Optional[str], optional - Static capture checkpoint label, by default None - - Raises - ------ - ValueError - If the model provided is not a physicsnemo.models.Module. I.e. has no meta data. - - Example - ------- - >>> # Create model - >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2) - >>> input = torch.rand(8, 2) - >>> output = torch.rand(8, 2) - >>> # Create optimizer - >>> optim = torch.optim.Adam(model.parameters(), lr=0.001) - >>> # Create training step function with optimization wrapper - >>> @StaticCaptureTraining(model=model, optim=optim) - ... def training_step(model, invar, outvar): - ... predvar = model(invar) - ... loss = torch.sum(torch.pow(predvar - outvar, 2)) - ... return loss - ... - >>> # Sample training loop - >>> for i in range(3): - ... loss = training_step(model, input, output) - ... - - Note - ---- - Static captures must be checkpointed when training using the `state_dict()` if AMP - is being used with gradient scaler. By default, this requires static captures to be - instantiated in the same order as when they were checkpointed. The label parameter - can be used to relax/circumvent this ordering requirement. - - Note - ---- - Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA - memory access errors on some systems. Prioritize capturing training graphs when this - occurs. - """ - - def __init__( - self, - model: "physicsnemo.Module", - optim: torch.optim, - logger: Optional[Logger] = None, - use_graphs: bool = True, - use_amp: bool = True, - compile: bool = False, - cuda_graph_warmup: int = 11, - amp_type: Union[float16, bfloat16] = torch.float16, - gradient_clip_norm: Optional[float] = None, - label: Optional[str] = None, - ): - super().__init__( - model, - optim, - logger, - use_graphs, - use_amp, - use_amp, - compile, - cuda_graph_warmup, - amp_type, - gradient_clip_norm, - label, - ) - - -class StaticCaptureEvaluateNoGrad(_StaticCapture): - - """An performance optimization decorator for PyTorch no grad evaluation. - - This class should be initialized as a decorator on a function that computes run the - forward pass of the model that does not require gradient calculations. This is the - recommended method to use for inference and validation methods. - - Parameters - ---------- - model : physicsnemo.models.Module - PhysicsNeMo Model - logger : Optional[Logger], optional - PhysicsNeMo Launch Logger, by default None - use_graphs : bool, optional - Toggle CUDA graphs if supported by model, by default True - use_amp : bool, optional - Toggle AMP if supported by mode, by default True - cuda_graph_warmup : int, optional - Number of warmup steps for cuda graphs, by default 11 - amp_type : Union[float16, bfloat16], optional - Auto casting type for AMP, by default torch.float16 - label : Optional[str], optional - Static capture checkpoint label, by default None - - Raises - ------ - ValueError - If the model provided is not a physicsnemo.models.Module. I.e. has no meta data. - - Example - ------- - >>> # Create model - >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2) - >>> input = torch.rand(8, 2) - >>> # Create evaluate function with optimization wrapper - >>> @StaticCaptureEvaluateNoGrad(model=model) - ... def eval_step(model, invar): - ... predvar = model(invar) - ... return predvar - ... - >>> output = eval_step(model, input) - >>> output.size() - torch.Size([8, 2]) - - Note - ---- - Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA - memory access errors on some systems. Prioritize capturing training graphs when this - occurs. - """ - - def __init__( - self, - model: "physicsnemo.Module", - logger: Optional[Logger] = None, - use_graphs: bool = True, - use_amp: bool = True, - compile: bool = False, - cuda_graph_warmup: int = 11, - amp_type: Union[float16, bfloat16] = torch.float16, - label: Optional[str] = None, - ): - super().__init__( - model, - None, - logger, - use_graphs, - use_amp, - compile, - False, - cuda_graph_warmup, - amp_type, - None, - label, - ) - self.eval = True # No optimizer/scaler calls - self.no_grad = True # No grad context and no grad zeroing diff --git a/src/hirad/utils/function_utils.py b/src/hirad/utils/function_utils.py index 347457c..2da05db 100644 --- a/src/hirad/utils/function_utils.py +++ b/src/hirad/utils/function_utils.py @@ -17,19 +17,8 @@ """Miscellaneous utility classes and functions.""" -import contextlib -import ctypes import datetime -import fnmatch -import importlib -import inspect -import os -import re -import shutil -import sys -import types -import warnings -from typing import Any, Iterator, List, Tuple, Union +from typing import Iterator import cftime import numpy as np @@ -38,25 +27,6 @@ # ruff: noqa: E722 PERF203 S110 E713 S324 -class EasyDict(dict): # pragma: no cover - """ - Convenience class that behaves like a dict but allows access with the attribute - syntax. - """ - - def __getattr__(self, name: str) -> Any: - try: - return self[name] - except KeyError: - raise AttributeError(name) - - def __setattr__(self, name: str, value: Any) -> None: - self[name] = value - - def __delattr__(self, name: str) -> None: - del self[name] - - class StackedRandomGenerator: # pragma: no cover """ Wrapper for torch.Generator that allows specifying a different random seed @@ -96,32 +66,8 @@ def randint(self, *args, size, **kwargs): ) -def parse_int_list(s): # pragma: no cover - """ - Parse a comma separated list of numbers or ranges and return a list of ints. - Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] - """ - if isinstance(s, list): - return s - ranges = [] - range_re = re.compile(r"^(\d+)-(\d+)$") - for p in s.split(","): - m = range_re.match(p) - if m: - ranges.extend(range(int(m.group(1)), int(m.group(2)) + 1)) - else: - ranges.append(int(p)) - return ranges - - # Small util functions # ------------------------------------------------------------------------------------- -def convert_datetime_to_cftime( - time: datetime.datetime, cls=cftime.DatetimeGregorian -) -> cftime.DatetimeGregorian: - """Convert a Python datetime object to a cftime DatetimeGregorian object.""" - return cls(time.year, time.month, time.day, time.hour, time.minute, time.second) - def time_range( start_time: datetime.datetime, @@ -135,417 +81,30 @@ def time_range( yield t t += step +def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"): + """Generates a list of times within a given range. -def format_time(seconds: Union[int, float]) -> str: # pragma: no cover - """Convert the seconds to human readable string with days, hours, minutes and seconds.""" - s = int(np.rint(seconds)) - - if s < 60: - return "{0}s".format(s) - elif s < 60 * 60: - return "{0}m {1:02}s".format(s // 60, s % 60) - elif s < 24 * 60 * 60: - return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) - else: - return "{0}d {1:02}h {2:02}m".format( - s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60 - ) - - -def format_time_brief(seconds: Union[int, float]) -> str: # pragma: no cover - """Convert the seconds to human readable string with days, hours, minutes and seconds.""" - s = int(np.rint(seconds)) - - if s < 60: - return "{0}s".format(s) - elif s < 60 * 60: - return "{0}m {1:02}s".format(s // 60, s % 60) - elif s < 24 * 60 * 60: - return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) - else: - return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) - - -def tuple_product(t: Tuple) -> Any: # pragma: no cover - """Calculate the product of the tuple elements.""" - result = 1 - - for v in t: - result *= v - - return result - - -_str_to_ctype = { - "uint8": ctypes.c_ubyte, - "uint16": ctypes.c_uint16, - "uint32": ctypes.c_uint32, - "uint64": ctypes.c_uint64, - "int8": ctypes.c_byte, - "int16": ctypes.c_int16, - "int32": ctypes.c_int32, - "int64": ctypes.c_int64, - "float32": ctypes.c_float, - "float64": ctypes.c_double, -} - - -def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: # pragma: no cover - """ - Given a type name string (or an object having a __name__ attribute), return - matching Numpy and ctypes types that have the same size in bytes. - """ - type_str = None - - if isinstance(type_obj, str): - type_str = type_obj - elif hasattr(type_obj, "__name__"): - type_str = type_obj.__name__ - elif hasattr(type_obj, "name"): - type_str = type_obj.name - else: - raise RuntimeError("Cannot infer type name from input") - - if type_str not in _str_to_ctype.keys(): - raise ValueError("Unknown type name: " + type_str) - - my_dtype = np.dtype(type_str) - my_ctype = _str_to_ctype[type_str] - - if my_dtype.itemsize != ctypes.sizeof(my_ctype): - raise ValueError( - "Numpy and ctypes types for '{}' have different sizes!".format(type_str) - ) - - return my_dtype, my_ctype - - -# Functionality to import modules/objects by name, and call functions by name -# ------------------------------------------------------------------------------------- - - -def get_module_from_obj_name( - obj_name: str, -) -> Tuple[types.ModuleType, str]: # pragma: no cover - """ - Searches for the underlying module behind the name to some python object. - Returns the module and the object name (original name with module part removed). - """ - - # allow convenience shorthands, substitute them by full names - obj_name = re.sub("^np.", "numpy.", obj_name) - obj_name = re.sub("^tf.", "tensorflow.", obj_name) - - # list alternatives for (module_name, local_obj_name) - parts = obj_name.split(".") - name_pairs = [ - (".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1) - ] - - # try each alternative in turn - for module_name, local_obj_name in name_pairs: - try: - module = importlib.import_module(module_name) # may raise ImportError - get_obj_from_module(module, local_obj_name) # may raise AttributeError - return module, local_obj_name - except: - pass - - # maybe some of the modules themselves contain errors? - for module_name, _local_obj_name in name_pairs: - try: - importlib.import_module(module_name) # may raise ImportError - except ImportError: - if not str(sys.exc_info()[1]).startswith( - "No module named '" + module_name + "'" - ): - raise - - # maybe the requested attribute is missing? - for module_name, local_obj_name in name_pairs: - try: - module = importlib.import_module(module_name) # may raise ImportError - get_obj_from_module(module, local_obj_name) # may raise AttributeError - except ImportError: - pass - - # we are out of luck, but we have no idea why - raise ImportError(obj_name) - - -def get_obj_from_module( - module: types.ModuleType, obj_name: str -) -> Any: # pragma: no cover - """ - Traverses the object name and returns the last (rightmost) python object. - """ - if obj_name == "": - return module - obj = module - for part in obj_name.split("."): - obj = getattr(obj, part) - return obj - - -def get_obj_by_name(name: str) -> Any: # pragma: no cover - """ - Finds the python object with the given name. - """ - module, obj_name = get_module_from_obj_name(name) - return get_obj_from_module(module, obj_name) - - -def call_func_by_name( - *args, func_name: str = None, **kwargs -) -> Any: # pragma: no cover - """ - Finds the python object with the given name and calls it as a function. - """ - if func_name is None: - raise ValueError("func_name must be specified") - func_obj = get_obj_by_name(func_name) - if not callable(func_obj): - raise ValueError(func_name + " is not callable") - return func_obj(*args, **kwargs) - - -def construct_class_by_name( - *args, class_name: str = None, **kwargs -) -> Any: # pragma: no cover - """ - Finds the python class with the given name and constructs it with the given - arguments. - """ - return call_func_by_name(*args, func_name=class_name, **kwargs) - - -def get_module_dir_by_obj_name(obj_name: str) -> str: # pragma: no cover - """ - Get the directory path of the module containing the given object name. - """ - module, _ = get_module_from_obj_name(obj_name) - return os.path.dirname(inspect.getfile(module)) - - -def is_top_level_function(obj: Any) -> bool: # pragma: no cover - """ - Determine whether the given object is a top-level function, i.e., defined at module - scope using 'def'. - """ - return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ - - -def get_top_level_function_name(obj: Any) -> str: # pragma: no cover - """ - Return the fully-qualified name of a top-level function. - """ - if not is_top_level_function(obj): - raise ValueError("Object is not a top-level function") - module = obj.__module__ - if module == "__main__": - module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] - return module + "." + obj.__name__ - - -# File system helpers -# ------------------------------------------------------------------------------------------ - - -def list_dir_recursively_with_ignore( - dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False -) -> List[Tuple[str, str]]: # pragma: no cover - """ - List all files recursively in a given directory while ignoring given file and - directory names. Returns list of tuples containing both absolute and relative paths. - """ - if not os.path.isdir(dir_path): - raise RuntimeError(f"Directory does not exist: {dir_path}") - base_name = os.path.basename(os.path.normpath(dir_path)) - - if ignores is None: - ignores = [] - - result = [] - - for root, dirs, files in os.walk(dir_path, topdown=True): - for ignore_ in ignores: - dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] - - # dirs need to be edited in-place - for d in dirs_to_remove: - dirs.remove(d) - - files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] - - absolute_paths = [os.path.join(root, f) for f in files] - relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] - - if add_base_to_relative: - relative_paths = [os.path.join(base_name, p) for p in relative_paths] - - if len(absolute_paths) != len(relative_paths): - raise ValueError("Number of absolute and relative paths do not match") - result += zip(absolute_paths, relative_paths) + Args: + times_range: A list containing start time, end time, and optional interval (hours). + time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S"). - return result - - -def copy_files_and_create_dirs( - files: List[Tuple[str, str]] -) -> None: # pragma: no cover - """ - Takes in a list of tuples of (src, dst) paths and copies files. - Will create all necessary directories. + Returns: + A list of times within the specified range. """ - for file in files: - target_dir_name = os.path.dirname(file[1]) - - # will create all intermediate-level directories - if not os.path.exists(target_dir_name): - os.makedirs(target_dir_name) - - shutil.copyfile(file[0], file[1]) - -# ---------------------------------------------------------------------------- -# Cached construction of constant tensors. Avoids CPU=>GPU copy when the -# same constant is used multiple times. - -_constant_cache = dict() - - -def constant( - value, shape=None, dtype=None, device=None, memory_format=None -): # pragma: no cover - """Cached construction of constant tensors""" - value = np.asarray(value) - if shape is not None: - shape = tuple(shape) - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device("cpu") - if memory_format is None: - memory_format = torch.contiguous_format - - key = ( - value.shape, - value.dtype, - value.tobytes(), - shape, - dtype, - device, - memory_format, + start_time = datetime.datetime.strptime(times_range[0], time_format) + end_time = datetime.datetime.strptime(times_range[1], time_format) + interval = ( + datetime.timedelta(hours=times_range[2]) + if len(times_range) > 2 + else datetime.timedelta(hours=1) ) - tensor = _constant_cache.get(key, None) - if tensor is None: - tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) - if shape is not None: - tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) - tensor = tensor.contiguous(memory_format=memory_format) - _constant_cache[key] = tensor - return tensor - - -# ---------------------------------------------------------------------------- -# Replace NaN/Inf with specified numerical values. - -try: - nan_to_num = torch.nan_to_num # 1.8.0a0 -except AttributeError: - - def nan_to_num( - input, nan=0.0, posinf=None, neginf=None, *, out=None - ): # pylint: disable=redefined-builtin # pragma: no cover - """Replace NaN/Inf with specified numerical values""" - if not isinstance(input, torch.Tensor): - raise TypeError("input should be a Tensor") - if posinf is None: - posinf = torch.finfo(input.dtype).max - if neginf is None: - neginf = torch.finfo(input.dtype).min - if nan != 0: - raise ValueError("nan_to_num only supports nan=0") - return torch.clamp( - input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out - ) - - -# ---------------------------------------------------------------------------- -# Symbolic assert. - -try: - symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access -except AttributeError: - symbolic_assert = torch.Assert # 1.7.0 - -# ---------------------------------------------------------------------------- -# Context manager to temporarily suppress known warnings in torch.jit.trace(). -# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 - -@contextlib.contextmanager -def suppress_tracer_warnings(): # pragma: no cover - """ - Context manager to temporarily suppress known warnings in torch.jit.trace(). - Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 - """ - flt = ("ignore", None, torch.jit.TracerWarning, None, 0) - warnings.filters.insert(0, flt) - yield - warnings.filters.remove(flt) - - -# ---------------------------------------------------------------------------- -# Assert that the shape of a tensor matches the given list of integers. -# None indicates that the size of a dimension is allowed to vary. -# Performs symbolic assertion when used in torch.jit.trace(). - - -def assert_shape(tensor, ref_shape): # pragma: no cover - """ - Assert that the shape of a tensor matches the given list of integers. - None indicates that the size of a dimension is allowed to vary. - Performs symbolic assertion when used in torch.jit.trace(). - """ - if tensor.ndim != len(ref_shape): - raise AssertionError( - f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}" - ) - for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): - if ref_size is None: - pass - elif isinstance(ref_size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert( - torch.equal(torch.as_tensor(size), ref_size), - f"Wrong size for dimension {idx}", - ) - elif isinstance(size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert( - torch.equal(size, torch.as_tensor(ref_size)), - f"Wrong size for dimension {idx}: expected {ref_size}", - ) - elif size != ref_size: - raise AssertionError( - f"Wrong size for dimension {idx}: got {size}, expected {ref_size}" - ) - - -# ---------------------------------------------------------------------------- -# Function decorator that calls torch.autograd.profiler.record_function(). - - -def profiled_function(fn): # pragma: no cover - """Function decorator that calls torch.autograd.profiler.record_function().""" - - def decorator(*args, **kwargs): - with torch.autograd.profiler.record_function(fn.__name__): - return fn(*args, **kwargs) - - decorator.__name__ = fn.__name__ - return decorator + times = [ + t.strftime(time_format) + for t in time_range(start_time, end_time, interval, inclusive=True) + ] + return times # ---------------------------------------------------------------------------- @@ -619,180 +178,3 @@ def __iter__(self) -> Iterator[int]: j = (i - rnd.randint(window)) % order.size order[i], order[j] = order[j], order[i] idx += 1 - - -# ---------------------------------------------------------------------------- -# Utilities for operating with torch.nn.Module parameters and buffers. - - -def params_and_buffers(module): # pragma: no cover - """Get parameters and buffers of a nn.Module""" - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - return list(module.parameters()) + list(module.buffers()) - - -def named_params_and_buffers(module): # pragma: no cover - """Get named parameters and buffers of a nn.Module""" - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - return list(module.named_parameters()) + list(module.named_buffers()) - - -@torch.no_grad() -def copy_params_and_buffers( - src_module, dst_module, require_all=False -): # pragma: no cover - """Copy parameters and buffers from a source module to target module""" - if not isinstance(src_module, torch.nn.Module): - raise TypeError("src_module must be a torch.nn.Module instance") - if not isinstance(dst_module, torch.nn.Module): - raise TypeError("dst_module must be a torch.nn.Module instance") - src_tensors = dict(named_params_and_buffers(src_module)) - for name, tensor in named_params_and_buffers(dst_module): - if not ((name in src_tensors) or (not require_all)): - raise ValueError(f"Missing source tensor for {name}") - if name in src_tensors: - tensor.copy_(src_tensors[name]) - - -# ---------------------------------------------------------------------------- -# Context manager for easily enabling/disabling DistributedDataParallel -# synchronization. - - -@contextlib.contextmanager -def ddp_sync(module, sync): # pragma: no cover - """ - Context manager for easily enabling/disabling DistributedDataParallel - synchronization. - """ - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): - yield - else: - with module.no_sync(): - yield - - -# ---------------------------------------------------------------------------- -# Check DistributedDataParallel consistency across processes. - - -def check_ddp_consistency(module, ignore_regex=None): # pragma: no cover - """Check DistributedDataParallel consistency across processes.""" - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - for name, tensor in named_params_and_buffers(module): - fullname = type(module).__name__ + "." + name - if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): - continue - tensor = tensor.detach() - if tensor.is_floating_point(): - tensor = nan_to_num(tensor) - other = tensor.clone() - torch.distributed.broadcast(tensor=other, src=0) - if not (tensor == other).all(): - raise RuntimeError(f"DDP consistency check failed for {fullname}") - - -# ---------------------------------------------------------------------------- -# Print summary table of module hierarchy. - - -def print_module_summary( - module, inputs, max_nesting=3, skip_redundant=True -): # pragma: no cover - """Print summary table of module hierarchy.""" - if not isinstance(module, torch.nn.Module): - raise TypeError("module must be a torch.nn.Module instance") - if isinstance(module, torch.jit.ScriptModule): - raise TypeError("module must not be a torch.jit.ScriptModule instance") - if not isinstance(inputs, (tuple, list)): - raise TypeError("inputs must be a tuple or list") - - # Register hooks. - entries = [] - nesting = [0] - - def pre_hook(_mod, _inputs): - nesting[0] += 1 - - def post_hook(mod, _inputs, outputs): - nesting[0] -= 1 - if nesting[0] <= max_nesting: - outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] - outputs = [t for t in outputs if isinstance(t, torch.Tensor)] - entries.append(EasyDict(mod=mod, outputs=outputs)) - - hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] - hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] - - # Run module. - outputs = module(*inputs) - for hook in hooks: - hook.remove() - - # Identify unique outputs, parameters, and buffers. - tensors_seen = set() - for e in entries: - e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] - e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] - e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] - tensors_seen |= { - id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs - } - - # Filter out redundant entries. - if skip_redundant: - entries = [ - e - for e in entries - if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs) - ] - - # Construct table. - rows = [ - [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"] - ] - rows += [["---"] * len(rows[0])] - param_total = 0 - buffer_total = 0 - submodule_names = {mod: name for name, mod in module.named_modules()} - for e in entries: - name = "" if e.mod is module else submodule_names[e.mod] - param_size = sum(t.numel() for t in e.unique_params) - buffer_size = sum(t.numel() for t in e.unique_buffers) - output_shapes = [str(list(t.shape)) for t in e.outputs] - output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs] - rows += [ - [ - name + (":0" if len(e.outputs) >= 2 else ""), - str(param_size) if param_size else "-", - str(buffer_size) if buffer_size else "-", - (output_shapes + ["-"])[0], - (output_dtypes + ["-"])[0], - ] - ] - for idx in range(1, len(e.outputs)): - rows += [ - [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]] - ] - param_total += param_size - buffer_total += buffer_size - rows += [["---"] * len(rows[0])] - rows += [["Total", str(param_total), str(buffer_total), "-", "-"]] - - # Print table. - widths = [max(len(cell) for cell in column) for column in zip(*rows)] - for row in rows: - print( - " ".join( - cell + " " * (width - len(cell)) for cell, width in zip(row, widths) - ) - ) - return outputs - - -# ---------------------------------------------------------------------------- diff --git a/src/hirad/utils/generate_utils.py b/src/hirad/utils/generate_utils.py deleted file mode 100644 index 43f83b6..0000000 --- a/src/hirad/utils/generate_utils.py +++ /dev/null @@ -1,22 +0,0 @@ -import datetime -from hirad.datasets import init_dataset_from_config -from .function_utils import convert_datetime_to_cftime - - -def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False): - """ - Get a dataset and sampler for generation. - """ - (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1) - # if has_lead_time: - # plot_times = times - # else: - # plot_times = [ - # datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") - # for time in times - # ] - all_times = dataset.time() - time_indices = [all_times.index(t) for t in times] - sampler = time_indices - - return dataset, sampler \ No newline at end of file diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 8665536..ee4b55a 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -203,117 +203,3 @@ def diffusion_step( def generate(): pass - -############################################################################ -# CorrDiff writer utilities # -############################################################################ - - -class NetCDFWriter: - """NetCDF Writer""" - - def __init__( - self, f, lat, lon, input_channels, output_channels, has_lead_time=False - ): - self._f = f - self.has_lead_time = has_lead_time - # create unlimited dimensions - f.createDimension("time") - f.createDimension("ensemble") - - if lat.shape != lon.shape: - raise ValueError("lat and lon must have the same shape") - ny, nx = lat.shape - - # create lat/lon grid - f.createDimension("x", nx) - f.createDimension("y", ny) - - v = f.createVariable("lat", "f", dimensions=("y", "x")) - # NOTE rethink this for datasets whose samples don't have constant lat-lon. - v[:] = lat - v.standard_name = "latitude" - v.units = "degrees_north" - - v = f.createVariable("lon", "f", dimensions=("y", "x")) - v[:] = lon - v.standard_name = "longitude" - v.units = "degrees_east" - - # create time dimension - if has_lead_time: - v = f.createVariable("time", "str", ("time")) - else: - v = f.createVariable("time", "i8", ("time")) - v.calendar = "standard" - v.units = "hours since 1990-01-01 00:00:00" - - self.truth_group = f.createGroup("truth") - self.prediction_group = f.createGroup("prediction") - self.input_group = f.createGroup("input") - - for variable in output_channels: - name = variable.name + variable.level - self.truth_group.createVariable(name, "f", dimensions=("time", "y", "x")) - self.prediction_group.createVariable( - name, "f", dimensions=("ensemble", "time", "y", "x") - ) - - # setup input data in netCDF - - for variable in input_channels: - name = variable.name + variable.level - self.input_group.createVariable(name, "f", dimensions=("time", "y", "x")) - - def write_input(self, channel_name, time_index, val): - """Write input data to NetCDF file.""" - self.input_group[channel_name][time_index] = val - - def write_truth(self, channel_name, time_index, val): - """Write ground truth data to NetCDF file.""" - self.truth_group[channel_name][time_index] = val - - def write_prediction(self, channel_name, time_index, ensemble_index, val): - """Write prediction data to NetCDF file.""" - self.prediction_group[channel_name][ensemble_index, time_index] = val - - def write_time(self, time_index, time): - """Write time information to NetCDF file.""" - if self.has_lead_time: - self._f["time"][time_index] = time - else: - time_v = self._f["time"] - self._f["time"][time_index] = cftime.date2num( - time, time_v.units, time_v.calendar - ) - - -############################################################################ -# CorrDiff time utilities # -############################################################################ - - -def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"): - """Generates a list of times within a given range. - - Args: - times_range: A list containing start time, end time, and optional interval (hours). - time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S"). - - Returns: - A list of times within the specified range. - """ - - start_time = datetime.datetime.strptime(times_range[0], time_format) - end_time = datetime.datetime.strptime(times_range[1], time_format) - interval = ( - datetime.timedelta(hours=times_range[2]) - if len(times_range) > 2 - else datetime.timedelta(hours=1) - ) - - times = [ - t.strftime(time_format) - for t in time_range(start_time, end_time, interval, inclusive=True) - ] - return times diff --git a/src/hirad/utils/model_utils.py b/src/hirad/utils/model_utils.py deleted file mode 100644 index e1cde9d..0000000 --- a/src/hirad/utils/model_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import numpy as np -import torch - - -def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): - """ - Unified routine for initializing weights and biases. - This function provides a unified interface for various weight initialization - strategies like Xavier (Glorot) and Kaiming (He) initializations. - - Parameters - ---------- - shape : tuple - The shape of the tensor to initialize. It could represent weights or biases - of a layer in a neural network. - mode : str - The mode/type of initialization to use. Supported values are: - - "xavier_uniform": Xavier (Glorot) uniform initialization. - - "xavier_normal": Xavier (Glorot) normal initialization. - - "kaiming_uniform": Kaiming (He) uniform initialization. - - "kaiming_normal": Kaiming (He) normal initialization. - fan_in : int - The number of input units in the weight tensor. For convolutional layers, - this typically represents the number of input channels times the kernel height - times the kernel width. - fan_out : int - The number of output units in the weight tensor. For convolutional layers, - this typically represents the number of output channels times the kernel height - times the kernel width. - - Returns - ------- - torch.Tensor - The initialized tensor based on the specified mode. - - Raises - ------ - ValueError - If the provided `mode` is not one of the supported initialization modes. - """ - if mode == "xavier_uniform": - return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) - if mode == "xavier_normal": - return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) - if mode == "kaiming_uniform": - return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) - if mode == "kaiming_normal": - return np.sqrt(1 / fan_in) * torch.randn(*shape) - raise ValueError(f'Invalid init mode "{mode}"') diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index 218d6f1..dc1a5b9 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -16,7 +16,6 @@ import torch import numpy as np -from omegaconf import ListConfig import warnings @@ -100,11 +99,6 @@ def handle_and_clip_gradients(model, grad_clip_threshold=None): torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold) -def parse_model_args(args): - """Convert ListConfig values in args to tuples.""" - return {k: tuple(v) if isinstance(v, ListConfig) else v for k, v in args.items()} - - def is_time_for_periodic_task( cur_nimg, freq, done, batch_size, rank, rank_0_only=False ): From 054bdf495758299a9aa72015c691ab5700cfdd49 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 17 Jun 2025 16:57:23 +0200 Subject: [PATCH 057/189] restructure inference utils and samplers --- src/hirad/inference/__init__.py | 2 ++ .../{utils => inference}/deterministic_sampler.py | 0 src/hirad/inference/generate.py | 4 ++-- src/hirad/{utils => inference}/stochastic_sampler.py | 0 src/hirad/utils/inference_utils.py | 11 +---------- 5 files changed, 5 insertions(+), 12 deletions(-) create mode 100644 src/hirad/inference/__init__.py rename src/hirad/{utils => inference}/deterministic_sampler.py (100%) rename src/hirad/{utils => inference}/stochastic_sampler.py (100%) diff --git a/src/hirad/inference/__init__.py b/src/hirad/inference/__init__.py new file mode 100644 index 0000000..95edb48 --- /dev/null +++ b/src/hirad/inference/__init__.py @@ -0,0 +1,2 @@ +from .deterministic_sampler import deterministic_sampler +from .stochastic_sampler import stochastic_sampler \ No newline at end of file diff --git a/src/hirad/utils/deterministic_sampler.py b/src/hirad/inference/deterministic_sampler.py similarity index 100% rename from src/hirad/utils/deterministic_sampler.py rename to src/hirad/inference/deterministic_sampler.py diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index b60d6d0..40d42bc 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -19,8 +19,8 @@ from hirad.models import EDMPrecondSuperResolution, UNet from hirad.utils.patching import GridPatching2D -from hirad.utils.stochastic_sampler import stochastic_sampler -from hirad.utils.deterministic_sampler import deterministic_sampler +from hirad.inference import stochastic_sampler +from hirad.inference import deterministic_sampler from hirad.utils.inference_utils import ( regression_step, diffusion_step, diff --git a/src/hirad/utils/stochastic_sampler.py b/src/hirad/inference/stochastic_sampler.py similarity index 100% rename from src/hirad/utils/stochastic_sampler.py rename to src/hirad/inference/stochastic_sampler.py diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index ee4b55a..1dcfe5e 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -14,18 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime from typing import Optional -import cftime import nvtx import torch import tqdm -from .function_utils import StackedRandomGenerator, time_range - -from .stochastic_sampler import stochastic_sampler -from .deterministic_sampler import deterministic_sampler +from .function_utils import StackedRandomGenerator ############################################################################ # CorrDiff Generation Utilities # @@ -199,7 +194,3 @@ def diffusion_step( ) all_images.append(images) return torch.cat(all_images) - - -def generate(): - pass From a02d5a7852e621563b7aad1fc51bacec51b90fea Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 1 Jul 2025 14:36:38 +0200 Subject: [PATCH 058/189] refactor generation --- pyproject.toml | 12 +- src/hirad/conf/dataset/era_cosmo.yaml | 4 +- src/hirad/conf/generate_era_cosmo.yaml | 4 +- src/hirad/conf/generation/era_cosmo.yaml | 8 +- src/hirad/conf/sampler/deterministic.yaml | 5 +- src/hirad/generate.sh | 34 +-- src/hirad/inference/__init__.py | 3 +- src/hirad/inference/generate.py | 253 +++------------------- src/hirad/inference/generator.py | 141 ++++++++++++ src/hirad/train_diffusion.sh | 36 +-- src/hirad/train_regression.sh | 41 ++-- src/hirad/training/train.py | 1 - src/hirad/utils/inference_utils.py | 91 ++++++++ 13 files changed, 334 insertions(+), 299 deletions(-) create mode 100644 src/hirad/inference/generator.py diff --git a/pyproject.toml b/pyproject.toml index 1477899..ff966f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,17 +13,7 @@ readme = "README.md" requires-python = ">=3.12" license = {file = "LICENSE"} -dependencies = [ - "cartopy>=0.24.1", - "cftime>=1.6.4", - "hydra-core>=1.3.2", - "matplotlib>=3.10.1", - "omegaconf>=2.3.0", - "tensorboard>=2.19.0", - "termcolor>=3.1.0", - "torchinfo>=1.8.0", - "treelib>=1.7.1" -] +dependencies = [] [tool.setuptools] package-dir = {"" = "src"} diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml index b1e21e6..5d32f4e 100644 --- a/src/hirad/conf/dataset/era_cosmo.yaml +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -1,3 +1,3 @@ type: era5_cosmo -dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_full -validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_full/validation \ No newline at end of file +dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/train +validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation \ No newline at end of file diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml index 5d7649d..90a5948 100644 --- a/src/hirad/conf/generate_era_cosmo.yaml +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -1,9 +1,9 @@ hydra: job: chdir: true - name: generation_full + name: generation_regression_valid run: - dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + dir: /capstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} # Get defaults defaults: diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index be4219d..2ccbb71 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -2,7 +2,7 @@ num_ensembles: 8 # Number of ensembles to generate per input seed_batch_size: 4 # Size of the batched inference -inference_mode: all +inference_mode: regression # Choose between "all" (regression + diffusion), "regression" or "diffusion" # Patch size. Patch-based sampling will be utilized if these dimensions differ from # img_shape_x and img_shape_y @@ -17,7 +17,7 @@ hr_mean_conditioning: True # Sampling resolution times_range: null times: - - 20160101-0000 + - 20200926-1800 # - 20160101-0600 # - 20160101-1200 has_laed_time: False @@ -35,10 +35,10 @@ perf: # To support multiple workers a threadsafe version of the netCDF library must be used io: - res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_refactoring/checkpoints_diffusion + res_ckpt_path: /capstor/scratch/cscs/pstamenk/diffusion_checkpoints # res_ckpt_path: null # Checkpoint filename for the diffusion model - reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_refactoring/checkpoints_regression + reg_ckpt_path: /capstor/scratch/cscs/pstamenk/regression_checkpoints # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression # Checkpoint filename for the mean predictor model output_path: ./images \ No newline at end of file diff --git a/src/hirad/conf/sampler/deterministic.yaml b/src/hirad/conf/sampler/deterministic.yaml index 856906b..f65e738 100644 --- a/src/hirad/conf/sampler/deterministic.yaml +++ b/src/hirad/conf/sampler/deterministic.yaml @@ -2,7 +2,8 @@ # Deterministic sampler is not implemented correctly in this codebase and shouldn't be used. type: deterministic -num_steps: 9 +params: + num_steps: 9 # Number of denoising steps -solver: euler + solver: euler # ODE solver type: euler is the simplest solver \ No newline at end of file diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index 87c8979..2607431 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -7,19 +7,16 @@ #SBATCH --nodes=1 #SBATCH --ntasks-per-node=1 #SBATCH --gpus-per-node=1 -#SBATCH --cpus-per-task=72 #SBATCH --time=00:30:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.err +#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/regression_generation.log +#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/regression_generation.err ### ENVIRONMENT #### -#SBATCH --uenv=pytorch/v2.6.0:/user-environment -#SBATCH --view=default -#SBATCH -A a-a122 +#SBATCH -A c38 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -34,18 +31,21 @@ export MASTER_PORT=29500 echo "Master port: $MASTER_PORT" # Get number of physical cores using Python -PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -# Use SLURM_NTASKS (number of processes to be launched by torchrun) -LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute threads per process -OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) -export OMP_NUM_THREADS=$OMP_THREADS -echo "Physical cores: $PHYSICAL_CORES" -echo "Local processes: $LOCAL_PROCS" -echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun bash -c " - . ./train_env/bin/activate +srun --container-writable --environment=modulus_env bash -c " + cd HiRAD-Gen + pip install -e . --no-dependencies + pip install Cartopy==0.22.0 python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file diff --git a/src/hirad/inference/__init__.py b/src/hirad/inference/__init__.py index 95edb48..1593b3a 100644 --- a/src/hirad/inference/__init__.py +++ b/src/hirad/inference/__init__.py @@ -1,2 +1,3 @@ from .deterministic_sampler import deterministic_sampler -from .stochastic_sampler import stochastic_sampler \ No newline at end of file +from .stochastic_sampler import stochastic_sampler +from .generator import Generator \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 40d42bc..77e78b2 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -4,27 +4,16 @@ from omegaconf import OmegaConf, DictConfig import torch import torch._dynamo -import nvtx import numpy as np import contextlib from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from concurrent.futures import ThreadPoolExecutor -from functools import partial - -import cartopy.crs as ccrs -from matplotlib import pyplot as plt -from torch.distributed import gather from hirad.models import EDMPrecondSuperResolution, UNet -from hirad.utils.patching import GridPatching2D -from hirad.inference import stochastic_sampler -from hirad.inference import deterministic_sampler -from hirad.utils.inference_utils import ( - regression_step, - diffusion_step, -) +from hirad.inference import Generator +from hirad.utils.inference_utils import save_images from hirad.utils.function_utils import get_time_from_range from hirad.utils.checkpoint import load_checkpoint @@ -32,14 +21,13 @@ from hirad.utils.train_helpers import set_patch_shape -from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig) -> None: """Generate random dowscaled atmospheric states using the techniques described in the paper "Elucidating the Design Space of Diffusion-Based Generative Models". """ - torch.backends.cudnn.enabled = False + # torch.backends.cudnn.enabled = False # Initialize distributed manager DistributedManager.initialize() dist = DistributedManager() @@ -49,14 +37,6 @@ def main(cfg: DictConfig) -> None: logger = PythonLogger("generate") # General python logger logger0 = RankZeroLoggingWrapper(logger, dist) - # Handle the batch size - seeds = list(np.arange(cfg.generation.num_ensembles)) - num_batches = ( - (len(seeds) - 1) // (cfg.generation.seed_batch_size * dist.world_size) + 1 - ) * dist.world_size - all_batches = torch.as_tensor(seeds).tensor_split(num_batches) - rank_batches = all_batches[dist.rank :: dist.world_size] - # Synchronize if dist.world_size > 1: torch.distributed.barrier() @@ -81,26 +61,6 @@ def main(cfg: DictConfig) -> None: img_shape = dataset.image_shape() img_out_channels = len(dataset.output_channels()) - # Parse the patch shape - if cfg.generation.patching: - patch_shape_x = cfg.generation.patch_shape_x - patch_shape_y = cfg.generation.patch_shape_y - else: - patch_shape_x, patch_shape_y = None, None - patch_shape = (patch_shape_y, patch_shape_x) - use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) - if use_patching: - patching = GridPatching2D( - img_shape=img_shape, - patch_shape=patch_shape, - boundary_pix=cfg.generation.boundary_pix, - overlap_pix=cfg.generation.overlap_pix, - ) - logger0.info("Patch-based training enabled") - else: - patching = None - logger0.info("Patch-based training disabled") - # Parse the inference mode if cfg.generation.inference_mode == "regression": load_net_reg, load_net_res = True, False @@ -178,101 +138,34 @@ def main(cfg: DictConfig) -> None: # Overhead of compiling regression network outweights any benefits if net_res: net_res = torch.compile(net_res, mode="reduce-overhead") - - # Partially instantiate the sampler based on the configs - if cfg.sampler.type == "deterministic": - if cfg.generation.hr_mean_conditioning: - raise NotImplementedError( - "High-res mean conditioning is not yet implemented for the deterministic sampler" - ) - sampler_fn = partial( - deterministic_sampler, - num_steps=cfg.sampler.num_steps, - # num_ensembles=cfg.generation.num_ensembles, - solver=cfg.sampler.solver, - ) - elif cfg.sampler.type == "stochastic": - sampler_fn = partial(stochastic_sampler, patching=patching) - else: - raise ValueError(f"Unknown sampling method {cfg.sampling.type}") + generator = Generator( + net_reg=net_reg, + net_res=net_res, + batch_size=cfg.generation.seed_batch_size, + ensemble_size=cfg.generation.num_ensembles, + hr_mean_conditioning=cfg.generation.hr_mean_conditioning, + n_out_channels=img_out_channels, + inference_mode=cfg.generation.inference_mode, + dist=dist, + ) - # Main generation definition - def generate_fn(image_lr, lead_time_label): - with nvtx.annotate("generate_fn", color="green"): - # (1, C, H, W) - image_lr = image_lr.to(memory_format=torch.channels_last) - - if net_reg: - with nvtx.annotate("regression_model", color="yellow"): - image_reg = regression_step( - net=net_reg, - img_lr=image_lr, - latents_shape=( - cfg.generation.seed_batch_size, - img_out_channels, - img_shape[0], - img_shape[1], - ), # (batch_size, C, H, W) - lead_time_label=lead_time_label, - ) - if net_res: - if cfg.generation.hr_mean_conditioning: - mean_hr = image_reg[0:1] - else: - mean_hr = None - with nvtx.annotate("diffusion model", color="purple"): - image_res = diffusion_step( - net=net_res, - sampler_fn=sampler_fn, - img_shape=img_shape, - img_out_channels=img_out_channels, - rank_batches=rank_batches, - img_lr=image_lr.expand( - cfg.generation.seed_batch_size, -1, -1, -1 - ), #.to(memory_format=torch.channels_last), - rank=dist.rank, - device=device, - mean_hr=mean_hr, - lead_time_label=lead_time_label, - ) - if cfg.generation.inference_mode == "regression": - image_out = image_reg - elif cfg.generation.inference_mode == "diffusion": - image_out = image_res - else: - image_out = image_reg[0:1,::] + image_res - - # Gather tensors on rank 0 - if dist.world_size > 1: - if dist.rank == 0: - gathered_tensors = [ - torch.zeros_like( - image_out, dtype=image_out.dtype, device=image_out.device - ) - for _ in range(dist.world_size) - ] - else: - gathered_tensors = None - - torch.distributed.barrier() - gather( - image_out, - gather_list=gathered_tensors if dist.rank == 0 else None, - dst=0, - ) - - if dist.rank == 0: - if cfg.generation.inference_mode != "regression": - return torch.cat(gathered_tensors), image_reg[0:1,::] - return torch.cat(gathered_tensors), None - else: - return None, None - else: - #TODO do this for multi-gpu setting above too - if cfg.generation.inference_mode != "regression": - return image_out, image_reg - return image_out, None + # Parse the patch shape + if cfg.generation.patching: + patch_shape_x = cfg.generation.patch_shape_x + patch_shape_y = cfg.generation.patch_shape_y + else: + patch_shape_x, patch_shape_y = None, None + patch_shape = (patch_shape_y, patch_shape_x) + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + generator.initialize_patching(img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=cfg.generation.boundary_pix, + overlap_pix=cfg.generation.overlap_pix, + ) + sampler_params = cfg.sampler.params if "params" in cfg.sampler else {} + generator.initialize_sampler(cfg.sampler.type, **sampler_params) # generate images output_path = getattr(cfg.generation.io, "output_path", "./outputs") @@ -348,7 +241,8 @@ def elapsed_time(self, _): .to(memory_format=torch.channels_last) ) image_tar = image_tar.to(device=device).to(torch.float32) - image_out, image_reg = generate_fn(image_lr,lead_time_label) + # image_out, image_reg = generate_fn(image_lr,lead_time_label) + image_out, image_reg = generator.generate(image_lr,lead_time_label) if dist.rank == 0: batch_size = image_out.shape[0] # write out data in a seperate thread so we don't hold up inferencing @@ -391,90 +285,5 @@ def elapsed_time(self, _): logger0.info("Generation Completed.") -def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): - - os.makedirs(output_path, exist_ok=True) - - longitudes = dataset.longitude() - latitudes = dataset.latitude() - input_channels = dataset.input_channels() - output_channels = dataset.output_channels() - - target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) - prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) - baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) - if mean_pred is not None: - mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) - - - freqs = {} - power = {} - for idx, channel in enumerate(output_channels): - input_channel_idx = input_channels.index(channel) - - if channel.name=="tp": - target[idx,::] = prepare_precipitaiton(target[idx,:,:]) - prediction[idx,::] = prepare_precipitaiton(prediction[idx,:,:]) - baseline[input_channel_idx,:,:] = prepare_precipitaiton(baseline[input_channel_idx]) - if mean_pred is not None: - mean_pred[idx,::] = prepare_precipitaiton(mean_pred[idx,::]) - - _plot_projection(longitudes, latitudes, target[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-target.jpg')) - _plot_projection(longitudes, latitudes, prediction[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-prediction.jpg')) - _plot_projection(longitudes, latitudes, baseline[input_channel_idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-input.jpg')) - if mean_pred is not None: - _plot_projection(longitudes, latitudes, mean_pred[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-mean_prediction.jpg')) - - _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) - _, prediction_errors = compute_mae(prediction[idx,:,:], target[idx,:,:]) - if mean_pred is not None: - _, mean_prediction_errors = compute_mae(mean_pred[idx,:,:], target[idx,:,:]) - - - plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-baseline-error.jpg')) - plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-prediction-error.jpg')) - if mean_pred is not None: - plot_error_projection(mean_prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-mean-prediction-error.jpg')) - - b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0) - freqs['baseline'] = b_freq - power['baseline'] = b_power - #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) - t_freq, t_power = average_power_spectrum(target[idx,:,:].squeeze(), 2.0) - freqs['target'] = t_freq - power['target'] = t_power - p_freq, p_power = average_power_spectrum(prediction[idx,:,:].squeeze(), 2.0) - freqs['prediction'] = p_freq - power['prediction'] = p_power - if mean_pred is not None: - mp_freq, mp_power = average_power_spectrum(mean_pred[idx,:,:].squeeze(), 2.0) - freqs['mean_prediction'] = mp_freq - power['mean_prediction'] = mp_power - plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) - - -def prepare_precipitaiton(precip_array): - precip_array = np.clip(precip_array, 0, None) - epsilon = 1e-2 - precip_array = precip_array + epsilon - precip_array = np.log(precip_array) - # log_min, log_max = precip_array.min(), precip_array.max() - # precip_array = (precip_array-log_min)/(log_max-log_min) - return precip_array - - -def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): - - """Plot observed or interpolated data in a scatter plot.""" - # TODO: Refactor this somehow, it's not really generalizing well across variables. - fig = plt.figure() - fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) - p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) - ax.coastlines() - ax.gridlines(draw_labels=True) - plt.colorbar(p, label="K", orientation="horizontal") - plt.savefig(filename) - plt.close('all') - if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/hirad/inference/generator.py b/src/hirad/inference/generator.py new file mode 100644 index 0000000..3b9a538 --- /dev/null +++ b/src/hirad/inference/generator.py @@ -0,0 +1,141 @@ +from typing import Callable +from functools import partial +import nvtx +import numpy as np +import torch +from torch.distributed import gather +from hirad.utils.inference_utils import regression_step, diffusion_step +from hirad.distributed import DistributedManager +from hirad.utils.patching import GridPatching2D +from hirad.inference import stochastic_sampler, deterministic_sampler + +class Generator(): + def __init__(self, + net_reg: torch.nn.Module, + net_res: torch.nn.Module, + batch_size: int, + ensemble_size: int, + hr_mean_conditioning: bool, + n_out_channels: int, + inference_mode: str, + dist: DistributedManager, + ): + + self.net_reg = net_reg + self.net_res = net_res + self.batch_size = batch_size + self.hr_mean_conditioning = hr_mean_conditioning + self.n_out_channels = n_out_channels + self.inference_mode = inference_mode + self.ensemble_size = ensemble_size + self.dist = dist + self.get_rank_batches() + self.patching = None + + def get_rank_batches(self): + seeds = list(np.arange(self.ensemble_size)) + num_batches = ( + (len(seeds) - 1) // (self.batch_size * self.dist.world_size) + 1 + ) * self.dist.world_size + all_batches = torch.as_tensor(seeds).tensor_split(num_batches) + self.rank_batches = all_batches[self.dist.rank :: self.dist.world_size] + + def initialize_sampler(self, sampler_type, **sampler_args): + if sampler_type == "deterministic": + if self.hr_mean_conditioning: + raise NotImplementedError( + "High-res mean conditioning is not yet implemented for the deterministic sampler" + ) + self.sampler = partial( + deterministic_sampler, + **sampler_args + ) + elif sampler_type == "stochastic": + self.sampler = partial(stochastic_sampler, patching=self.patching) + else: + raise ValueError(f"Unknown sampling method {sampler_type}") + + def initialize_patching(self, img_shape, patch_shape, boundary_pix, overlap_pix): + self.patching = GridPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=boundary_pix, + overlap_pix=overlap_pix, + ) + + def generate(self, image_lr, lead_time_label=None): + with nvtx.annotate("generate_fn", color="green"): + # (1, C, H, W) + image_lr = image_lr.to(memory_format=torch.channels_last) + img_shape = image_lr.shape[-2:] + + if self.net_reg: + with nvtx.annotate("regression_model", color="yellow"): + image_reg = regression_step( + net=self.net_reg, + img_lr=image_lr, + latents_shape=( + self.batch_size, + self.n_out_channels, + img_shape[0], + img_shape[1], + ), # (batch_size, C, H, W) + lead_time_label=lead_time_label, + ) + if self.net_res: + if self.hr_mean_conditioning: + mean_hr = image_reg[0:1] + else: + mean_hr = None + with nvtx.annotate("diffusion model", color="purple"): + image_res = diffusion_step( + net=self.net_res, + sampler_fn=self.sampler, + img_shape=img_shape, + img_out_channels=self.n_out_channels, + rank_batches=self.rank_batches, + img_lr=image_lr.expand( + self.batch_size, -1, -1, -1 + ).to(memory_format=torch.channels_last), #.to(memory_format=torch.channels_last), + rank=self.dist.rank, + device=image_lr.device, + mean_hr=mean_hr, + lead_time_label=lead_time_label, + ) + if self.inference_mode == "regression": + image_out = image_reg + elif self.inference_mode == "diffusion": + image_out = image_res + else: + image_out = image_reg[0:1,::] + image_res + + # Gather tensors on rank 0 + if self.dist.world_size > 1: + if self.dist.rank == 0: + gathered_tensors = [ + torch.zeros_like( + image_out, dtype=image_out.dtype, device=image_out.device + ) + for _ in range(self.dist.world_size) + ] + else: + gathered_tensors = None + + torch.distributed.barrier() + gather( + image_out, + gather_list=gathered_tensors if self.dist.rank == 0 else None, + dst=0, + ) + + if self.dist.rank == 0: + if self.inference_mode != "regression": + return torch.cat(gathered_tensors), image_reg[0:1,::] + return torch.cat(gathered_tensors), None + else: + return None, None + else: + #TODO do this for multi-gpu setting above too + if self.inference_mode != "regression": + return image_out, image_reg + return image_out, None diff --git a/src/hirad/train_diffusion.sh b/src/hirad/train_diffusion.sh index cf2f88f..d8515db 100644 --- a/src/hirad/train_diffusion.sh +++ b/src/hirad/train_diffusion.sh @@ -1,25 +1,23 @@ #!/bin/bash -#SBATCH --job-name="testrun" +#SBATCH --job-name="corrdiff-second-stage" ### HARDWARE ### -#SBATCH --partition=debug -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 +#SBATCH --partition=normal +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 #SBATCH --cpus-per-task=72 -#SBATCH --time=00:30:00 +#SBATCH --time=24:00:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.err +#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/diffusion_full.log +#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/diffusion_full.err ### ENVIRONMENT #### -#SBATCH --uenv=pytorch/v2.6.0:/user-environment -#SBATCH --view=default -#SBATCH -A a-a122 +#SBATCH -A c38 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -32,14 +30,16 @@ export MASTER_ADDR export MASTER_PORT=29500 # Get number of physical cores using Python -PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute cores per process -OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) -export OMP_NUM_THREADS=$OMP_THREADS +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute cores per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun bash -c " - . ./train_env/bin/activate +srun --container-writable --environment=modulus_env bash -c " + cd HiRAD-Gen + pip install -e . --no-dependencies python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml " \ No newline at end of file diff --git a/src/hirad/train_regression.sh b/src/hirad/train_regression.sh index c065477..7499bcf 100644 --- a/src/hirad/train_regression.sh +++ b/src/hirad/train_regression.sh @@ -1,25 +1,23 @@ #!/bin/bash -#SBATCH --job-name="testrun" +#SBATCH --job-name="corrdiff-first-stage" ### HARDWARE ### -#SBATCH --partition=debug -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 +#SBATCH --partition=normal +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 #SBATCH --cpus-per-task=72 -#SBATCH --time=00:30:00 +#SBATCH --time=06:00:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression.err +#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/regression_full_run.log +#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/regression_full_run.err ### ENVIRONMENT #### -#SBATCH --uenv=pytorch/v2.6.0:/user-environment -#SBATCH --view=default -#SBATCH -A a-a122 +#SBATCH -A c38 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -32,14 +30,19 @@ export MASTER_ADDR export MASTER_PORT=29500 # Get number of physical cores using Python -PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute cores per process -OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) -export OMP_NUM_THREADS=$OMP_THREADS - +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute cores per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun bash -c " - . ./train_env/bin/activate +# srun bash -c " +# . ./train_env/bin/activate +# python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml +# " +srun --container-writable --environment=modulus_env bash -c " + cd HiRAD-Gen + pip install -e . --no-dependencies python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml " \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 12b6942..7ea7a71 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -500,7 +500,6 @@ def main(cfg: DictConfig) -> None: / num_accumulation_rounds / len(patch_nums_iter) ) - loss_accum += loss / num_accumulation_rounds with nvtx.annotate(f"loss backward", color="yellow"): loss.backward() diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 1dcfe5e..7698a82 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -15,12 +15,17 @@ # limitations under the License. from typing import Optional +import os import nvtx +import numpy as np import torch import tqdm +from matplotlib import pyplot as plt +import cartopy.crs as ccrs from .function_utils import StackedRandomGenerator +from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra ############################################################################ # CorrDiff Generation Utilities # @@ -194,3 +199,89 @@ def diffusion_step( ) all_images.append(images) return torch.cat(all_images) + + +def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): + + os.makedirs(output_path, exist_ok=True) + + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + + target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) + baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) + if mean_pred is not None: + mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + + + freqs = {} + power = {} + for idx, channel in enumerate(output_channels): + input_channel_idx = input_channels.index(channel) + + if channel.name=="tp": + target[idx,::] = _prepare_precipitaiton(target[idx,:,:]) + prediction[idx,::] = _prepare_precipitaiton(prediction[idx,:,:]) + baseline[input_channel_idx,:,:] = _prepare_precipitaiton(baseline[input_channel_idx]) + if mean_pred is not None: + mean_pred[idx,::] = _prepare_precipitaiton(mean_pred[idx,::]) + + _plot_projection(longitudes, latitudes, target[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-target.jpg')) + _plot_projection(longitudes, latitudes, prediction[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-prediction.jpg')) + _plot_projection(longitudes, latitudes, baseline[input_channel_idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-input.jpg')) + if mean_pred is not None: + _plot_projection(longitudes, latitudes, mean_pred[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-mean_prediction.jpg')) + + _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) + _, prediction_errors = compute_mae(prediction[idx,:,:], target[idx,:,:]) + if mean_pred is not None: + _, mean_prediction_errors = compute_mae(mean_pred[idx,:,:], target[idx,:,:]) + + + plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-baseline-error.jpg')) + plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-prediction-error.jpg')) + if mean_pred is not None: + plot_error_projection(mean_prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-mean-prediction-error.jpg')) + + b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0) + freqs['baseline'] = b_freq + power['baseline'] = b_power + #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) + t_freq, t_power = average_power_spectrum(target[idx,:,:].squeeze(), 2.0) + freqs['target'] = t_freq + power['target'] = t_power + p_freq, p_power = average_power_spectrum(prediction[idx,:,:].squeeze(), 2.0) + freqs['prediction'] = p_freq + power['prediction'] = p_power + if mean_pred is not None: + mp_freq, mp_power = average_power_spectrum(mean_pred[idx,:,:].squeeze(), 2.0) + freqs['mean_prediction'] = mp_freq + power['mean_prediction'] = mp_power + plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) + + +def _prepare_precipitaiton(precip_array): + precip_array = np.clip(precip_array, 0, None) + epsilon = 1e-2 + precip_array = precip_array + epsilon + precip_array = np.log(precip_array) + # log_min, log_max = precip_array.min(), precip_array.max() + # precip_array = (precip_array-log_min)/(log_max-log_min) + return precip_array + + +def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): + + """Plot observed or interpolated data in a scatter plot.""" + # TODO: Refactor this somehow, it's not really generalizing well across variables. + fig = plt.figure() + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) + ax.coastlines() + ax.gridlines(draw_labels=True) + plt.colorbar(p, label="K", orientation="horizontal") + plt.savefig(filename) + plt.close('all') \ No newline at end of file From be6170dfd5302157f021254dd3b2200c0a772c14 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 1 Jul 2025 18:34:13 +0200 Subject: [PATCH 059/189] Add some initial CRPS code --- src/hirad/eval/__init__.py | 2 +- src/hirad/eval/metrics.py | 23 +++++++++++++++++++++++ src/hirad/generate.sh | 5 +++-- src/hirad/utils/inference_utils.py | 9 ++++++++- 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/src/hirad/eval/__init__.py b/src/hirad/eval/__init__.py index 13d9eb3..a4b9475 100644 --- a/src/hirad/eval/__init__.py +++ b/src/hirad/eval/__init__.py @@ -1,2 +1,2 @@ -from .metrics import compute_mae, average_power_spectrum +from .metrics import compute_mae, average_power_spectrum, crps from .plotting import plot_error_projection, plot_power_spectra diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index 1170ea1..e508cf2 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -4,6 +4,8 @@ import torch from scipy.signal import periodogram +import xskillscore +import xarray as xr # set up MAE calculation to be run for each channel for a given date/time (for target COSMO, prediction, and ERA interpolated) @@ -55,3 +57,24 @@ def average_power_spectrum(data: np.ndarray, d=2.0): # d=2km by default logging.info(f'power spectra shape={power_spectra.shape}') return freqs, power_spectra + +def crps(prediction_ensemble, target, average_over_area=True, average_over_channels=True): + # Plot CRPS + observations = xr.DataArray(target, + coords = [('channel', np.arange(target.shape[0])), + ('x', np.arange(target.shape[1])), + ('y', np.arange(target.shape[2]))]) + + forecasts = xr.DataArray(prediction_ensemble, + coords = [('member', np.arange(prediction_ensemble.shape[0])), + ('channel', np.arange(prediction_ensemble.shape[1])), + ('x', np.arange(prediction_ensemble.shape[2])), + ('y', np.arange(prediction_ensemble.shape[3]))]) + dim = [] + if average_over_area: + dim.append('x') + dim.append('y') + if average_over_channels: + dim.append('channel') + crps = xskillscore.crps_ensemble(observations=observations, forecasts=forecasts, dim=dim) + crps = crps.to_numpy() diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index 2607431..7237d2a 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -12,8 +12,8 @@ #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/regression_generation.log -#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/regression_generation.err +#SBATCH --output=/capstor/scratch/cscs/$USER/logs/regression_generation.log +#SBATCH --error=/capstor/scratch/cscs/$USER/logs/regression_generation.err ### ENVIRONMENT #### #SBATCH -A c38 @@ -47,5 +47,6 @@ srun --container-writable --environment=modulus_env bash -c " cd HiRAD-Gen pip install -e . --no-dependencies pip install Cartopy==0.22.0 + pip install xskillscore python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 7698a82..515e000 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -25,7 +25,7 @@ import cartopy.crs as ccrs from .function_utils import StackedRandomGenerator -from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra +from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra, crps ############################################################################ # CorrDiff Generation Utilities # @@ -211,12 +211,19 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, output_channels = dataset.output_channels() target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + # prediction.shape = (num_channels, X, Y) prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) + # prediction_ensemble.shape = (num_ensembles, num_channels, X, Y) + prediction_ensemble = np.flip(dataset.denormalize_output(image_pred.squeeze()),1) #.reshape(len(output_channels),-1) baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) if mean_pred is not None: mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + # Plot CRPS + crps_score = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=True) + _plot_projection(longitudes, latitudes, crps_score, os.path.join(output_path, f'{time_step}-crps.jpg')) + # Plot power spectra freqs = {} power = {} for idx, channel in enumerate(output_channels): From 833242b2021cff35b382cdbb904779d2efce3586 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 2 Jul 2025 12:11:24 +0200 Subject: [PATCH 060/189] Fixes to CRPS --- src/hirad/eval/metrics.py | 1 + src/hirad/generate.sh | 4 ++-- src/hirad/utils/inference_utils.py | 5 ++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index e508cf2..6b2a4b2 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -78,3 +78,4 @@ def crps(prediction_ensemble, target, average_over_area=True, average_over_chann dim.append('channel') crps = xskillscore.crps_ensemble(observations=observations, forecasts=forecasts, dim=dim) crps = crps.to_numpy() + return crps diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index 7237d2a..3876ade 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -12,8 +12,8 @@ #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/capstor/scratch/cscs/$USER/logs/regression_generation.log -#SBATCH --error=/capstor/scratch/cscs/$USER/logs/regression_generation.err +#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/regression_generation.log +#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/regression_generation.err ### ENVIRONMENT #### #SBATCH -A c38 diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 515e000..699d7b6 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -214,7 +214,10 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, # prediction.shape = (num_channels, X, Y) prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) # prediction_ensemble.shape = (num_ensembles, num_channels, X, Y) - prediction_ensemble = np.flip(dataset.denormalize_output(image_pred.squeeze()),1) #.reshape(len(output_channels),-1) + prediction_ensemble = np.ndarray(image_pred.shape) + for i in range(image_pred.shape[0]): + prediction_ensemble[i,::] = np.flip(dataset.denormalize_output(image_pred[i,::].squeeze()),1) + prediction_ensemble = np.flip(dataset.denormalize_output(image_pred.squeeze()),2) #.reshape(len(output_channels),-1) baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) if mean_pred is not None: mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) From 68ab63e4efb09765e84d06776b53ed8a2fd2686a Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 2 Jul 2025 12:49:48 +0200 Subject: [PATCH 061/189] Adds time capability to crps --- src/hirad/eval/metrics.py | 28 +++++++++++++++++++++------- src/hirad/utils/inference_utils.py | 1 - 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index 6b2a4b2..01594d6 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -58,19 +58,33 @@ def average_power_spectrum(data: np.ndarray, d=2.0): # d=2km by default return freqs, power_spectra -def crps(prediction_ensemble, target, average_over_area=True, average_over_channels=True): - # Plot CRPS +def crps(prediction_ensemble, target, average_over_area=True, average_over_channels=True, average_over_time=True): + # Assumes that prediction_ensemble is in form: + # (member, channel, x, y) or + # (time, member, channel, x, y) + # Returns: a k-dimensional array of continuous ranked probability scores, + # where k is the number of dimensions that were not averaged over. + # For example, if average_over_area is False (and all others true), will + # return an ndarray of shape (X,Y) observations = xr.DataArray(target, coords = [('channel', np.arange(target.shape[0])), ('x', np.arange(target.shape[1])), ('y', np.arange(target.shape[2]))]) - forecasts = xr.DataArray(prediction_ensemble, - coords = [('member', np.arange(prediction_ensemble.shape[0])), - ('channel', np.arange(prediction_ensemble.shape[1])), - ('x', np.arange(prediction_ensemble.shape[2])), - ('y', np.arange(prediction_ensemble.shape[3]))]) + forecasts_coords = [('member', np.arange(prediction_ensemble.shape[-4])), + ('channel', np.arange(prediction_ensemble.shape[-3])), + ('x', np.arange(prediction_ensemble.shape[-2])), + ('y', np.arange(prediction_ensemble.shape[-1]))] + + if prediction_ensemble.ndim > 4: + forecasts.coords.insert(0, ('time', np.arange(prediction_ensemble.shape[-5]))) + + + forecasts = xr.DataArray(prediction_ensemble, coords = forecasts_coords) + dim = [] + if prediction_ensemble.ndim > 4 and average_over_time: + dim.append('time') if average_over_area: dim.append('x') dim.append('y') diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 699d7b6..af6fda5 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -292,6 +292,5 @@ def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) ax.coastlines() ax.gridlines(draw_labels=True) - plt.colorbar(p, label="K", orientation="horizontal") plt.savefig(filename) plt.close('all') \ No newline at end of file From 826601537092d05a706214bc8e1c8bc5a58112b0 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 2 Jul 2025 13:23:11 +0200 Subject: [PATCH 062/189] Plot CRPS per channel --- src/hirad/utils/inference_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index af6fda5..9047bcd 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -224,7 +224,10 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, # Plot CRPS crps_score = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=True) - _plot_projection(longitudes, latitudes, crps_score, os.path.join(output_path, f'{time_step}-crps.jpg')) + _plot_projection(longitudes, latitudes, crps_score, os.path.join(output_path, f'{time_step}-crps-all.jpg')) + crps_score_channels = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False) + for channel_num in range(crps_score_channels.shape[0]): + _plot_projection(longitudes, latitudes, crps_score_channels[channel_num,::], os.path.join(output_path, f'{time_step}-crps-{output_channels[channel_num].name}.jpg')) # Plot power spectra freqs = {} @@ -292,5 +295,6 @@ def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) ax.coastlines() ax.gridlines(draw_labels=True) + plt.colorbar(p, orientation="horizontal") plt.savefig(filename) plt.close('all') \ No newline at end of file From a89fa4684fe25fd513ad69e839db023124bcbc56 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Fri, 4 Jul 2025 10:01:00 +0200 Subject: [PATCH 063/189] Adds plotting CRPS over time (not yet tested) --- src/hirad/eval/metrics.py | 15 +++++++----- src/hirad/utils/inference_utils.py | 38 ++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index 01594d6..add4023 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -66,21 +66,24 @@ def crps(prediction_ensemble, target, average_over_area=True, average_over_chann # where k is the number of dimensions that were not averaged over. # For example, if average_over_area is False (and all others true), will # return an ndarray of shape (X,Y) - observations = xr.DataArray(target, - coords = [('channel', np.arange(target.shape[0])), - ('x', np.arange(target.shape[1])), - ('y', np.arange(target.shape[2]))]) + target_coords = [('channel', np.arange(target.shape[-3])), + ('x', np.arange(target.shape[-2])), + ('y', np.arange(target.shape[-1]))] + forecasts_coords = [('member', np.arange(prediction_ensemble.shape[-4])), ('channel', np.arange(prediction_ensemble.shape[-3])), ('x', np.arange(prediction_ensemble.shape[-2])), ('y', np.arange(prediction_ensemble.shape[-1]))] - if prediction_ensemble.ndim > 4: - forecasts.coords.insert(0, ('time', np.arange(prediction_ensemble.shape[-5]))) + if prediction_ensemble.ndim > 4 and target.ndim > 3: + forecasts_coords.insert(0, ('time', np.arange(prediction_ensemble.shape[-5]))) + target_coords.insert(0, ('time', np.arange(target.shape[-4]))) + forecasts = xr.DataArray(prediction_ensemble, coords = forecasts_coords) + observations = xr.DataArray(target, coords = target_coords) dim = [] if prediction_ensemble.ndim > 4 and average_over_time: diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 9047bcd..f90b8a0 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -201,6 +201,26 @@ def diffusion_step( return torch.cat(all_images) +def save_results_as_torch(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): + + os.makedirs(output_path, exist_ok=True) + + target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + # prediction.shape = (num_channels, X, Y) + # prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) + # prediction_ensemble.shape = (num_ensembles, num_channels, X, Y) + prediction_ensemble = np.ndarray(image_pred.shape) + for i in range(image_pred.shape[0]): + prediction_ensemble[i,::] = np.flip(dataset.denormalize_output(image_pred[i,::].squeeze()),1) + prediction_ensemble = np.flip(dataset.denormalize_output(image_pred.squeeze()),2) #.reshape(len(output_channels),-1) + baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) + if mean_pred is not None: + mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + torch.save(target, os.path.join(output_path, f'{time_step}-target')) + torch.save(prediction_ensemble, os.path.join(output_path, f'{time_step}-predictions')) + torch.save(baseline, os.path.join(output_path, f'{time_step}-baseline')) + + def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): os.makedirs(output_path, exist_ok=True) @@ -275,6 +295,24 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, power['mean_prediction'] = mp_power plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) +def plot_crps_over_time(times, dataset, output_path): + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + + prediction_ensemble = torch.load(os.join(output_path, f'{times[0]}-predictions')) + all_predictions = np.ndarray((len(times), prediction_ensemble.shape[0], prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) + all_targets = np.ndarray(len(times), prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3]) + for i in range(len(times)): + prediction_ensemble = torch.load(os.join(output_path, f'{times[i]}-predictions')) + all_predictions[i,::] = prediction_ensemble + target = torch.load(os.join(output_path, f'{times[i]}-target')) + all_targets[i,::] = target + score_over_time_channels = crps_score = crps(all_predictions, all_targets, average_over_area=True, average_over_channels=False, average_over_time=False) + score_over_area_channels = crps(all_predictions, all_targets, average_over_area=False, average_over_channels=False, average_over_time=True) + for channel_num in range(score_over_area_channels.shape[0]): + _plot_projection(longitudes, latitudes, score_over_area_channels[channel_num,::], os.path.join(output_path, f'all-time-crps-{output_channels[channel_num].name}.jpg')) def _prepare_precipitaiton(precip_array): precip_array = np.clip(precip_array, 0, None) From 99b3b3770673fc57ef91e36a968290ce93637c98 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 8 Jul 2025 12:58:06 +0200 Subject: [PATCH 064/189] Add plotting CRPS across time as well as area. --- src/hirad/inference/generate.py | 23 +++++++++++++++++++++-- src/hirad/utils/inference_utils.py | 28 +++++++++++++++++++++------- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 77e78b2..7696f7f 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -13,7 +13,7 @@ from hirad.models import EDMPrecondSuperResolution, UNet from hirad.inference import Generator -from hirad.utils.inference_utils import save_images +from hirad.utils.inference_utils import save_images, save_results_as_torch, plot_crps_over_time from hirad.utils.function_utils import get_time_from_range from hirad.utils.checkpoint import load_checkpoint @@ -246,9 +246,23 @@ def elapsed_time(self, _): if dist.rank == 0: batch_size = image_out.shape[0] # write out data in a seperate thread so we don't hold up inferencing + + if not cfg.generation.times_range: + writer_threads.append( + writer_executor.submit( + save_images, + output_path, + times[sampler[time_index]], + dataset, + image_out.cpu().numpy(), + image_tar.cpu().numpy(), + image_lr.cpu().numpy(), + image_reg.cpu().numpy() if image_reg is not None else None, + ) + ) writer_threads.append( writer_executor.submit( - save_images, + save_results_as_torch, output_path, times[sampler[time_index]], dataset, @@ -284,6 +298,11 @@ def elapsed_time(self, _): f.close() logger0.info("Generation Completed.") + if cfg.generation.times_range: + # reassign times + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") #TODO check what time formats we are using and adapt + plot_crps_over_time(times, dataset, output_path) + if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index f90b8a0..a57a107 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -16,6 +16,7 @@ from typing import Optional import os +import logging import nvtx import numpy as np @@ -295,24 +296,37 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, power['mean_prediction'] = mp_power plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) -def plot_crps_over_time(times, dataset, output_path): +def plot_crps_over_time(times, dataset, output_path): longitudes = dataset.longitude() latitudes = dataset.latitude() input_channels = dataset.input_channels() output_channels = dataset.output_channels() + start_time=times[0] + end_time=times[-1] - prediction_ensemble = torch.load(os.join(output_path, f'{times[0]}-predictions')) + prediction_ensemble = torch.load(os.path.join(output_path, f'{times[0]}-predictions'), weights_only=False) all_predictions = np.ndarray((len(times), prediction_ensemble.shape[0], prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) - all_targets = np.ndarray(len(times), prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3]) + all_targets = np.ndarray((len(times), prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) for i in range(len(times)): - prediction_ensemble = torch.load(os.join(output_path, f'{times[i]}-predictions')) + prediction_ensemble = torch.load(os.path.join(output_path, f'{times[i]}-predictions'), weights_only=False) all_predictions[i,::] = prediction_ensemble - target = torch.load(os.join(output_path, f'{times[i]}-target')) + target = torch.load(os.path.join(output_path, f'{times[i]}-target'), weights_only=False) all_targets[i,::] = target - score_over_time_channels = crps_score = crps(all_predictions, all_targets, average_over_area=True, average_over_channels=False, average_over_time=False) + score_over_time_channels = crps(all_predictions, all_targets, average_over_area=True, average_over_channels=False, average_over_time=False) score_over_area_channels = crps(all_predictions, all_targets, average_over_area=False, average_over_channels=False, average_over_time=True) for channel_num in range(score_over_area_channels.shape[0]): - _plot_projection(longitudes, latitudes, score_over_area_channels[channel_num,::], os.path.join(output_path, f'all-time-crps-{output_channels[channel_num].name}.jpg')) + _plot_projection(longitudes, latitudes, score_over_area_channels[channel_num,::], os.path.join(output_path, f'crps-time-{start_time}-{end_time}-{output_channels[channel_num].name}.jpg')) + _plot_score_vs_t(score_over_time_channels[:, channel_num], times, os.path.join(output_path, f'crps-area-{start_time}-{end_time}-{output_channels[channel_num].name}.jpg')) + +def _plot_score_vs_t(score: np.array, times: np.array, filename: str): + fig = plt.figure() + ax = plt.subplot() + p = plt.plot(times, score) + #plt.ylabel('CRPS') + #plt.xlabel('time') + plt.xticks([times[0],times[-1]]) + plt.savefig(filename) + plt.close('all') def _prepare_precipitaiton(precip_array): precip_array = np.clip(precip_array, 0, None) From 88ee662293e1bf3606652d59049eb9e369fed1e6 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 8 Jul 2025 13:08:37 +0200 Subject: [PATCH 065/189] add visualization during training --- src/hirad/inference/generate.py | 4 +- src/hirad/inference/generator.py | 1 - src/hirad/inference/stochastic_sampler.py | 9 +- src/hirad/models/layers.py | 4 +- src/hirad/training/train.py | 139 +++++++++++++++++++++- src/hirad/utils/inference_utils.py | 84 +++++++++---- 6 files changed, 207 insertions(+), 34 deletions(-) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 77e78b2..804e141 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -230,6 +230,8 @@ def elapsed_time(self, _): if time_index == warmup_steps: start.record() + savedir = os.path.join(output_path,f"{times[sampler[time_index]]}") + os.makedirs(savedir,exist_ok=True) # continue if lead_time_label: lead_time_label = lead_time_label[0].to(dist.device).contiguous() @@ -249,7 +251,7 @@ def elapsed_time(self, _): writer_threads.append( writer_executor.submit( save_images, - output_path, + savedir, times[sampler[time_index]], dataset, image_out.cpu().numpy(), diff --git a/src/hirad/inference/generator.py b/src/hirad/inference/generator.py index 3b9a538..b1e10e1 100644 --- a/src/hirad/inference/generator.py +++ b/src/hirad/inference/generator.py @@ -66,7 +66,6 @@ def initialize_patching(self, img_shape, patch_shape, boundary_pix, overlap_pix) def generate(self, image_lr, lead_time_label=None): with nvtx.annotate("generate_fn", color="green"): # (1, C, H, W) - image_lr = image_lr.to(memory_format=torch.channels_last) img_shape = image_lr.shape[-2:] if self.net_reg: diff --git a/src/hirad/inference/stochastic_sampler.py b/src/hirad/inference/stochastic_sampler.py index 198fde4..606c911 100644 --- a/src/hirad/inference/stochastic_sampler.py +++ b/src/hirad/inference/stochastic_sampler.py @@ -130,8 +130,8 @@ def stochastic_sampler( # Adjust noise levels based on what's supported by the network. # Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution. - sigma_min = max(sigma_min, net.sigma_min) - sigma_max = min(sigma_max, net.sigma_max) + sigma_min = max(sigma_min, net.module.sigma_min if hasattr(net, "module") else net.sigma_min) + sigma_max = min(sigma_max, net.module.sigma_max if hasattr(net, "module") else net.sigma_max) if patching is not None and not isinstance(patching, GridPatching2D): raise ValueError("patching must be an instance of GridPatching2D.") @@ -162,7 +162,8 @@ def stochastic_sampler( * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho t_steps = torch.cat( - [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + [net.module.round_sigma(t_steps) if hasattr(net, "module") else net.round_sigma(t_steps), + torch.zeros_like(t_steps[:1])] ) # t_N = 0 batch_size = img_lr.shape[0] @@ -198,7 +199,7 @@ def patch_embedding_selector(emb): x_cur = x_next # Increase noise temporarily. gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 - t_hat = net.round_sigma(t_cur + gamma * t_cur) + t_hat = net.module.round_sigma(t_cur + gamma * t_cur) if hasattr(net, "module") else net.round_sigma(t_cur + gamma * t_cur) x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index 96bb37f..4da26b1 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -26,7 +26,7 @@ import numpy as np import nvtx import torch -import torch.cuda.amp as amp +import torch.amp as amp from einops import rearrange from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh @@ -700,7 +700,7 @@ def forward(self, x, emb): # w = AttentionOp.apply(q, k) # a = torch.einsum("nqk,nck->ncq", w, v) # Compute attention in one step - with amp.autocast(enabled=self.amp_mode): + with amp.autocast(x.device.type, enabled=self.amp_mode): attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = self.proj(attn.reshape(*x.shape)).add_(x) x = x * self.skip_scale diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 7ea7a71..931fa66 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -1,6 +1,8 @@ import os import time +from concurrent.futures import ThreadPoolExecutor + import psutil import hydra from omegaconf import DictConfig, OmegaConf @@ -20,11 +22,12 @@ is_time_for_periodic_task, handle_and_clip_gradients from hirad.utils.checkpoint import load_checkpoint, save_checkpoint from hirad.utils.patching import RandomPatching2D -from hirad.models import UNet, EDMPrecondSuperResolution, EDMPrecondSR +from hirad.utils.function_utils import get_time_from_range +from hirad.utils.inference_utils import save_images +from hirad.models import UNet, EDMPrecondSuperResolution from hirad.losses import ResidualLoss, RegressionLoss, RegressionLossCE -from hirad.datasets import init_train_valid_datasets_from_config - -from matplotlib import pyplot as plt +from hirad.datasets import init_train_valid_datasets_from_config, get_dataset_and_sampler_inference +from hirad.inference import Generator torch._dynamo.reset() # Increase the cache size limit @@ -85,6 +88,14 @@ def main(cfg: DictConfig) -> None: ) if dist.rank==0 and not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # added creating checkpoint dir + visualize_checkpoints = False + if hasattr(cfg, "generation"): + visualization_dir = os.path.join( + cfg.training.io.get("checkpoint_dir", "."), "visualization" + ) + if dist.rank==0 and not os.path.exists(visualization_dir): + os.makedirs(visualization_dir) # added creating checkpoint dir + visualize_checkpoints = True if cfg.training.hp.batch_size_per_gpu == "auto": cfg.training.hp.batch_size_per_gpu = ( cfg.training.hp.total_batch_size // dist.world_size @@ -127,6 +138,22 @@ def main(cfg: DictConfig) -> None: else: prob_channels = None + if visualize_checkpoints: + # Parse the inference input times + if cfg.generation.times_range and cfg.generation.times: + raise ValueError("Either times_range or times must be provided, but not both") + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") #TODO check what time formats we are using and adapt + else: + times = cfg.generation.times + viz_dataset_cfg = OmegaConf.to_container(cfg.generation.dataset) + visualization_dataset, visualization_sampler = get_dataset_and_sampler_inference( + dataset_cfg=viz_dataset_cfg, times=times + ) + visualization_data_loader = torch.utils.data.DataLoader( + dataset=visualization_dataset, sampler=visualization_sampler, batch_size=1, pin_memory=True + ) + # Parse the patch shape #TODO figure out patched diffusion and how to use it if ( @@ -395,6 +422,27 @@ def main(cfg: DictConfig) -> None: except: cur_nimg = 0 + # init the generator for inference to visualize checkpoint results + if visualize_checkpoints: + generator = Generator( + net_reg= model if regression_net is None else regression_net, + net_res= model if regression_net is not None else None, + batch_size=int(cfg.generation.num_ensembles//dist.world_size), + ensemble_size=cfg.generation.num_ensembles, + hr_mean_conditioning=cfg.model.hr_mean_conditioning, + n_out_channels=img_out_channels, + inference_mode="all" if regression_net is not None else "regression", + dist=dist, + ) + if use_patching: + generator.initialize_patching(img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=cfg.generation.boundary_pix, + overlap_pix=cfg.generation.overlap_pix, + ) + sampler_params = cfg.generation.sampler.params if "params" in cfg.generation.sampler else {} + generator.initialize_sampler(cfg.generation.sampler.type, **sampler_params) + ############################################################################ # MAIN TRAINING LOOP # ############################################################################ @@ -713,6 +761,89 @@ def main(cfg: DictConfig) -> None: epoch=cur_nimg, ) + # Visualize samples + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.save_checkpoint_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + if visualize_checkpoints: + with nvtx.annotate("validation", color="red"): + if dist.rank == 0: + writer_executor = ThreadPoolExecutor( + max_workers=cfg.generation.perf.num_writer_workers + ) + writer_threads = [] + + times = visualization_dataset.time() + time_index = -1 + for index, (img_clean_viz, img_lr_viz, *lead_time_label_viz) in enumerate( + iter(visualization_data_loader) + ): + time_index += 1 + logger0.info(f"starting index: {time_index}") + + # continue + if lead_time_label_viz: + lead_time_label_viz = lead_time_label_viz[0].to(dist.device).contiguous() + else: + lead_time_label_viz = None + + if use_apex_gn: + img_clean_viz = img_clean_viz.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_viz = img_lr_viz.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean_viz = ( + img_clean_viz.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_viz = ( + img_lr_viz.to(dist.device) + .to(input_dtype) + .contiguous() + ) + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + image_pred_viz, image_reg_viz = generator.generate(img_lr_viz, lead_time_label_viz) + if dist.rank == 0: + # write out data in a seperate thread so we don't hold up inferencing + output_path = os.path.join(visualization_dir, f"{cur_nimg}_{times[visualization_sampler[time_index]]}") + if dist.rank==0 and not os.path.exists(output_path): + os.makedirs(output_path) + writer_threads.append( + writer_executor.submit( + save_images, + output_path, + times[visualization_sampler[time_index]], + visualization_dataset, + image_pred_viz.cpu().numpy(), + img_clean_viz.cpu().numpy(), + img_lr_viz.cpu().numpy(), + image_reg_viz.cpu().numpy() if image_reg_viz is not None else None, + ) + ) + # make sure all the workers are done writing + if dist.rank == 0: + for thread in list(writer_threads): + thread.result() + writer_threads.remove(thread) + writer_executor.shutdown() + + # Done. logger0.info("Training Completed.") diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 7698a82..5053448 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -84,7 +84,7 @@ def regression_step( if lead_time_label is not None: x = net(x=x_hat[0:1], img_lr=img_lr, lead_time_label=lead_time_label) else: - x = net(x=x_hat[0:1], img_lr=img_lr) + x = net(x=x_hat[0:1], img_lr=img_lr, force_fp32=False) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: @@ -201,6 +201,11 @@ def diffusion_step( return torch.cat(all_images) +############################################################################ +# Visualization Utilities # +############################################################################ + + def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): os.makedirs(output_path, exist_ok=True) @@ -211,7 +216,7 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, output_channels = dataset.output_channels() target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) - prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) + prediction = np.flip(dataset.denormalize_output(image_pred.squeeze()),-2) #.reshape(len(output_channels),-1) baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) if mean_pred is not None: mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) @@ -220,31 +225,60 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, freqs = {} power = {} for idx, channel in enumerate(output_channels): + channel_dir = channel.name + "_" + channel.level if channel.level else channel.name + output_path_channel = os.path.join(output_path, channel_dir) + if not os.path.exists(output_path_channel): + os.makedirs(output_path_channel) input_channel_idx = input_channels.index(channel) if channel.name=="tp": target[idx,::] = _prepare_precipitaiton(target[idx,:,:]) - prediction[idx,::] = _prepare_precipitaiton(prediction[idx,:,:]) - baseline[input_channel_idx,:,:] = _prepare_precipitaiton(baseline[input_channel_idx]) + prediction[:,idx,::] = _prepare_precipitaiton(prediction[:,idx,:,:]) + baseline[input_channel_idx,:,:] = _prepare_precipitaiton(baseline[input_channel_idx,::]) if mean_pred is not None: mean_pred[idx,::] = _prepare_precipitaiton(mean_pred[idx,::]) - - _plot_projection(longitudes, latitudes, target[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-target.jpg')) - _plot_projection(longitudes, latitudes, prediction[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-prediction.jpg')) - _plot_projection(longitudes, latitudes, baseline[input_channel_idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-input.jpg')) + + if mean_pred is not None: + vmin, vmax = calculate_bounds(target[idx,:,:], + prediction[:,idx,:,:], + baseline[input_channel_idx,:,:], + mean_pred[idx,:,:]) + else: + vmin, vmax = calculate_bounds(target[idx,:,:], + prediction[:,idx,:,:], + baseline[input_channel_idx,:,:]) + _plot_projection(longitudes, latitudes, target[idx,:,:], + os.path.join(output_path_channel, f'{time_step}-{channel.name}-target.jpg'), + vmin=vmin, vmax=vmax) if mean_pred is not None: - _plot_projection(longitudes, latitudes, mean_pred[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-mean_prediction.jpg')) + for member_idx in range(prediction.shape[0]): + _plot_projection(longitudes, latitudes, prediction[member_idx,idx,:,:], + os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction_{member_idx}.jpg'), + vmin=vmin, vmax=vmax) + else: + _plot_projection(longitudes, latitudes, + prediction[0,idx,:,:], os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction.jpg'), + vmin=vmin, vmax=vmax) + _plot_projection(longitudes, latitudes, baseline[input_channel_idx,:,:], + os.path.join(output_path_channel, f'{time_step}-{channel.name}-input.jpg'), + vmin=vmin, vmax=vmax) + if mean_pred is not None: + _plot_projection(longitudes, latitudes, mean_pred[idx,:,:], + os.path.join(output_path_channel, f'{time_step}-{channel.name}-mean_prediction.jpg'), + vmin=vmin, vmax=vmax) _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) - _, prediction_errors = compute_mae(prediction[idx,:,:], target[idx,:,:]) + plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-baseline-error.jpg')) if mean_pred is not None: - _, mean_prediction_errors = compute_mae(mean_pred[idx,:,:], target[idx,:,:]) - - - plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-baseline-error.jpg')) - plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-prediction-error.jpg')) + for member_idx in range(prediction.shape[0]): + _, prediction_errors = compute_mae(prediction[member_idx,idx,:,:], target[idx,:,:]) + plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction_{member_idx}-error.jpg')) + else: + _, prediction_errors = compute_mae(prediction[0,idx,:,:], target[idx,:,:]) + plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction-error.jpg')) if mean_pred is not None: - plot_error_projection(mean_prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-mean-prediction-error.jpg')) + _, mean_prediction_errors = compute_mae(mean_pred[idx,:,:], target[idx,:,:]) + plot_error_projection(mean_prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-mean-prediction-error.jpg')) b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0) freqs['baseline'] = b_freq @@ -253,21 +287,22 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, t_freq, t_power = average_power_spectrum(target[idx,:,:].squeeze(), 2.0) freqs['target'] = t_freq power['target'] = t_power - p_freq, p_power = average_power_spectrum(prediction[idx,:,:].squeeze(), 2.0) + p_freq, p_power = average_power_spectrum(prediction[-1,idx,:,:].squeeze(), 2.0) freqs['prediction'] = p_freq power['prediction'] = p_power if mean_pred is not None: mp_freq, mp_power = average_power_spectrum(mean_pred[idx,:,:].squeeze(), 2.0) freqs['mean_prediction'] = mp_freq power['mean_prediction'] = mp_power - plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) + plot_power_spectra(freqs, power, channel.name, os.path.join(output_path_channel, f'{time_step}-{channel.name}-spectra.jpg')) def _prepare_precipitaiton(precip_array): precip_array = np.clip(precip_array, 0, None) - epsilon = 1e-2 - precip_array = precip_array + epsilon - precip_array = np.log(precip_array) + precip_array = np.where(precip_array == 0, 1e-6, precip_array) + # epsilon = 1e-2 + # precip_array = precip_array + epsilon + precip_array = np.log10(precip_array) # log_min, log_max = precip_array.min(), precip_array.max() # precip_array = (precip_array-log_min)/(log_max-log_min) return precip_array @@ -284,4 +319,9 @@ def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array ax.gridlines(draw_labels=True) plt.colorbar(p, label="K", orientation="horizontal") plt.savefig(filename) - plt.close('all') \ No newline at end of file + plt.close('all') + +def calculate_bounds(*arrays: np.ndarray) -> tuple[float]: + vmin = min(*[np.min(array).item() for array in arrays]) + vmax = max(*[np.max(array).item() for array in arrays]) + return vmin, vmax \ No newline at end of file From 48ae52baa379425ce1f1aaad5d067861413d97b3 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 9 Jul 2025 14:05:43 +0200 Subject: [PATCH 066/189] fix calculating crps only for diffusion output --- .../conf/dataset/era_cosmo_inference.yaml | 2 ++ src/hirad/conf/generate_era_cosmo.yaml | 6 ++--- src/hirad/conf/generation/era_cosmo.yaml | 13 +++++----- .../conf/generation/era_cosmo_training.yaml | 18 +++++++++++++ .../conf/training/era_cosmo_diffusion.yaml | 15 +++++------ .../conf/training/era_cosmo_regression.yaml | 10 +++---- .../conf/training_era_cosmo_diffusion.yaml | 9 ++++--- .../conf/training_era_cosmo_regression.yaml | 9 ++++--- src/hirad/generate.sh | 13 +++++----- src/hirad/inference/generate.py | 1 + src/hirad/inference/generator.py | 7 +++-- src/hirad/train_diffusion.sh | 8 +++--- src/hirad/train_regression.sh | 8 +++--- src/hirad/training/train_dummy.py | 14 ++++++++++ src/hirad/utils/inference_utils.py | 26 +++++++++---------- 15 files changed, 98 insertions(+), 61 deletions(-) create mode 100644 src/hirad/conf/dataset/era_cosmo_inference.yaml create mode 100644 src/hirad/conf/generation/era_cosmo_training.yaml create mode 100644 src/hirad/training/train_dummy.py diff --git a/src/hirad/conf/dataset/era_cosmo_inference.yaml b/src/hirad/conf/dataset/era_cosmo_inference.yaml new file mode 100644 index 0000000..f819b18 --- /dev/null +++ b/src/hirad/conf/dataset/era_cosmo_inference.yaml @@ -0,0 +1,2 @@ +type: era5_cosmo +dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation \ No newline at end of file diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml index 90a5948..f9f629c 100644 --- a/src/hirad/conf/generate_era_cosmo.yaml +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -1,15 +1,15 @@ hydra: job: chdir: true - name: generation_regression_valid + name: diffusion_era5_cosmo_7500000_test run: - dir: /capstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + dir: /capstor/scratch/cscs/pstamenk/outputs/generation/${hydra:job.name} # Get defaults defaults: - _self_ # Dataset - - dataset/era_cosmo + - dataset/era_cosmo_inference # Sampler - sampler/stochastic diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index 2ccbb71..7061889 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -2,7 +2,7 @@ num_ensembles: 8 # Number of ensembles to generate per input seed_batch_size: 4 # Size of the batched inference -inference_mode: regression +inference_mode: all # Choose between "all" (regression + diffusion), "regression" or "diffusion" # Patch size. Patch-based sampling will be utilized if these dimensions differ from # img_shape_x and img_shape_y @@ -18,8 +18,8 @@ hr_mean_conditioning: True times_range: null times: - 20200926-1800 - # - 20160101-0600 - # - 20160101-1200 + - 20200927-0000 + has_laed_time: False perf: @@ -30,15 +30,16 @@ perf: # whether to use torch.compile on the diffusion model # this will make the first time stamp generation very slow due to compilation overheads # but will significantly speed up subsequent inference runs - num_writer_workers: 1 + num_writer_workers: 8 # number of workers to use for writing file # To support multiple workers a threadsafe version of the netCDF library must be used io: - res_ckpt_path: /capstor/scratch/cscs/pstamenk/diffusion_checkpoints + # res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/diffusion_full/checkpoints_diffusion + res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/diffusion_era5_cosmo/checkpoints_diffusion # res_ckpt_path: null # Checkpoint filename for the diffusion model - reg_ckpt_path: /capstor/scratch/cscs/pstamenk/regression_checkpoints + reg_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression # Checkpoint filename for the mean predictor model output_path: ./images \ No newline at end of file diff --git a/src/hirad/conf/generation/era_cosmo_training.yaml b/src/hirad/conf/generation/era_cosmo_training.yaml new file mode 100644 index 0000000..4374fc8 --- /dev/null +++ b/src/hirad/conf/generation/era_cosmo_training.yaml @@ -0,0 +1,18 @@ +defaults: + - ../sampler@sampler: stochastic + - ../dataset@dataset: era_cosmo_inference + +num_ensembles: 16 + # Number of ensembles to generate per input +# overlap_pixels: 0 + # Number of overlapping pixels between adjacent patches +# boundary_pixels: 0 + # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary + # artifact. +times_range: null +times: + - 20200926-1800 + - 20200927-0000 + +perf: + num_writer_workers: 10 \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index 07cbb03..e0d096c 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -1,8 +1,8 @@ # Hyperparameters hp: - training_duration: 5000000 + training_duration: 1000 # Training duration based on the number of processed samples - total_batch_size: 128 + total_batch_size: 64 # Total batch size batch_size_per_gpu: "auto" # Batch size per GPU @@ -29,15 +29,14 @@ perf: # I/O io: - regression_checkpoint_path: /capstor/scratch/cscs/boeschf/HiRAD-Gen/outputs_full/regression/checkpoints_regression/ + regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression # Where to load the regression checkpoint - print_progress_freq: 5000 + print_progress_freq: 128 # How often to print progress - save_checkpoint_freq: 250000 + save_checkpoint_freq: 512 # How often to save the checkpoints, measured in number of processed samples - validation_freq: 25000 + validation_freq: 256 # how often to record the validation loss, measured in number of processed samples - validation_steps: 4 - # how many loss evaluations are used to compute the validation loss per checkpoint + validation_steps: 2 # how many loss evaluations are used to compute the validation loss per checkpoint checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index 98c6c24..caa94ab 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,6 +1,6 @@ # Hyperparameters hp: - training_duration: 500000 + training_duration: 1000 # Training duration based on the number of processed samples total_batch_size: 64 # Total batch size @@ -31,12 +31,12 @@ perf: # I/O io: - print_progress_freq: 1024 + print_progress_freq: 128 # How often to print progress - save_checkpoint_freq: 25000 + save_checkpoint_freq: 512 # How often to save the checkpoints, measured in number of processed samples - validation_freq: 5000 + validation_freq: 256 # how often to record the validation loss, measured in number of processed samples - validation_steps: 10 + validation_steps: 2 # how many loss evaluations are used to compute the validation loss per checkpoint checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 0a069e9..7a38fbd 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -1,9 +1,9 @@ hydra: job: chdir: true - name: diffusion + name: diffusion_era5_cosmo_test run: - dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + dir: /capstor/scratch/cscs/pstamenk/outputs/training/${hydra:job.name} # Get defaults defaults: @@ -18,4 +18,7 @@ defaults: - model_size/normal # Training - - training/era_cosmo_diffusion \ No newline at end of file + - training/era_cosmo_diffusion + + # Inference visualization + - generation/era_cosmo_training \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml index 1de83d9..045eba4 100644 --- a/src/hirad/conf/training_era_cosmo_regression.yaml +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -1,9 +1,9 @@ hydra: job: chdir: true - name: regression + name: regression_era5_cosmo_test run: - dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + dir: /capstor/scratch/cscs/pstamenk/outputs/training/${hydra:job.name} # Get defaults defaults: @@ -18,4 +18,7 @@ defaults: - model_size/normal # Training - - training/era_cosmo_regression \ No newline at end of file + - training/era_cosmo_regression + + # Inference visualization + - generation/era_cosmo_training \ No newline at end of file diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index 3876ade..5ca8b97 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -5,18 +5,19 @@ ### HARDWARE ### #SBATCH --partition=debug #SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 #SBATCH --time=00:30:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/regression_generation.log -#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/regression_generation.err +#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/generation_diffusion_test.log +#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/generation_diffusion_test.err ### ENVIRONMENT #### -#SBATCH -A c38 +#SBATCH -A a122 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -46,7 +47,5 @@ export OMP_NUM_THREADS=72 srun --container-writable --environment=modulus_env bash -c " cd HiRAD-Gen pip install -e . --no-dependencies - pip install Cartopy==0.22.0 - pip install xskillscore python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index ae6d9d9..cea82a2 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -245,6 +245,7 @@ def elapsed_time(self, _): image_tar = image_tar.to(device=device).to(torch.float32) # image_out, image_reg = generate_fn(image_lr,lead_time_label) image_out, image_reg = generator.generate(image_lr,lead_time_label) + if dist.rank == 0: batch_size = image_out.shape[0] # write out data in a seperate thread so we don't hold up inferencing diff --git a/src/hirad/inference/generator.py b/src/hirad/inference/generator.py index b1e10e1..ad051a9 100644 --- a/src/hirad/inference/generator.py +++ b/src/hirad/inference/generator.py @@ -102,7 +102,7 @@ def generate(self, image_lr, lead_time_label=None): lead_time_label=lead_time_label, ) if self.inference_mode == "regression": - image_out = image_reg + image_out = image_reg[0:1,::] elif self.inference_mode == "diffusion": image_out = image_res else: @@ -130,11 +130,10 @@ def generate(self, image_lr, lead_time_label=None): if self.dist.rank == 0: if self.inference_mode != "regression": return torch.cat(gathered_tensors), image_reg[0:1,::] - return torch.cat(gathered_tensors), None + return torch.cat(gathered_tensors)[0:1,::], None else: return None, None else: - #TODO do this for multi-gpu setting above too if self.inference_mode != "regression": - return image_out, image_reg + return image_out, image_reg[0:1,::] return image_out, None diff --git a/src/hirad/train_diffusion.sh b/src/hirad/train_diffusion.sh index d8515db..badeb46 100644 --- a/src/hirad/train_diffusion.sh +++ b/src/hirad/train_diffusion.sh @@ -3,18 +3,18 @@ #SBATCH --job-name="corrdiff-second-stage" ### HARDWARE ### -#SBATCH --partition=normal +#SBATCH --partition=debug #SBATCH --nodes=2 #SBATCH --ntasks-per-node=4 #SBATCH --gpus-per-node=4 #SBATCH --cpus-per-task=72 -#SBATCH --time=24:00:00 +#SBATCH --time=00:30:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/diffusion_full.log -#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/diffusion_full.err +#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/training_diffusion_test.log +#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/training_diffusion_test.err ### ENVIRONMENT #### #SBATCH -A c38 diff --git a/src/hirad/train_regression.sh b/src/hirad/train_regression.sh index 7499bcf..31d8014 100644 --- a/src/hirad/train_regression.sh +++ b/src/hirad/train_regression.sh @@ -3,18 +3,18 @@ #SBATCH --job-name="corrdiff-first-stage" ### HARDWARE ### -#SBATCH --partition=normal +#SBATCH --partition=debug #SBATCH --nodes=2 #SBATCH --ntasks-per-node=4 #SBATCH --gpus-per-node=4 #SBATCH --cpus-per-task=72 -#SBATCH --time=06:00:00 +#SBATCH --time=00:30:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/regression_full_run.log -#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/regression_full_run.err +#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/training_regression_test.log +#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/training_regression_test.err ### ENVIRONMENT #### #SBATCH -A c38 diff --git a/src/hirad/training/train_dummy.py b/src/hirad/training/train_dummy.py new file mode 100644 index 0000000..cc47ea9 --- /dev/null +++ b/src/hirad/training/train_dummy.py @@ -0,0 +1,14 @@ +import hydra +from omegaconf import DictConfig, OmegaConf +import json + + +@hydra.main(version_base=None, config_path="../conf", config_name="training") +def main(cfg: DictConfig) -> None: + OmegaConf.resolve(cfg) + cfg = OmegaConf.to_container(cfg) + print(json.dumps(cfg, indent=2)) + # print(cfg.pretty()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index f3c9f74..ed50b05 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -215,10 +215,7 @@ def save_results_as_torch(output_path, time_step, dataset, image_pred, image_hr, # prediction.shape = (num_channels, X, Y) # prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) # prediction_ensemble.shape = (num_ensembles, num_channels, X, Y) - prediction_ensemble = np.ndarray(image_pred.shape) - for i in range(image_pred.shape[0]): - prediction_ensemble[i,::] = np.flip(dataset.denormalize_output(image_pred[i,::].squeeze()),1) - prediction_ensemble = np.flip(dataset.denormalize_output(image_pred.squeeze()),2) #.reshape(len(output_channels),-1) + prediction_ensemble = np.flip(dataset.denormalize_output(image_pred.squeeze()),-2) baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) if mean_pred is not None: mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) @@ -227,8 +224,8 @@ def save_results_as_torch(output_path, time_step, dataset, image_pred, image_hr, torch.save(baseline, os.path.join(output_path, f'{time_step}-baseline')) -def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): - +def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): + os.makedirs(output_path, exist_ok=True) longitudes = dataset.longitude() @@ -237,17 +234,18 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, output_channels = dataset.output_channels() target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) - prediction = np.flip(dataset.denormalize_output(image_pred.squeeze()),-2) #.reshape(len(output_channels),-1) + prediction = np.flip(dataset.denormalize_output(image_pred),-2) #.reshape(len(output_channels),-1) baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) if mean_pred is not None: mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) # Plot CRPS - crps_score = crps(prediction, target, average_over_area=False, average_over_channels=True) - _plot_projection(longitudes, latitudes, crps_score, os.path.join(output_path, f'{time_step}-crps-all.jpg')) - crps_score_channels = crps(prediction, target, average_over_area=False, average_over_channels=False) - for channel_num in range(crps_score_channels.shape[0]): - _plot_projection(longitudes, latitudes, crps_score_channels[channel_num,::], os.path.join(output_path, f'{time_step}-crps-{output_channels[channel_num].name}.jpg')) + if prediction.shape[0] > 1: + crps_score = crps(prediction, target, average_over_area=False, average_over_channels=True) + _plot_projection(longitudes, latitudes, crps_score, os.path.join(output_path, f'{time_step}-crps-all.jpg')) + crps_score_channels = crps(prediction, target, average_over_area=False, average_over_channels=False) + for channel_num in range(crps_score_channels.shape[0]): + _plot_projection(longitudes, latitudes, crps_score_channels[channel_num,::], os.path.join(output_path, f'{time_step}-crps-{output_channels[channel_num].name}.jpg')) # Plot power spectra freqs = {} @@ -278,7 +276,7 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, _plot_projection(longitudes, latitudes, target[idx,:,:], os.path.join(output_path_channel, f'{time_step}-{channel.name}-target.jpg'), vmin=vmin, vmax=vmax) - if mean_pred is not None: + if prediction.shape[0] > 1: for member_idx in range(prediction.shape[0]): _plot_projection(longitudes, latitudes, prediction[member_idx,idx,:,:], os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction_{member_idx}.jpg'), @@ -297,7 +295,7 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-baseline-error.jpg')) - if mean_pred is not None: + if prediction.shape[0] > 1: for member_idx in range(prediction.shape[0]): _, prediction_errors = compute_mae(prediction[member_idx,idx,:,:], target[idx,:,:]) plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path_channel, f'{time_step}-{channel.name}-prediction_{member_idx}-error.jpg')) From b92abde6e052981a85f7210ca167edf5d894154b Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 9 Jul 2025 15:46:08 +0200 Subject: [PATCH 067/189] environment files --- Dockerfile | 10 ++++++++++ modulus_env.toml | 10 ++++++++++ src/hirad/generate.sh | 5 ++--- 3 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 Dockerfile create mode 100644 modulus_env.toml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..93f389c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.06 + +# setup +RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade pip + +# Install the rest of dependencies. +RUN pip install \ + Cartopy==0.22.0 \ + xskillscore diff --git a/modulus_env.toml b/modulus_env.toml new file mode 100644 index 0000000..f44a7a7 --- /dev/null +++ b/modulus_env.toml @@ -0,0 +1,10 @@ +image = "/capstor/scratch/cscs/pstamenk/hirad.sqsh" + +mounts = ["/capstor", "/iopsstor”, “/users”] + +# The initial directory in the container. +workdir = "${PWD}" + +[annotations] +com.hooks.aws_ofi_nccl.enabled = "true" +com.hooks.aws_ofi_nccl.variant = "cuda12" \ No newline at end of file diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index 5ca8b97..57320bf 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -17,7 +17,7 @@ #SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/generation_diffusion_test.err ### ENVIRONMENT #### -#SBATCH -A a122 +#SBATCH -A c38 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -44,8 +44,7 @@ export OMP_NUM_THREADS=72 # echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun --container-writable --environment=modulus_env bash -c " - cd HiRAD-Gen +srun --environment=./modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file From 03386edeea614b44794d7b5a2381e302cbb4ef09 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 9 Jul 2025 18:09:46 +0200 Subject: [PATCH 068/189] small fix environment --- modulus_env.toml | 2 +- src/hirad/conf/training/era_cosmo_diffusion.yaml | 4 ++-- src/hirad/conf/training/era_cosmo_regression.yaml | 2 +- src/hirad/generate.sh | 2 +- src/hirad/inference/README.md | 0 src/hirad/train_diffusion.sh | 5 ++--- src/hirad/train_regression.sh | 5 ++--- 7 files changed, 9 insertions(+), 11 deletions(-) create mode 100644 src/hirad/inference/README.md diff --git a/modulus_env.toml b/modulus_env.toml index f44a7a7..55f43d8 100644 --- a/modulus_env.toml +++ b/modulus_env.toml @@ -1,6 +1,6 @@ image = "/capstor/scratch/cscs/pstamenk/hirad.sqsh" -mounts = ["/capstor", "/iopsstor”, “/users”] +mounts = ["/capstor", "/iopsstor", "/users"] # The initial directory in the container. workdir = "${PWD}" diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index e0d096c..5a14a6a 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -1,6 +1,6 @@ # Hyperparameters hp: - training_duration: 1000 + training_duration: 1024 # Training duration based on the number of processed samples total_batch_size: 64 # Total batch size @@ -33,7 +33,7 @@ io: # Where to load the regression checkpoint print_progress_freq: 128 # How often to print progress - save_checkpoint_freq: 512 + save_checkpoint_freq: 1024 # How often to save the checkpoints, measured in number of processed samples validation_freq: 256 # how often to record the validation loss, measured in number of processed samples diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index caa94ab..6d7216e 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,6 +1,6 @@ # Hyperparameters hp: - training_duration: 1000 + training_duration: 1024 # Training duration based on the number of processed samples total_batch_size: 64 # Total batch size diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index 57320bf..a9843e0 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -17,7 +17,7 @@ #SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/generation_diffusion_test.err ### ENVIRONMENT #### -#SBATCH -A c38 +#SBATCH -A a122 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM diff --git a/src/hirad/inference/README.md b/src/hirad/inference/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hirad/train_diffusion.sh b/src/hirad/train_diffusion.sh index badeb46..58e6ccb 100644 --- a/src/hirad/train_diffusion.sh +++ b/src/hirad/train_diffusion.sh @@ -17,7 +17,7 @@ #SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/training_diffusion_test.err ### ENVIRONMENT #### -#SBATCH -A c38 +#SBATCH -A a122 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -38,8 +38,7 @@ export MASTER_PORT=29500 export OMP_NUM_THREADS=72 # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun --container-writable --environment=modulus_env bash -c " - cd HiRAD-Gen +srun --environment=./modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml " \ No newline at end of file diff --git a/src/hirad/train_regression.sh b/src/hirad/train_regression.sh index 31d8014..8c681db 100644 --- a/src/hirad/train_regression.sh +++ b/src/hirad/train_regression.sh @@ -17,7 +17,7 @@ #SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/training_regression_test.err ### ENVIRONMENT #### -#SBATCH -A c38 +#SBATCH -A a122 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -41,8 +41,7 @@ export OMP_NUM_THREADS=72 # . ./train_env/bin/activate # python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml # " -srun --container-writable --environment=modulus_env bash -c " - cd HiRAD-Gen +srun --environment=./modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml " \ No newline at end of file From f7a0b4d624a29257e3b951b23486233158e1699e Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Fri, 11 Jul 2025 12:31:47 +0200 Subject: [PATCH 069/189] Make some paths relative --- src/hirad/conf/dataset/era_cosmo.yaml | 2 +- src/hirad/conf/generate_era_cosmo.yaml | 4 ++++ src/hirad/conf/generation/era_cosmo.yaml | 14 +++++++------- src/hirad/conf/training/era_cosmo_regression.yaml | 4 ++++ src/hirad/conf/training_era_cosmo_diffusion.yaml | 2 +- src/hirad/conf/training_era_cosmo_regression.yaml | 2 +- src/hirad/generate.sh | 5 +++++ src/hirad/train_regression.sh | 3 +-- 8 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml index 5d32f4e..beec8a2 100644 --- a/src/hirad/conf/dataset/era_cosmo.yaml +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -1,3 +1,3 @@ type: era5_cosmo -dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/train +dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation \ No newline at end of file diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml index f9f629c..cc42751 100644 --- a/src/hirad/conf/generate_era_cosmo.yaml +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -3,7 +3,11 @@ hydra: chdir: true name: diffusion_era5_cosmo_7500000_test run: +<<<<<<< Updated upstream dir: /capstor/scratch/cscs/pstamenk/outputs/generation/${hydra:job.name} +======= + dir: /capstor/scratch/cscs/mmcgloho/outputs/${hydra:job.name} +>>>>>>> Stashed changes # Get defaults defaults: diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index 7061889..3891b7b 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -15,12 +15,12 @@ patching: False hr_mean_conditioning: True # sample_res: full # Sampling resolution -times_range: null -times: - - 20200926-1800 - - 20200927-0000 - -has_laed_time: False +times_range: ['20200101-0000','20200102-0000',1] +times: null +# - 20200926-1800 + #- 20160101-0600 + # - 20160101-1200 +has_lead_time: False perf: force_fp16: False @@ -42,4 +42,4 @@ io: reg_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression # Checkpoint filename for the mean predictor model - output_path: ./images \ No newline at end of file + output_path: /capstor/scratch/cscs/mmcgloho/outputs/era-cosmo-1h/ \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index 6d7216e..db64726 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,6 +1,10 @@ # Hyperparameters hp: +<<<<<<< Updated upstream training_duration: 1024 +======= + training_duration: 1000 +>>>>>>> Stashed changes # Training duration based on the number of processed samples total_batch_size: 64 # Total batch size diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 7a38fbd..08c8d6a 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -15,7 +15,7 @@ defaults: # Model - model/era_cosmo_diffusion - - model_size/normal + - model_size/mini # Training - training/era_cosmo_diffusion diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml index 045eba4..83a4f94 100644 --- a/src/hirad/conf/training_era_cosmo_regression.yaml +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -15,7 +15,7 @@ defaults: # Model - model/era_cosmo_regression - - model_size/normal + - model_size/mini # Training - training/era_cosmo_regression diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index a9843e0..a27a727 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -13,8 +13,13 @@ #SBATCH --exclusive ### OUTPUT ### +<<<<<<< Updated upstream #SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/generation_diffusion_test.log #SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/generation_diffusion_test.err +======= +#SBATCH --output=/capstor/scratch/cscs/mmcgloho/logs/regression_generation.log +#SBATCH --error=/capstor/scratch/cscs/mmcgloho/logs/regression_generation.err +>>>>>>> Stashed changes ### ENVIRONMENT #### #SBATCH -A a122 diff --git a/src/hirad/train_regression.sh b/src/hirad/train_regression.sh index 8c681db..a75cc06 100644 --- a/src/hirad/train_regression.sh +++ b/src/hirad/train_regression.sh @@ -13,8 +13,7 @@ #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/training_regression_test.log -#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/training_regression_test.err +#SBATCH --output=./logs/regression_full_run.log ### ENVIRONMENT #### #SBATCH -A a122 From 0224692916323416c89771053b7903c51cd99ddb Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Fri, 11 Jul 2025 12:47:44 +0200 Subject: [PATCH 070/189] Fix errors from merge process --- src/hirad/conf/dataset/era_cosmo.yaml | 2 +- src/hirad/conf/generate_era_cosmo.yaml | 6 +----- src/hirad/conf/generation/era_cosmo.yaml | 2 +- src/hirad/conf/training/era_cosmo_regression.yaml | 6 +----- src/hirad/generate.sh | 12 ++++-------- 5 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml index beec8a2..5d32f4e 100644 --- a/src/hirad/conf/dataset/era_cosmo.yaml +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -1,3 +1,3 @@ type: era5_cosmo -dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation +dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/train validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation \ No newline at end of file diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml index cc42751..d448b58 100644 --- a/src/hirad/conf/generate_era_cosmo.yaml +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -3,11 +3,7 @@ hydra: chdir: true name: diffusion_era5_cosmo_7500000_test run: -<<<<<<< Updated upstream - dir: /capstor/scratch/cscs/pstamenk/outputs/generation/${hydra:job.name} -======= - dir: /capstor/scratch/cscs/mmcgloho/outputs/${hydra:job.name} ->>>>>>> Stashed changes + dir: ./outputs/generation/${hydra:job.name} # Get defaults defaults: diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index 3891b7b..c4f71dc 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -42,4 +42,4 @@ io: reg_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression # Checkpoint filename for the mean predictor model - output_path: /capstor/scratch/cscs/mmcgloho/outputs/era-cosmo-1h/ \ No newline at end of file + output_path: ./outputs/evaluation \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index db64726..b45b206 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,13 +1,9 @@ # Hyperparameters hp: -<<<<<<< Updated upstream training_duration: 1024 -======= - training_duration: 1000 ->>>>>>> Stashed changes # Training duration based on the number of processed samples total_batch_size: 64 - # Total batch size + # Total batch size -- based 8 per GPU -- 2 nodes is 2x8x4 -- see sbatch vars for how many gpus. diffusion need to point to the rgression. batch_size_per_gpu: "auto" # Batch size per GPU lr: 0.0002 diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index a27a727..adb9b30 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -13,13 +13,7 @@ #SBATCH --exclusive ### OUTPUT ### -<<<<<<< Updated upstream -#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/generation_diffusion_test.log -#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/generation_diffusion_test.err -======= -#SBATCH --output=/capstor/scratch/cscs/mmcgloho/logs/regression_generation.log -#SBATCH --error=/capstor/scratch/cscs/mmcgloho/logs/regression_generation.err ->>>>>>> Stashed changes +#SBATCH --output=./logs/regression_generation.log ### ENVIRONMENT #### #SBATCH -A a122 @@ -52,4 +46,6 @@ export OMP_NUM_THREADS=72 srun --environment=./modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml -" \ No newline at end of file +" + #pip install Cartopy==0.22.0 + #pip install xskillscore \ No newline at end of file From 9cf4d18e3692e834f3719ee38974649d03162610 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Fri, 11 Jul 2025 17:12:39 +0200 Subject: [PATCH 071/189] Move environments to ci/ directory --- .gitignore | 1 + ci/cscs.yml | 2 +- ci/docker/Dockerfile.ci | 15 +++++++++++++++ Dockerfile => ci/docker/Dockerfile.corrdiff | 0 modulus_env.toml => ci/edf/modulus_env.toml | 0 src/hirad/generate.sh | 2 +- 6 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 ci/docker/Dockerfile.ci rename Dockerfile => ci/docker/Dockerfile.corrdiff (100%) rename modulus_env.toml => ci/edf/modulus_env.toml (100%) diff --git a/.gitignore b/.gitignore index 7189efd..17c4ea5 100644 --- a/.gitignore +++ b/.gitignore @@ -175,6 +175,7 @@ pyrightconfig.json *.torch plots/* *.npz +outputs/* # conda .conda/* diff --git a/ci/cscs.yml b/ci/cscs.yml index fc92645..b96e65a 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -12,7 +12,7 @@ build_job: stage: build extends: .container-builder-cscs-gh200 variables: - DOCKERFILE: ci/docker/Dockerfile + DOCKERFILE: ci/docker/Dockerfile.ci #test_job: # stage: test diff --git a/ci/docker/Dockerfile.ci b/ci/docker/Dockerfile.ci new file mode 100644 index 0000000..80925d3 --- /dev/null +++ b/ci/docker/Dockerfile.ci @@ -0,0 +1,15 @@ +FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.06 + +# setup +RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade pip + +# Install the rest of dependencies. +RUN pip install \ + Cartopy==0.22.0 \ + xskillscore + + + + + diff --git a/Dockerfile b/ci/docker/Dockerfile.corrdiff similarity index 100% rename from Dockerfile rename to ci/docker/Dockerfile.corrdiff diff --git a/modulus_env.toml b/ci/edf/modulus_env.toml similarity index 100% rename from modulus_env.toml rename to ci/edf/modulus_env.toml diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index adb9b30..8f38326 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -43,7 +43,7 @@ export OMP_NUM_THREADS=72 # echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun --environment=./modulus_env.toml bash -c " +srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml " From 482938812490badd0e47f864e8b603b60d8ddf69 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Fri, 11 Jul 2025 17:16:28 +0200 Subject: [PATCH 072/189] Update paths for CRPS plotting --- src/hirad/utils/inference_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index ed50b05..8e95f1c 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -330,19 +330,20 @@ def plot_crps_over_time(times, dataset, output_path): start_time=times[0] end_time=times[-1] - prediction_ensemble = torch.load(os.path.join(output_path, f'{times[0]}-predictions'), weights_only=False) + # Load one prediction ensemble to get the shape + prediction_ensemble = torch.load(os.path.join(output_path, times[0], f'{times[0]}-predictions'), weights_only=False) all_predictions = np.ndarray((len(times), prediction_ensemble.shape[0], prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) all_targets = np.ndarray((len(times), prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) for i in range(len(times)): - prediction_ensemble = torch.load(os.path.join(output_path, f'{times[i]}-predictions'), weights_only=False) + prediction_ensemble = torch.load(os.path.join(output_path, times[i], f'{times[i]}-predictions'), weights_only=False) all_predictions[i,::] = prediction_ensemble - target = torch.load(os.path.join(output_path, f'{times[i]}-target'), weights_only=False) + target = torch.load(os.path.join(output_path, times[i], f'{times[i]}-target'), weights_only=False) all_targets[i,::] = target score_over_time_channels = crps(all_predictions, all_targets, average_over_area=True, average_over_channels=False, average_over_time=False) score_over_area_channels = crps(all_predictions, all_targets, average_over_area=False, average_over_channels=False, average_over_time=True) for channel_num in range(score_over_area_channels.shape[0]): - _plot_projection(longitudes, latitudes, score_over_area_channels[channel_num,::], os.path.join(output_path, f'crps-time-{start_time}-{end_time}-{output_channels[channel_num].name}.jpg')) - _plot_score_vs_t(score_over_time_channels[:, channel_num], times, os.path.join(output_path, f'crps-area-{start_time}-{end_time}-{output_channels[channel_num].name}.jpg')) + _plot_projection(longitudes, latitudes, score_over_area_channels[channel_num,::], os.path.join(output_path, f'crps-area-{start_time}-{end_time}-{output_channels[channel_num].name}.jpg')) + _plot_score_vs_t(score_over_time_channels[:, channel_num], times, os.path.join(output_path, f'crps-time-{start_time}-{end_time}-{output_channels[channel_num].name}.jpg')) def _plot_score_vs_t(score: np.array, times: np.array, filename: str): fig = plt.figure() From 07db9f397b7b563566c3e938bcea725017c7baf4 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 14 Jul 2025 09:48:29 +0200 Subject: [PATCH 073/189] =?UTF-8?q?Change=20slurm=20account=20to=20new=20p?= =?UTF-8?q?roject=20=F0=9F=8E=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/hirad/generate.sh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh index 8f38326..1e3e9f5 100644 --- a/src/hirad/generate.sh +++ b/src/hirad/generate.sh @@ -16,7 +16,7 @@ #SBATCH --output=./logs/regression_generation.log ### ENVIRONMENT #### -#SBATCH -A a122 +#SBATCH -A a161 # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -46,6 +46,4 @@ export OMP_NUM_THREADS=72 srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml -" - #pip install Cartopy==0.22.0 - #pip install xskillscore \ No newline at end of file +" \ No newline at end of file From 36953f84cc574173ccc6052042ed7acb0c79c2f8 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 10:04:53 +0200 Subject: [PATCH 074/189] Starting plotting script --- src/hirad/conf/compute_eval.yaml | 12 ++ src/hirad/eval/__init__.py | 4 +- src/hirad/eval/compute_eval.py | 63 +++++++++++ src/hirad/eval/plotting.py | 183 ++++++++++++++++++++++++++++++- 4 files changed, 257 insertions(+), 5 deletions(-) create mode 100644 src/hirad/conf/compute_eval.yaml create mode 100644 src/hirad/eval/compute_eval.py diff --git a/src/hirad/conf/compute_eval.yaml b/src/hirad/conf/compute_eval.yaml new file mode 100644 index 0000000..069ad1e --- /dev/null +++ b/src/hirad/conf/compute_eval.yaml @@ -0,0 +1,12 @@ +hydra: + job: + chdir: true + name: diffusion_era5_cosmo_7500000_test + run: + dir: ./outputs/generation/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + # Dataset + - dataset/era_cosmo_inference diff --git a/src/hirad/eval/__init__.py b/src/hirad/eval/__init__.py index a4b9475..db34154 100644 --- a/src/hirad/eval/__init__.py +++ b/src/hirad/eval/__init__.py @@ -1,2 +1,2 @@ -from .metrics import compute_mae, average_power_spectrum, crps -from .plotting import plot_error_projection, plot_power_spectra +from .metrics import absolute_error, compute_mae, average_power_spectrum, crps +from .plotting import plot_error_projection, plot_power_spectra, compute_crps_over_time, compute_crps_over_time_and_area, plot_crps_over_time_and_area diff --git a/src/hirad/eval/compute_eval.py b/src/hirad/eval/compute_eval.py new file mode 100644 index 0000000..a8c11d8 --- /dev/null +++ b/src/hirad/eval/compute_eval.py @@ -0,0 +1,63 @@ +import hydra +import os +import json +from omegaconf import OmegaConf, DictConfig +import torch +import torch._dynamo +import numpy as np +import contextlib + +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper +from concurrent.futures import ThreadPoolExecutor + +from hirad.eval import compute_crps_over_time, plot_crps_over_time_and_area, compute_crps_over_time_and_area +from hirad.models import EDMPrecondSuperResolution, UNet +from hirad.inference import Generator +from hirad.utils.inference_utils import save_images, save_results_as_torch +from hirad.utils.function_utils import get_time_from_range +from hirad.utils.checkpoint import load_checkpoint + +from hirad.datasets import get_dataset_and_sampler_inference + +from hirad.utils.train_helpers import set_patch_shape + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig) -> None: + + + # Initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + + + # Initialize logger + logger = PythonLogger("generate") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + + + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") + + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if "has_lead_time" in cfg.generation: + has_lead_time = cfg.generation["has_lead_time"] + else: + has_lead_time = False + dataset, sampler = get_dataset_and_sampler_inference( + dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time + ) + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + output_path = getattr(cfg.generation.io, "output_path", "./outputs") + + #plot_crps_over_time_and_area(times, dataset, output_path) + #compute_crps_over_time(times, dataset, output_path) + compute_crps_over_time_and_area(times, dataset, output_path) + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 1ca11c2..58d37a3 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -1,20 +1,197 @@ import logging +import os + +from hirad.eval import crps, absolute_error import cartopy.crs as ccrs import matplotlib.pyplot as plt import numpy as np +import torch -def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str): +def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str, label: str, vmin=None, vmax=None): + """Plot observed or interpolated data in a scatter plot.""" fig = plt.figure() fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) logging.info(f'plotting values to {filename}') - p = ax.scatter(x=longitudes, y=latitudes, c=values) + p = ax.scatter(x=longitudes, y=latitudes, c=values, vmin=vmin, vmax=vmax) ax.coastlines() ax.gridlines(draw_labels=True) - plt.colorbar(p, label="absolute error", orientation="horizontal") + plt.colorbar(p, label=label, orientation="horizontal") plt.savefig(filename) plt.close('all') +def compute_crps_over_time(times, dataset, output_path): + logging.info('computing crps') + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + start_time=times[0] + end_time=times[-1] + + # Load one prediction ensemble to get the shape + prediction_ensemble = torch.load(os.path.join(output_path, times[0], f'{times[0]}-predictions'), weights_only=False) + #all_predictions = np.ndarray((len(times), prediction_ensemble.shape[0], prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) + #all_targets = np.ndarray((len(times), prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) + + for i in range(len(times)): + logging.info(f'computing crps for {times[i]}') + prediction_ensemble = torch.load(os.path.join(output_path, times[i], f'{times[i]}-predictions'), weights_only=False) + baseline = torch.load(os.path.join(output_path, times[i], f'{times[i]}-baseline'), weights_only=False) + target = torch.load(os.path.join(output_path, times[i], f'{times[i]}-target'), weights_only=False) + + # Calculate CRPS + crps_area = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False) + + # Calculate ensemble mean error + ensemble_mean = np.mean(prediction_ensemble, 0) + ensemble_mean_error = absolute_error(ensemble_mean, target) + + # Calculate interpolation error (baseline #1) + interpolation_error = np.zeros(baseline.shape) + for j in range(len(output_channels)): + index = -1 + for k in range(len(input_channels)): + if input_channels[k].name == output_channels[j].name: + index = k + if k > -1: + interpolation_error[j,::] = baseline[k,::] - target[j,::] + + # Calculate persistence error (baseline #2) + persistence_error = np.zeros(baseline.shape) + if i > 0: + prev = torch.load(os.path.join(output_path, times[i-1], f'{times[i-1]}-target'), weights_only=False) + persistence_error = absolute_error(prev, target) + else: + persistence_error = absolute_error(target, target) + + + torch.save(crps_area, os.path.join(output_path, times[i], f'{times[i]}-crps')) + torch.save(ensemble_mean_error, os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error')) + torch.save(interpolation_error, os.path.join(output_path, times[i], f'{times[i]}-interpolation-error')) + torch.save(persistence_error, os.path.join(output_path, times[i], f'{times[i]}-persistence-error')) + +def plot_crps_over_time_and_area(times, dataset, output_path): + logging.info('plotting crps') + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + start_time=times[0] + end_time=times[-1] + + maxes = torch.load(os.path.join(output_path, f'crps-maxes-{start_time}-{end_time}'), weights_only=False) + mins = torch.load(os.path.join(output_path, f'crps-mins-{start_time}-{end_time}'), weights_only=False) + crps_time = torch.load(os.path.join(output_path, f'crps-time-{start_time}-{end_time}'), weights_only=False) + crps_area = torch.load(os.path.join(output_path, f'crps-area-{start_time}-{end_time}'), weights_only=False) + interpolation_time = torch.load(os.path.join(output_path, f'interpolation-time-{start_time}-{end_time}'), weights_only=False) + persistence_time = torch.load(os.path.join(output_path, f'persistence-time-{start_time}-{end_time}'), weights_only=False) + + + for j in range(crps_area.shape[0]): + plot_error_projection(crps_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name) + plot_error_projection(interpolation_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name) + plot_error_projection(crps_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name) + _plot_score_vs_t(crps_time[j,::], times, os.path.join(output_path, f'NEW-crps-time-{start_time}-{end_time}-{output_channels[j].name}.jpg')) + + for j in range(len(output_channels)): + for i in range(len(times)): + plot_error_projection(crps_area[j,::], latitudes, longitudes, + os.path.join(output_path, 'animations', output_channels[j].name, f'{times[i]}.jpg'), + label=output_channels[j].name) + + +def compute_crps_over_time_and_area(times, dataset, output_path): + logging.info('computing crps and errors') + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + start_time=times[0] + end_time=times[-1] + + logging.info('calculating min/max') + + crps_area = torch.load(os.path.join(output_path, times[0], f'{times[0]}-crps'), weights_only=False) + mins = np.ones((crps_area.shape[0])) * 999999999 + maxes = np.zeros((crps_area.shape[0])) + + for i in range(len(times)): + if i % (24*5) == 0: + logging.info(f'on time {times[i]}') + # Shape = channels, x, y + crps_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps'), weights_only=False) + for j in range(crps_area.shape[0]): + max = np.max(crps_area[j,::]) + min = np.min(crps_area[j,::]) + if max > maxes[j]: + maxes[j] = max + if min < mins[j]: + mins[j] = min + logging.info(f'maxes are {maxes}') + logging.info(f'mins are {mins}') + + torch.save(maxes, os.path.join(output_path, f'crps-maxes-{times[0]}-{times[len(times)-1]}')) + torch.save(mins, os.path.join(output_path, f'crps-mins-{times[0]}-{times[len(times)-1]}')) + + # make area and time plot + total_crps_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) + total_ensemble_mean_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) + total_interpolation_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) + total_persistence_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) + + crps_over_time = np.zeros((crps_area.shape[0], len(times))) + ensemble_mean_over_time = np.zeros((crps_area.shape[0], len(times))) + interpolation_over_time = np.zeros((crps_area.shape[0], len(times))) + persistence_over_time = np.zeros((crps_area.shape[0], len(times))) + for i in range(len(times)): + if i % (24*5) == 0: + logging.info(f'on time {times[i]}') + crps_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps'), weights_only=False) + total_crps_area = total_crps_area + crps_area + ensemble_mean_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error'), weights_only=False) + total_ensemble_mean_area = total_ensemble_mean_area + ensemble_mean_area + interpolation_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-interpolation-error'), weights_only=False) + total_interpolation_area = total_interpolation_area + interpolation_area + persistence_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-persistence-error'), weights_only=False) + total_persistence_area = total_persistence_area + persistence_area + + for j in range(crps_area.shape[0]): + crps_over_time[j,i] = np.mean(crps_area[j,::]) + ensemble_mean_over_time[j,i] = np.mean(ensemble_mean_area[j,::]) + interpolation_over_time[j,i] = np.mean(interpolation_area[j,::]) + persistence_over_time[j,i] = np.mean(persistence_area[j,::]) + mean_crps_area = total_crps_area / len(times) + mean_ensemble_mean_area = total_ensemble_mean_area / len(times) + mean_interpolation_area = total_interpolation_area / len(times) + mean_persistence_area = total_persistence_area / len(times) + torch.save(mean_crps_area, os.path.join(output_path, f'crps-area-{times[0]}-{times[len(times)-1]}')) + torch.save(mean_ensemble_mean_area, os.path.join(output_path, f'mae-ensemble-mean-area-{times[0]}-{times[len(times)-1]}')) + torch.save(mean_interpolation_area, os.path.join(output_path, f'mae-interpolation-area-{times[0]}-{times[len(times)-1]}')) + torch.save(mean_persistence_area, os.path.join(output_path, f'mae-persistence-area-{times[0]}-{times[len(times)-1]}')) + + + torch.save(crps_over_time, os.path.join(output_path, f'crps-time-{times[0]}-{times[len(times)-1]}')) + torch.save(ensemble_mean_over_time, os.path.join(output_path, f'mae-ensemble-mean-time-{times[0]}-{times[len(times)-1]}')) + torch.save(interpolation_over_time, os.path.join(output_path, f'mae-interpolation-time-{times[0]}-{times[len(times)-1]}')) + torch.save(persistence_over_time, os.path.join(output_path, f'mae-persistence-time-{times[0]}-{times[len(times)-1]}')) + +def _plot_score_vs_t(score: np.array, times: np.array, filename: str): + fig = plt.figure() + ax = plt.subplot() + p = plt.plot(times, score) + #plt.ylabel('CRPS') + #plt.xlabel('time') + plt.xticks([times[0],times[-1]]) + plt.savefig(filename) + plt.close('all') + + + + def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): fig = plt.figure() for k in freqs.keys(): From 872ab702781d6762a15ea2b01ccf5868dc0b7745 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 10:05:07 +0200 Subject: [PATCH 075/189] add time point to logging --- src/hirad/inference/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index cea82a2..d26a168 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -225,7 +225,7 @@ def elapsed_time(self, _): ): time_index += 1 if dist.rank == 0: - logger0.info(f"starting index: {time_index}") + logger0.info(f"starting index: {time_index} time: {times[sampler[time_index]]}") if time_index == warmup_steps: start.record() From 38d1d88d6da8759e12f20ac656d4adb7c88b0de0 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 10:05:19 +0200 Subject: [PATCH 076/189] eval script --- src/hirad/eval.sh | 49 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 src/hirad/eval.sh diff --git a/src/hirad/eval.sh b/src/hirad/eval.sh new file mode 100644 index 0000000..64bb995 --- /dev/null +++ b/src/hirad/eval.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=normal +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +##SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/eval_compute.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/eval/compute_eval.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file From 97434911b3479b6fb54f54fab8f654b07f2062ad Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 10:33:47 +0200 Subject: [PATCH 077/189] add ensemble mean MAE, and add a comparison plot over time --- src/hirad/eval/compute_eval.py | 4 +- src/hirad/eval/plotting.py | 133 +++++++++++++++++++-------------- 2 files changed, 78 insertions(+), 59 deletions(-) diff --git a/src/hirad/eval/compute_eval.py b/src/hirad/eval/compute_eval.py index a8c11d8..e611373 100644 --- a/src/hirad/eval/compute_eval.py +++ b/src/hirad/eval/compute_eval.py @@ -52,9 +52,9 @@ def main(cfg: DictConfig) -> None: img_out_channels = len(dataset.output_channels()) output_path = getattr(cfg.generation.io, "output_path", "./outputs") - #plot_crps_over_time_and_area(times, dataset, output_path) - #compute_crps_over_time(times, dataset, output_path) + compute_crps_over_time(times, dataset, output_path) compute_crps_over_time_and_area(times, dataset, output_path) + plot_crps_over_time_and_area(times, dataset, output_path) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 58d37a3..0924e3d 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -20,6 +20,44 @@ def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np. plt.savefig(filename) plt.close('all') +def _plot_score_vs_t(score: np.array, times: np.array, filename: str): + fig = plt.figure() + ax = plt.subplot() + p = plt.plot(times, score) + #plt.ylabel('CRPS') + #plt.xlabel('time') + plt.xticks([times[0],times[-1]]) + plt.savefig(filename) + plt.close('all') + +def plot_scores_vs_t(scores: dict[str,np.ndarray], times: np.array, filename: str): + fig = plt.figure() + ax = plt.subplot() + colors = {'red', 'green', 'blue', 'orange'} # TODO, add more + for k in scores.keys(): + p = ax.plot(times, scores[k], color=colors[k]) + p.set_label(k) + ax.legend() + #plt.ylabel('CRPS') + #plt.xlabel('time') + ax.set_xticks([times[0],times[-1]]) + plt.savefig(filename) + plt.close('all') + +def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): + fig = plt.figure() + for k in freqs.keys(): + plt.loglog(freqs[k], spec[k], label=k) + plt.title(channel_name) + plt.legend() + plt.xlabel("Frequency (1/km)") + plt.ylabel("Power Spectrum") + plt.ylim(bottom=1e-1) + #plt.psd(x) + logging.info(f'plotting values to {filename}') + plt.savefig(filename) + plt.close('all') + def compute_crps_over_time(times, dataset, output_path): logging.info('computing crps') longitudes = dataset.longitude() @@ -71,39 +109,6 @@ def compute_crps_over_time(times, dataset, output_path): torch.save(interpolation_error, os.path.join(output_path, times[i], f'{times[i]}-interpolation-error')) torch.save(persistence_error, os.path.join(output_path, times[i], f'{times[i]}-persistence-error')) -def plot_crps_over_time_and_area(times, dataset, output_path): - logging.info('plotting crps') - longitudes = dataset.longitude() - latitudes = dataset.latitude() - input_channels = dataset.input_channels() - output_channels = dataset.output_channels() - start_time=times[0] - end_time=times[-1] - - maxes = torch.load(os.path.join(output_path, f'crps-maxes-{start_time}-{end_time}'), weights_only=False) - mins = torch.load(os.path.join(output_path, f'crps-mins-{start_time}-{end_time}'), weights_only=False) - crps_time = torch.load(os.path.join(output_path, f'crps-time-{start_time}-{end_time}'), weights_only=False) - crps_area = torch.load(os.path.join(output_path, f'crps-area-{start_time}-{end_time}'), weights_only=False) - interpolation_time = torch.load(os.path.join(output_path, f'interpolation-time-{start_time}-{end_time}'), weights_only=False) - persistence_time = torch.load(os.path.join(output_path, f'persistence-time-{start_time}-{end_time}'), weights_only=False) - - - for j in range(crps_area.shape[0]): - plot_error_projection(crps_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), - label=output_channels[j].name) - plot_error_projection(interpolation_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), - label=output_channels[j].name) - plot_error_projection(crps_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), - label=output_channels[j].name) - _plot_score_vs_t(crps_time[j,::], times, os.path.join(output_path, f'NEW-crps-time-{start_time}-{end_time}-{output_channels[j].name}.jpg')) - - for j in range(len(output_channels)): - for i in range(len(times)): - plot_error_projection(crps_area[j,::], latitudes, longitudes, - os.path.join(output_path, 'animations', output_channels[j].name, f'{times[i]}.jpg'), - label=output_channels[j].name) - - def compute_crps_over_time_and_area(times, dataset, output_path): logging.info('computing crps and errors') longitudes = dataset.longitude() @@ -157,7 +162,8 @@ def compute_crps_over_time_and_area(times, dataset, output_path): interpolation_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-interpolation-error'), weights_only=False) total_interpolation_area = total_interpolation_area + interpolation_area persistence_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-persistence-error'), weights_only=False) - total_persistence_area = total_persistence_area + persistence_area + if i>0: + total_persistence_area = total_persistence_area + persistence_area for j in range(crps_area.shape[0]): crps_over_time[j,i] = np.mean(crps_area[j,::]) @@ -167,7 +173,7 @@ def compute_crps_over_time_and_area(times, dataset, output_path): mean_crps_area = total_crps_area / len(times) mean_ensemble_mean_area = total_ensemble_mean_area / len(times) mean_interpolation_area = total_interpolation_area / len(times) - mean_persistence_area = total_persistence_area / len(times) + mean_persistence_area = total_persistence_area / (len(times)-1) torch.save(mean_crps_area, os.path.join(output_path, f'crps-area-{times[0]}-{times[len(times)-1]}')) torch.save(mean_ensemble_mean_area, os.path.join(output_path, f'mae-ensemble-mean-area-{times[0]}-{times[len(times)-1]}')) torch.save(mean_interpolation_area, os.path.join(output_path, f'mae-interpolation-area-{times[0]}-{times[len(times)-1]}')) @@ -179,29 +185,42 @@ def compute_crps_over_time_and_area(times, dataset, output_path): torch.save(interpolation_over_time, os.path.join(output_path, f'mae-interpolation-time-{times[0]}-{times[len(times)-1]}')) torch.save(persistence_over_time, os.path.join(output_path, f'mae-persistence-time-{times[0]}-{times[len(times)-1]}')) -def _plot_score_vs_t(score: np.array, times: np.array, filename: str): - fig = plt.figure() - ax = plt.subplot() - p = plt.plot(times, score) - #plt.ylabel('CRPS') - #plt.xlabel('time') - plt.xticks([times[0],times[-1]]) - plt.savefig(filename) - plt.close('all') +def plot_crps_over_time_and_area(times, dataset, output_path): + logging.info('plotting crps and errors') + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + start_time=times[0] + end_time=times[-1] + maxes = torch.load(os.path.join(output_path, f'crps-maxes-{start_time}-{end_time}'), weights_only=False) + mins = torch.load(os.path.join(output_path, f'crps-mins-{start_time}-{end_time}'), weights_only=False) + crps_time = torch.load(os.path.join(output_path, f'crps-time-{start_time}-{end_time}'), weights_only=False) + crps_area = torch.load(os.path.join(output_path, f'crps-area-{start_time}-{end_time}'), weights_only=False) + ensemble_mean_time = torch.load(os.path.join(output_path, f'mae-ensemble-mean-time-{start_time}-{end_time}'), weights_only=False) + ensemble_mean_area = torch.load(os.path.join(output_path, f'mae-ensemble-mean-area-{start_time}-{end_time}'), weights_only=False) + interpolation_time = torch.load(os.path.join(output_path, f'mae-interpolation-time-{start_time}-{end_time}'), weights_only=False) + interpolation_area = torch.load(os.path.join(output_path, f'mae-interpolation-area-{start_time}-{end_time}'), weights_only=False) + persistence_time = torch.load(os.path.join(output_path, f'mae-persistence-time-{start_time}-{end_time}'), weights_only=False) + persistence_area = torch.load(os.path.join(output_path, f'mae-persistence-area-{start_time}-{end_time}'), weights_only=False) + for j in range(crps_area.shape[0]): + plot_error_projection(crps_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name) + plot_error_projection(ensemble_mean_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-ensemble-mean-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name) + plot_error_projection(interpolation_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-interpolation-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name) + plot_error_projection(persistence_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name) + + _plot_score_vs_t(crps_time[j,::], times, os.path.join(output_path, f'NEW-crps-time-{start_time}-{end_time}-{output_channels[j].name}.jpg')) -def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): - fig = plt.figure() - for k in freqs.keys(): - plt.loglog(freqs[k], spec[k], label=k) - plt.title(channel_name) - plt.legend() - plt.xlabel("Frequency (1/km)") - plt.ylabel("Power Spectrum") - plt.ylim(bottom=1e-1) - #plt.psd(x) - logging.info(f'plotting values to {filename}') - plt.savefig(filename) - plt.close('all') \ No newline at end of file + maes = {} + maes['ensemble mean'] = ensemble_mean_time[j,::] + maes['interpolation'] = interpolation_time[j,::] + maes['persistence'] = persistence_time[j,::] + plot_scores_vs_t(maes, times, os.path.join(output_path, f'NEW-mae-time-{start_time}-{end_time}-{output_channels[j].name}.jpg')) + From ebdec81e42785e7f53f255c2ba54caa0e4d5c692 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 13:12:23 +0200 Subject: [PATCH 078/189] Do some extra plotting tests to ensure that CRPS reduces to MAE --- src/hirad/eval/plotting.py | 121 +++++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 53 deletions(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 0924e3d..56cc69f 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -20,27 +20,22 @@ def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np. plt.savefig(filename) plt.close('all') -def _plot_score_vs_t(score: np.array, times: np.array, filename: str): +def plot_scores_vs_t(scores: dict[str,np.ndarray], times: np.array, filename: str, xlabel='', ylabel='', title=''): fig = plt.figure() ax = plt.subplot() - p = plt.plot(times, score) - #plt.ylabel('CRPS') - #plt.xlabel('time') - plt.xticks([times[0],times[-1]]) - plt.savefig(filename) - plt.close('all') - -def plot_scores_vs_t(scores: dict[str,np.ndarray], times: np.array, filename: str): - fig = plt.figure() - ax = plt.subplot() - colors = {'red', 'green', 'blue', 'orange'} # TODO, add more + colors = ['red', 'green', 'blue', 'orange'] # TODO, add more + i=0 for k in scores.keys(): - p = ax.plot(times, scores[k], color=colors[k]) + p, = ax.plot(times, scores[k], color=colors[i]) + i=i+1 p.set_label(k) ax.legend() #plt.ylabel('CRPS') #plt.xlabel('time') ax.set_xticks([times[0],times[-1]]) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) plt.savefig(filename) plt.close('all') @@ -72,39 +67,55 @@ def compute_crps_over_time(times, dataset, output_path): #all_predictions = np.ndarray((len(times), prediction_ensemble.shape[0], prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) #all_targets = np.ndarray((len(times), prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) + output_to_input_channel_map = {} + for j in range(len(output_channels)): + index = -1 + for k in range(len(input_channels)): + if input_channels[k].name == output_channels[j].name: + logging.info(f'found index of {j}:{output_channels[j].name} at {k}') + index = k + output_to_input_channel_map[j] = index + + for i in range(len(times)): - logging.info(f'computing crps for {times[i]}') prediction_ensemble = torch.load(os.path.join(output_path, times[i], f'{times[i]}-predictions'), weights_only=False) baseline = torch.load(os.path.join(output_path, times[i], f'{times[i]}-baseline'), weights_only=False) target = torch.load(os.path.join(output_path, times[i], f'{times[i]}-target'), weights_only=False) - # Calculate CRPS - crps_area = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False) - # Calculate ensemble mean error ensemble_mean = np.mean(prediction_ensemble, 0) ensemble_mean_error = absolute_error(ensemble_mean, target) # Calculate interpolation error (baseline #1) - interpolation_error = np.zeros(baseline.shape) + interpolation_error = np.zeros(target.shape) for j in range(len(output_channels)): - index = -1 - for k in range(len(input_channels)): - if input_channels[k].name == output_channels[j].name: - index = k + k = output_to_input_channel_map[j] if k > -1: - interpolation_error[j,::] = baseline[k,::] - target[j,::] + interpolation_error[j,::] = absolute_error(baseline[k,::], target[j,::]) # Calculate persistence error (baseline #2) persistence_error = np.zeros(baseline.shape) + prev=[] if i > 0: prev = torch.load(os.path.join(output_path, times[i-1], f'{times[i-1]}-target'), weights_only=False) persistence_error = absolute_error(prev, target) else: + # persist the next-time-point target-- this is fiction but it keeps the plots from looking weird. + prev = torch.load(os.path.join(output_path, times[i+1], f'{times[i+1]}-target'), weights_only=False) persistence_error = absolute_error(target, target) + + # Calculate CRPS + crps_diffusion_area = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False) + crps_interpolation_area = crps(np.broadcast_to(baseline, np.insert(baseline.shape, 0, 8)), target, average_over_area=False, average_over_channels=False) + crps_ensemble_mean_area = crps(np.broadcast_to(ensemble_mean, np.insert(ensemble_mean.shape, 0, 8)), target, average_over_area=False, average_over_channels=False) + crps_persistence_area = crps(np.broadcast_to(prev, np.insert(prev.shape, 0, 8)), target, average_over_area=False, average_over_channels=False) - torch.save(crps_area, os.path.join(output_path, times[i], f'{times[i]}-crps')) + + torch.save(crps_diffusion_area, os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble')) + torch.save(crps_interpolation_area, os.path.join(output_path, times[i], f'{times[i]}-crps-interpolation')) + torch.save(crps_ensemble_mean_area, os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble-mean')) + torch.save(crps_persistence_area, os.path.join(output_path, times[i], f'{times[i]}-crps-persistence')) torch.save(ensemble_mean_error, os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error')) torch.save(interpolation_error, os.path.join(output_path, times[i], f'{times[i]}-interpolation-error')) torch.save(persistence_error, os.path.join(output_path, times[i], f'{times[i]}-persistence-error')) @@ -121,26 +132,6 @@ def compute_crps_over_time_and_area(times, dataset, output_path): logging.info('calculating min/max') crps_area = torch.load(os.path.join(output_path, times[0], f'{times[0]}-crps'), weights_only=False) - mins = np.ones((crps_area.shape[0])) * 999999999 - maxes = np.zeros((crps_area.shape[0])) - - for i in range(len(times)): - if i % (24*5) == 0: - logging.info(f'on time {times[i]}') - # Shape = channels, x, y - crps_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps'), weights_only=False) - for j in range(crps_area.shape[0]): - max = np.max(crps_area[j,::]) - min = np.min(crps_area[j,::]) - if max > maxes[j]: - maxes[j] = max - if min < mins[j]: - mins[j] = min - logging.info(f'maxes are {maxes}') - logging.info(f'mins are {mins}') - - torch.save(maxes, os.path.join(output_path, f'crps-maxes-{times[0]}-{times[len(times)-1]}')) - torch.save(mins, os.path.join(output_path, f'crps-mins-{times[0]}-{times[len(times)-1]}')) # make area and time plot total_crps_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) @@ -149,14 +140,22 @@ def compute_crps_over_time_and_area(times, dataset, output_path): total_persistence_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) crps_over_time = np.zeros((crps_area.shape[0], len(times))) + crps_ensemble_mean_over_time = np.zeros((crps_area.shape[0], len(times))) + crps_persistence_over_time = np.zeros((crps_area.shape[0], len(times))) + crps_interpolation_over_time = np.zeros((crps_area.shape[0], len(times))) ensemble_mean_over_time = np.zeros((crps_area.shape[0], len(times))) interpolation_over_time = np.zeros((crps_area.shape[0], len(times))) persistence_over_time = np.zeros((crps_area.shape[0], len(times))) for i in range(len(times)): if i % (24*5) == 0: logging.info(f'on time {times[i]}') - crps_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps'), weights_only=False) + crps_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble'), weights_only=False) total_crps_area = total_crps_area + crps_area + + crps_ensemble_mean_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble-mean'), weights_only=False) + crps_interpolation_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-interpolation'), weights_only=False) + crps_persistence_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-persistence'), weights_only=False) + ensemble_mean_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error'), weights_only=False) total_ensemble_mean_area = total_ensemble_mean_area + ensemble_mean_area interpolation_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-interpolation-error'), weights_only=False) @@ -167,6 +166,9 @@ def compute_crps_over_time_and_area(times, dataset, output_path): for j in range(crps_area.shape[0]): crps_over_time[j,i] = np.mean(crps_area[j,::]) + crps_ensemble_mean_over_time[j,i] = np.mean(crps_ensemble_mean_area[j,::]) + crps_interpolation_over_time[j,i] = np.mean(crps_interpolation_area[j,::]) + crps_persistence_over_time[j,i] = np.mean(crps_persistence_area[j,::]) ensemble_mean_over_time[j,i] = np.mean(ensemble_mean_area[j,::]) interpolation_over_time[j,i] = np.mean(interpolation_area[j,::]) persistence_over_time[j,i] = np.mean(persistence_area[j,::]) @@ -174,13 +176,19 @@ def compute_crps_over_time_and_area(times, dataset, output_path): mean_ensemble_mean_area = total_ensemble_mean_area / len(times) mean_interpolation_area = total_interpolation_area / len(times) mean_persistence_area = total_persistence_area / (len(times)-1) - torch.save(mean_crps_area, os.path.join(output_path, f'crps-area-{times[0]}-{times[len(times)-1]}')) + torch.save(mean_crps_area, os.path.join(output_path, f'crps-ensemble-area-{times[0]}-{times[len(times)-1]}')) torch.save(mean_ensemble_mean_area, os.path.join(output_path, f'mae-ensemble-mean-area-{times[0]}-{times[len(times)-1]}')) torch.save(mean_interpolation_area, os.path.join(output_path, f'mae-interpolation-area-{times[0]}-{times[len(times)-1]}')) torch.save(mean_persistence_area, os.path.join(output_path, f'mae-persistence-area-{times[0]}-{times[len(times)-1]}')) + # Little hack to make the plots look nicer, without having to change dimensions. + persistence_over_time[:,0] = persistence_over_time[:,1] + + torch.save(crps_over_time, os.path.join(output_path, f'crps-ensemble-time-{times[0]}-{times[len(times)-1]}')) + torch.save(crps_ensemble_mean_over_time, os.path.join(output_path, f'crps-ensemble-mean-time-{times[0]}-{times[len(times)-1]}')) + torch.save(crps_interpolation_over_time, os.path.join(output_path, f'crps-interpolation-time-{times[0]}-{times[len(times)-1]}')) + torch.save(crps_persistence_over_time, os.path.join(output_path, f'crps-persistence-time-{times[0]}-{times[len(times)-1]}')) - torch.save(crps_over_time, os.path.join(output_path, f'crps-time-{times[0]}-{times[len(times)-1]}')) torch.save(ensemble_mean_over_time, os.path.join(output_path, f'mae-ensemble-mean-time-{times[0]}-{times[len(times)-1]}')) torch.save(interpolation_over_time, os.path.join(output_path, f'mae-interpolation-time-{times[0]}-{times[len(times)-1]}')) torch.save(persistence_over_time, os.path.join(output_path, f'mae-persistence-time-{times[0]}-{times[len(times)-1]}')) @@ -194,9 +202,10 @@ def plot_crps_over_time_and_area(times, dataset, output_path): start_time=times[0] end_time=times[-1] - maxes = torch.load(os.path.join(output_path, f'crps-maxes-{start_time}-{end_time}'), weights_only=False) - mins = torch.load(os.path.join(output_path, f'crps-mins-{start_time}-{end_time}'), weights_only=False) - crps_time = torch.load(os.path.join(output_path, f'crps-time-{start_time}-{end_time}'), weights_only=False) + crps_ensemble_time = torch.load(os.path.join(output_path, f'crps-time-{start_time}-{end_time}'), weights_only=False) + crps_ensemble_mean_time = torch.load(os.path.join(output_path, f'crps-ensemble-mean-time-{start_time}-{end_time}'), weights_only=False) + crps_interpolation_time = torch.load(os.path.join(output_path, f'crps-interpolation-time-{start_time}-{end_time}'), weights_only=False) + crps_persistence_time = torch.load(os.path.join(output_path, f'crps-persistence-time-{start_time}-{end_time}'), weights_only=False) crps_area = torch.load(os.path.join(output_path, f'crps-area-{start_time}-{end_time}'), weights_only=False) ensemble_mean_time = torch.load(os.path.join(output_path, f'mae-ensemble-mean-time-{start_time}-{end_time}'), weights_only=False) ensemble_mean_area = torch.load(os.path.join(output_path, f'mae-ensemble-mean-area-{start_time}-{end_time}'), weights_only=False) @@ -216,11 +225,17 @@ def plot_crps_over_time_and_area(times, dataset, output_path): plot_error_projection(persistence_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), label=output_channels[j].name) - _plot_score_vs_t(crps_time[j,::], times, os.path.join(output_path, f'NEW-crps-time-{start_time}-{end_time}-{output_channels[j].name}.jpg')) + crps_scores = {} + crps_scores['ensemble predictions'] = crps_ensemble_time[j,::] + crps_scores['ensemble mean'] = crps_ensemble_mean_time[j,::] + crps_scores['interpolation'] = crps_interpolation_time[j,::] + crps_scores['persistence'] = crps_persistence_time[j,::] + plot_scores_vs_t(crps_scores, times, os.path.join(output_path, f'NEW-crps-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), title=f'CRPS: {output_channels[j].name}', xlabel='time', ylabel='CRPS') maes = {} - maes['ensemble mean'] = ensemble_mean_time[j,::] maes['interpolation'] = interpolation_time[j,::] + maes['ensemble mean'] = ensemble_mean_time[j,::] + maes['crps'] = crps_ensemble_time[j,:] maes['persistence'] = persistence_time[j,::] - plot_scores_vs_t(maes, times, os.path.join(output_path, f'NEW-mae-time-{start_time}-{end_time}-{output_channels[j].name}.jpg')) + plot_scores_vs_t(maes, times, os.path.join(output_path, f'NEW-mae-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') From 02e8788665ae35ce46efaff8bf0d42832209f4c7 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 13:15:11 +0200 Subject: [PATCH 079/189] Remove superfluous plotting, now that we know CRPS reduvces to MAE --- src/hirad/eval/plotting.py | 35 ++--------------------------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 56cc69f..cbb0eb0 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -107,15 +107,8 @@ def compute_crps_over_time(times, dataset, output_path): # Calculate CRPS crps_diffusion_area = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False) - crps_interpolation_area = crps(np.broadcast_to(baseline, np.insert(baseline.shape, 0, 8)), target, average_over_area=False, average_over_channels=False) - crps_ensemble_mean_area = crps(np.broadcast_to(ensemble_mean, np.insert(ensemble_mean.shape, 0, 8)), target, average_over_area=False, average_over_channels=False) - crps_persistence_area = crps(np.broadcast_to(prev, np.insert(prev.shape, 0, 8)), target, average_over_area=False, average_over_channels=False) - torch.save(crps_diffusion_area, os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble')) - torch.save(crps_interpolation_area, os.path.join(output_path, times[i], f'{times[i]}-crps-interpolation')) - torch.save(crps_ensemble_mean_area, os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble-mean')) - torch.save(crps_persistence_area, os.path.join(output_path, times[i], f'{times[i]}-crps-persistence')) torch.save(ensemble_mean_error, os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error')) torch.save(interpolation_error, os.path.join(output_path, times[i], f'{times[i]}-interpolation-error')) torch.save(persistence_error, os.path.join(output_path, times[i], f'{times[i]}-persistence-error')) @@ -140,9 +133,6 @@ def compute_crps_over_time_and_area(times, dataset, output_path): total_persistence_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) crps_over_time = np.zeros((crps_area.shape[0], len(times))) - crps_ensemble_mean_over_time = np.zeros((crps_area.shape[0], len(times))) - crps_persistence_over_time = np.zeros((crps_area.shape[0], len(times))) - crps_interpolation_over_time = np.zeros((crps_area.shape[0], len(times))) ensemble_mean_over_time = np.zeros((crps_area.shape[0], len(times))) interpolation_over_time = np.zeros((crps_area.shape[0], len(times))) persistence_over_time = np.zeros((crps_area.shape[0], len(times))) @@ -152,10 +142,6 @@ def compute_crps_over_time_and_area(times, dataset, output_path): crps_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble'), weights_only=False) total_crps_area = total_crps_area + crps_area - crps_ensemble_mean_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble-mean'), weights_only=False) - crps_interpolation_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-interpolation'), weights_only=False) - crps_persistence_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-persistence'), weights_only=False) - ensemble_mean_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error'), weights_only=False) total_ensemble_mean_area = total_ensemble_mean_area + ensemble_mean_area interpolation_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-interpolation-error'), weights_only=False) @@ -166,9 +152,6 @@ def compute_crps_over_time_and_area(times, dataset, output_path): for j in range(crps_area.shape[0]): crps_over_time[j,i] = np.mean(crps_area[j,::]) - crps_ensemble_mean_over_time[j,i] = np.mean(crps_ensemble_mean_area[j,::]) - crps_interpolation_over_time[j,i] = np.mean(crps_interpolation_area[j,::]) - crps_persistence_over_time[j,i] = np.mean(crps_persistence_area[j,::]) ensemble_mean_over_time[j,i] = np.mean(ensemble_mean_area[j,::]) interpolation_over_time[j,i] = np.mean(interpolation_area[j,::]) persistence_over_time[j,i] = np.mean(persistence_area[j,::]) @@ -185,10 +168,6 @@ def compute_crps_over_time_and_area(times, dataset, output_path): persistence_over_time[:,0] = persistence_over_time[:,1] torch.save(crps_over_time, os.path.join(output_path, f'crps-ensemble-time-{times[0]}-{times[len(times)-1]}')) - torch.save(crps_ensemble_mean_over_time, os.path.join(output_path, f'crps-ensemble-mean-time-{times[0]}-{times[len(times)-1]}')) - torch.save(crps_interpolation_over_time, os.path.join(output_path, f'crps-interpolation-time-{times[0]}-{times[len(times)-1]}')) - torch.save(crps_persistence_over_time, os.path.join(output_path, f'crps-persistence-time-{times[0]}-{times[len(times)-1]}')) - torch.save(ensemble_mean_over_time, os.path.join(output_path, f'mae-ensemble-mean-time-{times[0]}-{times[len(times)-1]}')) torch.save(interpolation_over_time, os.path.join(output_path, f'mae-interpolation-time-{times[0]}-{times[len(times)-1]}')) torch.save(persistence_over_time, os.path.join(output_path, f'mae-persistence-time-{times[0]}-{times[len(times)-1]}')) @@ -202,10 +181,7 @@ def plot_crps_over_time_and_area(times, dataset, output_path): start_time=times[0] end_time=times[-1] - crps_ensemble_time = torch.load(os.path.join(output_path, f'crps-time-{start_time}-{end_time}'), weights_only=False) - crps_ensemble_mean_time = torch.load(os.path.join(output_path, f'crps-ensemble-mean-time-{start_time}-{end_time}'), weights_only=False) - crps_interpolation_time = torch.load(os.path.join(output_path, f'crps-interpolation-time-{start_time}-{end_time}'), weights_only=False) - crps_persistence_time = torch.load(os.path.join(output_path, f'crps-persistence-time-{start_time}-{end_time}'), weights_only=False) + crps_ensemble_time = torch.load(os.path.join(output_path, f'crps-ensemble-time-{start_time}-{end_time}'), weights_only=False) crps_area = torch.load(os.path.join(output_path, f'crps-area-{start_time}-{end_time}'), weights_only=False) ensemble_mean_time = torch.load(os.path.join(output_path, f'mae-ensemble-mean-time-{start_time}-{end_time}'), weights_only=False) ensemble_mean_area = torch.load(os.path.join(output_path, f'mae-ensemble-mean-area-{start_time}-{end_time}'), weights_only=False) @@ -225,17 +201,10 @@ def plot_crps_over_time_and_area(times, dataset, output_path): plot_error_projection(persistence_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), label=output_channels[j].name) - crps_scores = {} - crps_scores['ensemble predictions'] = crps_ensemble_time[j,::] - crps_scores['ensemble mean'] = crps_ensemble_mean_time[j,::] - crps_scores['interpolation'] = crps_interpolation_time[j,::] - crps_scores['persistence'] = crps_persistence_time[j,::] - plot_scores_vs_t(crps_scores, times, os.path.join(output_path, f'NEW-crps-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), title=f'CRPS: {output_channels[j].name}', xlabel='time', ylabel='CRPS') - maes = {} maes['interpolation'] = interpolation_time[j,::] maes['ensemble mean'] = ensemble_mean_time[j,::] maes['crps'] = crps_ensemble_time[j,:] maes['persistence'] = persistence_time[j,::] - plot_scores_vs_t(maes, times, os.path.join(output_path, f'NEW-mae-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') + plot_scores_vs_t(maes, times, os.path.join(output_path, f'NEW-error-plot-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') From 8401a248560171be472bb1517ad3718a8d9933af Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 13:23:29 +0200 Subject: [PATCH 080/189] fix filename --- src/hirad/eval/plotting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index cbb0eb0..2fbdab0 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -72,7 +72,6 @@ def compute_crps_over_time(times, dataset, output_path): index = -1 for k in range(len(input_channels)): if input_channels[k].name == output_channels[j].name: - logging.info(f'found index of {j}:{output_channels[j].name} at {k}') index = k output_to_input_channel_map[j] = index @@ -182,7 +181,7 @@ def plot_crps_over_time_and_area(times, dataset, output_path): end_time=times[-1] crps_ensemble_time = torch.load(os.path.join(output_path, f'crps-ensemble-time-{start_time}-{end_time}'), weights_only=False) - crps_area = torch.load(os.path.join(output_path, f'crps-area-{start_time}-{end_time}'), weights_only=False) + crps_area = torch.load(os.path.join(output_path, f'crps-ensemble-area-{start_time}-{end_time}'), weights_only=False) ensemble_mean_time = torch.load(os.path.join(output_path, f'mae-ensemble-mean-time-{start_time}-{end_time}'), weights_only=False) ensemble_mean_area = torch.load(os.path.join(output_path, f'mae-ensemble-mean-area-{start_time}-{end_time}'), weights_only=False) interpolation_time = torch.load(os.path.join(output_path, f'mae-interpolation-time-{start_time}-{end_time}'), weights_only=False) From e0f9de878889a86193063f3006ecfb0a1fde5c80 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 15:53:17 +0200 Subject: [PATCH 081/189] Move script methods into script --- src/hirad/eval/__init__.py | 2 +- src/hirad/eval/compute_eval.py | 169 +++++++++++++++- src/hirad/eval/crps.py | 348 --------------------------------- src/hirad/eval/plotting.py | 159 +-------------- 4 files changed, 169 insertions(+), 509 deletions(-) delete mode 100644 src/hirad/eval/crps.py diff --git a/src/hirad/eval/__init__.py b/src/hirad/eval/__init__.py index db34154..82e9073 100644 --- a/src/hirad/eval/__init__.py +++ b/src/hirad/eval/__init__.py @@ -1,2 +1,2 @@ from .metrics import absolute_error, compute_mae, average_power_spectrum, crps -from .plotting import plot_error_projection, plot_power_spectra, compute_crps_over_time, compute_crps_over_time_and_area, plot_crps_over_time_and_area +from .plotting import plot_error_projection, plot_power_spectra, plot_scores_vs_t \ No newline at end of file diff --git a/src/hirad/eval/compute_eval.py b/src/hirad/eval/compute_eval.py index e611373..305f640 100644 --- a/src/hirad/eval/compute_eval.py +++ b/src/hirad/eval/compute_eval.py @@ -1,4 +1,5 @@ import hydra +import logging import os import json from omegaconf import OmegaConf, DictConfig @@ -11,7 +12,7 @@ from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from concurrent.futures import ThreadPoolExecutor -from hirad.eval import compute_crps_over_time, plot_crps_over_time_and_area, compute_crps_over_time_and_area +from hirad.eval import absolute_error, crps, plot_scores_vs_t from hirad.models import EDMPrecondSuperResolution, UNet from hirad.inference import Generator from hirad.utils.inference_utils import save_images, save_results_as_torch @@ -52,10 +53,174 @@ def main(cfg: DictConfig) -> None: img_out_channels = len(dataset.output_channels()) output_path = getattr(cfg.generation.io, "output_path", "./outputs") - compute_crps_over_time(times, dataset, output_path) + compute_crps_per_time(times, dataset, output_path) compute_crps_over_time_and_area(times, dataset, output_path) plot_crps_over_time_and_area(times, dataset, output_path) + +def _get_data_path(output_path, time=None, filename=None): + return os.path.join(output_path, time, filename) + +def load_data(output_path, time=None, filename=None): + return torch.load(_get_data_path(output_path, time, filename), weights_only=False) + +def save_data(data, output_path, time=None, filename=None): + path = _get_data_path(output_path, time, filename) + torch.save(data, path) +def compute_crps_per_time(times, dataset, output_path): + logging.info('Computing CRPS for each time point') + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + start_time=times[0] + end_time=times[-1] + + # Load one prediction ensemble to get the shape + prediction_ensemble = torch.load(os.path.join(output_path, start_time, f'{start_time}-predictions'), weights_only=False) + + # Get a map of output to input channel, for building baseline errors + output_to_input_channel_map = {} + for j in range(len(output_channels)): + index = -1 + for k in range(len(input_channels)): + if input_channels[k].name == output_channels[j].name: + index = k + output_to_input_channel_map[j] = index + + + for i in range(len(times)): + curr_time = times[i] + if i % (24*5) == 0: + logging.info(f'on time {curr_time}') + prediction_ensemble = torch.load(os.path.join(output_path, curr_time, f'{curr_time}-predictions'), weights_only=False) + baseline = torch.load(os.path.join(output_path, curr_time, f'{curr_time}-baseline'), weights_only=False) + target = torch.load(os.path.join(output_path, curr_time, f'{curr_time}-target'), weights_only=False) + + # Calculate ensemble mean error + ensemble_mean = np.mean(prediction_ensemble, 0) + ensemble_mean_error = absolute_error(ensemble_mean, target) + + # Calculate interpolation error (baseline #1) + interpolation_error = np.zeros(target.shape) + for j in range(len(output_channels)): + k = output_to_input_channel_map[j] + if k > -1: + interpolation_error[j,::] = absolute_error(baseline[k,::], target[j,::]) + + # Calculate persistence error (baseline #2) + persistence_error = np.zeros(baseline.shape) + if i > 0: + prev = torch.load(os.path.join(output_path, times[i-1], f'{times[i-1]}-target'), weights_only=False) + persistence_error = absolute_error(prev, target) + else: + # for the first time point, persist the next-time-point target. + # This is fiction but it keeps the plots from looking weird. + prev = torch.load(os.path.join(output_path, times[i+1], f'{times[i+1]}-target'), weights_only=False) + persistence_error = absolute_error(prev, target) + + + # Calculate CRPS + crps_diffusion_area = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False) + + torch.save(crps_diffusion_area, os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble')) + torch.save(ensemble_mean_error, os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error')) + torch.save(interpolation_error, os.path.join(output_path, times[i], f'{times[i]}-interpolation-error')) + torch.save(persistence_error, os.path.join(output_path, times[i], f'{times[i]}-persistence-error')) + +def compute_crps_over_time_and_area(times, dataset, output_path): + logging.info('computing crps and errors') + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + start_time=times[0] + end_time=times[-1] + + logging.info('calculating min/max') + + crps_area = torch.load(os.path.join(output_path, times[0], f'{times[0]}-crps'), weights_only=False) + + # make area and time plot + total_crps_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) + total_ensemble_mean_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) + total_interpolation_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) + total_persistence_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) + + crps_over_time = np.zeros((crps_area.shape[0], len(times))) + ensemble_mean_over_time = np.zeros((crps_area.shape[0], len(times))) + interpolation_over_time = np.zeros((crps_area.shape[0], len(times))) + persistence_over_time = np.zeros((crps_area.shape[0], len(times))) + for i in range(len(times)): + if i % (24*5) == 0: + logging.info(f'on time {times[i]}') + crps_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble'), weights_only=False) + total_crps_area = total_crps_area + crps_area + + ensemble_mean_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error'), weights_only=False) + total_ensemble_mean_area = total_ensemble_mean_area + ensemble_mean_area + interpolation_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-interpolation-error'), weights_only=False) + total_interpolation_area = total_interpolation_area + interpolation_area + persistence_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-persistence-error'), weights_only=False) + if i>0: + total_persistence_area = total_persistence_area + persistence_area + + for j in range(crps_area.shape[0]): + crps_over_time[j,i] = np.mean(crps_area[j,::]) + ensemble_mean_over_time[j,i] = np.mean(ensemble_mean_area[j,::]) + interpolation_over_time[j,i] = np.mean(interpolation_area[j,::]) + persistence_over_time[j,i] = np.mean(persistence_area[j,::]) + mean_crps_area = total_crps_area / len(times) + mean_ensemble_mean_area = total_ensemble_mean_area / len(times) + mean_interpolation_area = total_interpolation_area / len(times) + mean_persistence_area = total_persistence_area / (len(times)-1) + torch.save(mean_crps_area, os.path.join(output_path, f'crps-ensemble-area-{times[0]}-{times[len(times)-1]}')) + torch.save(mean_ensemble_mean_area, os.path.join(output_path, f'mae-ensemble-mean-area-{times[0]}-{times[len(times)-1]}')) + torch.save(mean_interpolation_area, os.path.join(output_path, f'mae-interpolation-area-{times[0]}-{times[len(times)-1]}')) + torch.save(mean_persistence_area, os.path.join(output_path, f'mae-persistence-area-{times[0]}-{times[len(times)-1]}')) + + # Little hack to make the plots look nicer, without having to change dimensions. + persistence_over_time[:,0] = persistence_over_time[:,1] + + torch.save(crps_over_time, os.path.join(output_path, f'crps-ensemble-time-{times[0]}-{times[len(times)-1]}')) + torch.save(ensemble_mean_over_time, os.path.join(output_path, f'mae-ensemble-mean-time-{times[0]}-{times[len(times)-1]}')) + torch.save(interpolation_over_time, os.path.join(output_path, f'mae-interpolation-time-{times[0]}-{times[len(times)-1]}')) + torch.save(persistence_over_time, os.path.join(output_path, f'mae-persistence-time-{times[0]}-{times[len(times)-1]}')) + +def plot_crps_over_time_and_area(times, dataset, output_path): + logging.info('plotting crps and errors') + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + start_time=times[0] + end_time=times[-1] + + crps_ensemble_time = torch.load(os.path.join(output_path, f'crps-ensemble-time-{start_time}-{end_time}'), weights_only=False) + crps_area = torch.load(os.path.join(output_path, f'crps-ensemble-area-{start_time}-{end_time}'), weights_only=False) + ensemble_mean_time = torch.load(os.path.join(output_path, f'mae-ensemble-mean-time-{start_time}-{end_time}'), weights_only=False) + ensemble_mean_area = torch.load(os.path.join(output_path, f'mae-ensemble-mean-area-{start_time}-{end_time}'), weights_only=False) + interpolation_time = torch.load(os.path.join(output_path, f'mae-interpolation-time-{start_time}-{end_time}'), weights_only=False) + interpolation_area = torch.load(os.path.join(output_path, f'mae-interpolation-area-{start_time}-{end_time}'), weights_only=False) + persistence_time = torch.load(os.path.join(output_path, f'mae-persistence-time-{start_time}-{end_time}'), weights_only=False) + persistence_area = torch.load(os.path.join(output_path, f'mae-persistence-area-{start_time}-{end_time}'), weights_only=False) + + + for j in range(crps_area.shape[0]): + plot_error_projection(crps_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name, title=f'Mean absolute error: CRPS: {output_channels[j].name}') + plot_error_projection(ensemble_mean_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-ensemble-mean-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name, title=f'Mean absolute error: Ensemble mean: {output_channels[j].name}') + plot_error_projection(interpolation_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-interpolation-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name, title=f'Mean absolute error: Interpolation: {output_channels[j].name}') + plot_error_projection(persistence_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + label=output_channels[j].name, title=f'Mean absolute error: Persistence: {output_channels[j].name}') + + maes = {} + maes['interpolation'] = interpolation_time[j,::] + maes['ensemble mean'] = ensemble_mean_time[j,::] + maes['crps'] = crps_ensemble_time[j,:] + maes['persistence'] = persistence_time[j,::] + plot_scores_vs_t(maes, times, os.path.join(output_path, f'NEW-error-plot-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') + diff --git a/src/hirad/eval/crps.py b/src/hirad/eval/crps.py deleted file mode 100644 index 2ccdb42..0000000 --- a/src/hirad/eval/crps.py +++ /dev/null @@ -1,348 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union - -import numpy as np -import torch - -from .histogram import cdf as cdf_function - -Tensor = torch.Tensor - - -@torch.jit.script -def _kernel_crps_implementation(pred: Tensor, obs: Tensor, biased: bool) -> Tensor: - """An O(m log m) implementation of the kernel CRPS formulas""" - skill = torch.abs(pred - obs[..., None]).mean(-1) - pred, _ = torch.sort(pred) - - # derivation of fast implementation of spread-portion of CRPS formula when x is sorted - # sum_(i,j=1)^m |x_i - x_j| = sum_(i j) |x_i - x_j| - # = 2 sum_(i <= j) |x_i -x_j| - # = 2 sum_(i <= j) (x_j - x_i) - # = 2 sum_(i <= j) x_j - 2 sum_(i <= j) x_i - # = 2 sum_(j=1)^m j x_j - 2 sum (m - i + 1) x_i - # = 2 sum_(i=1)^m (2i - m - 1) x_i - m = pred.size(-1) - i = torch.arange(1, m + 1, device=pred.device, dtype=pred.dtype) - denom = m * m if biased else m * (m - 1) - factor = (2 * i - m - 1) / denom - spread = torch.sum(factor * pred, dim=-1) - return skill - spread - - -def kcrps(pred: Tensor, obs: Tensor, dim: int = 0, biased: bool = True): - """Estimate the CRPS from a finite ensemble - - Computes the local Continuous Ranked Probability Score (CRPS) by using - the kernel version of CRPS. The cost is O(m log m). - - Creates a map of CRPS and does not accumulate over lat/lon regions. - Approximates: - .. math:: - CRPS(X, y) = E[X - y] - 0.5 E[X-X'] - - with - .. math:: - sum_i=1^m |X_i - y| / m - 1/(2m^2) sum_i,j=1^m |x_i - x_j| - - Parameters - ---------- - pred : Tensor - Tensor containing the ensemble predictions. The ensemble dimension - is assumed to be the leading dimension unless 'dim' is specified. - obs : Union[Tensor, np.ndarray] - Tensor or array containing an observation over which the CRPS is computed - with respect to. - dim : int, optional - The dimension over which to compute the CRPS, assumed to be 0. - biased : - When False, uses the unbiased estimators described in (Zamo and Naveau, 2018):: - - E|X-y|/m - 1/(2m(m-1)) sum_(i,j=1)|x_i - x_j| - - Unlike ``crps`` this is fair for finite ensembles. Non-fair ``crps`` favors less - dispersive ensembles since it is biased high by E|X- X'|/ m where m is the - ensemble size. - - Returns - ------- - Tensor - Map of CRPS - """ - pred = torch.movedim(pred, dim, -1) - return _kernel_crps_implementation(pred, obs, biased=biased) - - -def _crps_gaussian(mean: Tensor, std: Tensor, obs: Union[Tensor, np.ndarray]) -> Tensor: - """ - Computes the local Continuous Ranked Probability Score (CRPS) - using assuming that the forecast distribution is normal. - - Creates a map of CRPS and does not accumulate over lat/lon regions. - - Computes: - - .. math:: - - CRPS(mean, std, y) = std * [ \\frac{1}{\\sqrt{\\pi}}} - 2 \\phi ( \\frac{x-mean}{std} ) - - ( \\frac{x-mean}{std} ) * (2 \\Phi(\\frac{x-mean}{std}) - 1) ] - - where \\phi and \\Phi are the normal gaussian pdf/cdf respectively. - - Parameters - ---------- - mean : Tensor - Tensor of mean of forecast distribution. - std : Tensor - Tensor of standard deviation of forecast distribution. - obs : Union[Tensor, np.ndarray] - Tensor or array containing an observation over which the CRPS is computed - with respect to. Broadcasting dimensions must be compatible with the non-zeroth - dimensions of bins and cdf. - - Returns - ------- - Tensor - Map of CRPS - """ - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).to(mean.device) - # Check shape compatibility - if mean.shape != std.shape: - raise ValueError( - "Mean and standard deviation must have" - + "compatible shapes but found" - + str(mean.shape) - + " and " - + str(std.shape) - + "." - ) - if mean.shape != obs.shape: - raise ValueError( - "Mean and obs must have" - + "compatible shapes but found" - + str(mean.shape) - + " and " - + str(obs.shape) - + "." - ) - - d = (obs - mean) / std - phi = torch.exp(-0.5 * d**2) / torch.sqrt(torch.as_tensor(2 * torch.pi)) - - # Note, simplified expression below is not exactly Gaussian CDF - Phi = torch.erf(d / torch.sqrt(torch.as_tensor(2.0))) - - return std * (2 * phi + d * Phi - 1.0 / torch.sqrt(torch.as_tensor(torch.pi))) - - -def _crps_from_cdf( - bin_edges: Tensor, cdf: Tensor, obs: Union[Tensor, np.ndarray] -) -> Tensor: - """Computes the local Continuous Ranked Probability Score (CRPS) - using a cumulative distribution function. - - Creates a map of CRPS and does not accumulate over lat/lon regions. - - Computes: - - .. math:: - - CRPS(X, y) = \\int[ (F(x) - 1[x - y])^2 ] dx - - where F is the empirical cdf of X. - - Parameters - ---------- - bins_edges : Tensor - Tensor [N+1, ...] containing bin edges. The leading dimension must represent the - N+1 bin edges. - cdf : Tensor - Tensor [N, ...] containing a cdf, defined over bins. The non-zeroth dimensions - of bins and cdf must be compatible. - obs : Union[Tensor, np.ndarray] - Tensor or array containing an observation over which the CRPS is computed - with respect to. Broadcasting dimensions must be compatible with the non-zeroth - dimensions of bins and cdf. - - Returns - ------- - Tensor - Map of CRPS - """ - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).to(cdf.device) - if bin_edges.shape[1:] != cdf.shape[1:]: - raise ValueError( - "Expected bins and cdf to have compatible non-zeroth dimensions but have shapes" - + str(bin_edges.shape[1:]) - + " and " - + str(cdf.shape[1:]) - + "." - ) - if bin_edges.shape[1:] != obs.shape: - raise ValueError( - "Expected bins and observations to have compatible broadcasting dimensions but have shapes" - + str(bin_edges.shape[1:]) - + " and " - + str(obs.shape) - + "." - ) - if bin_edges.shape[0] != cdf.shape[0] + 1: - raise ValueError( - "Expected zeroth dimension of cdf to be equal to the zeroth dimension of bins + 1 but have shapes" - + str(bin_edges.shape[0]) - + " and " - + str(cdf.shape[0]) - + "+1." - ) - dbins = bin_edges[1, ...] - bin_edges[0, ...] - bin_mids = 0.5 * (bin_edges[1:] + bin_edges[:-1]) - obs = torch.ge(bin_mids, obs).int() - return torch.sum(torch.abs(cdf - obs) ** 2 * dbins, dim=0) - - -def _crps_from_counts( - bin_edges: Tensor, counts: Tensor, obs: Union[Tensor, np.ndarray] -) -> Tensor: - """Computes the local Continuous Ranked Probability Score (CRPS) - using a histogram of counts. - - Creates a map of CRPS and does not accumulate over lat/lon regions. - - Computes: - - .. math:: - - CRPS(X, y) = int[ (F(x) - 1[x - y])^2 ] dx - - where F is the empirical cdf of X. - - Parameters - ---------- - bins_edges : Tensor - Tensor [N+1, ...] containing bin edges. The leading dimension must represent the - N+1 bin edges. - counts : Tensor - Tensor [N, ...] containing counts, defined over bins. The non-zeroth dimensions - of bins and counts must be compatible. - obs : Union[Tensor, np.ndarray] - Tensor or array containing an observation over which the CRPS is computed - with respect to. Broadcasting dimensions must be compatible with the non-zeroth - dimensions of bins and counts. - - Returns - ------- - Tensor - Map of CRPS - """ - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).to(counts.device) - if bin_edges.shape[1:] != counts.shape[1:]: - raise ValueError( - "Expected bins and cdf to have compatible non-zeroth dimensions but have shapes" - + str(bin_edges.shape[1:]) - + " and " - + str(counts.shape[1:]) - + "." - ) - if bin_edges.shape[1:] != obs.shape: - raise ValueError( - "Expected bins and observations to have compatible broadcasting dimensions but have shapes" - + str(bin_edges.shape[1:]) - + " and " - + str(obs.shape) - + "." - ) - if bin_edges.shape[0] != counts.shape[0] + 1: - raise ValueError( - "Expected zeroth dimension of cdf to be equal to the zeroth dimension of bins + 1 but have shapes" - + str(bin_edges.shape[0]) - + " and " - + str(counts.shape[0]) - + "+1." - ) - cdf_hat = torch.cumsum(counts / torch.sum(counts, dim=0), dim=0) - return _crps_from_cdf(bin_edges, cdf_hat, obs) - - -def crps( - pred: Tensor, obs: Union[Tensor, np.ndarray], dim: int = 0, method: str = "kernel" -) -> Tensor: - """ - Computes the local Continuous Ranked Probability Score (CRPS). - - Creates a map of CRPS and does not accumulate over any other dimensions (e.g., lat/lon regions). - - Parameters - ---------- - pred : Tensor - Tensor containing the ensemble predictions. - obs : Union[Tensor, np.ndarray] - Tensor or array containing an observation over which the CRPS is computed - with respect to. - dim : int, Optional - Dimension with which to calculate the CRPS over, the ensemble dimension. - Assumed to be zero. - method: str, Optional - The method to calculate the crps. Can either be "kernel", "sort" or "histogram". - - The "kernel" method implements - .. math:: - CRPS(x, y) = E[X-y] - 0.5*E[X-X'] - - This method scales as O(n^2) where n is the number of ensemble members and - can potentially induce large memory consumption as the algorithm attempts - to vectorize over this O(n^2) operation. - - The "sort" method compute the exact CRPS using the CDF method - .. math:: - CRPS(x, y) = int [F(x) - 1(x-y)]^2 dx - - where F is the empirical CDF and 1(x-y) = 1 if x > y. - - This method is more memory efficient than the kernel method, and uses O(n - log n) compute instead of O(n^2), where n is the number of ensemble members. - - The "histogram" method computes an approximate CRPS using the CDF method - .. math:: - CRPS(x, y) = int [F(x) - 1(x-y)]^2 dx - - where F is the empirical CDF, estimated via a histogram of the samples. The - number of bins used is the lesser of the square root of the number of samples - and 100. For more control over the implementation of this method consider using - `cdf_function` to construct a cdf and `_crps_from_cdf` to compute CRPS. - - Returns - ------- - Tensor - Map of CRPS - """ - if method not in ["kernel", "sort", "histogram"]: - raise ValueError("Method must either be 'kernel', 'sort' or 'histogram'.") - - n = pred.shape[dim] - obs = torch.as_tensor(obs, device=pred.device, dtype=pred.dtype) - if method in ["kernel", "sort"]: - return kcrps(pred, obs, dim=dim) - else: - pred = pred.unsqueeze(0).transpose(0, dim + 1).squeeze(dim + 1) - number_of_bins = max(int(np.sqrt(n)), 100) - bin_edges, cdf = cdf_function(pred, bins=number_of_bins) - _crps = _crps_from_cdf(bin_edges, cdf, obs) - return _crps \ No newline at end of file diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 2fbdab0..0d43714 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -8,7 +8,7 @@ import numpy as np import torch -def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str, label: str, vmin=None, vmax=None): +def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str, label: str, title='', vmin=None, vmax=None): """Plot observed or interpolated data in a scatter plot.""" fig = plt.figure() fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) @@ -30,8 +30,6 @@ def plot_scores_vs_t(scores: dict[str,np.ndarray], times: np.array, filename: st i=i+1 p.set_label(k) ax.legend() - #plt.ylabel('CRPS') - #plt.xlabel('time') ax.set_xticks([times[0],times[-1]]) plt.xlabel(xlabel) plt.ylabel(ylabel) @@ -52,158 +50,3 @@ def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): logging.info(f'plotting values to {filename}') plt.savefig(filename) plt.close('all') - -def compute_crps_over_time(times, dataset, output_path): - logging.info('computing crps') - longitudes = dataset.longitude() - latitudes = dataset.latitude() - input_channels = dataset.input_channels() - output_channels = dataset.output_channels() - start_time=times[0] - end_time=times[-1] - - # Load one prediction ensemble to get the shape - prediction_ensemble = torch.load(os.path.join(output_path, times[0], f'{times[0]}-predictions'), weights_only=False) - #all_predictions = np.ndarray((len(times), prediction_ensemble.shape[0], prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) - #all_targets = np.ndarray((len(times), prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) - - output_to_input_channel_map = {} - for j in range(len(output_channels)): - index = -1 - for k in range(len(input_channels)): - if input_channels[k].name == output_channels[j].name: - index = k - output_to_input_channel_map[j] = index - - - for i in range(len(times)): - prediction_ensemble = torch.load(os.path.join(output_path, times[i], f'{times[i]}-predictions'), weights_only=False) - baseline = torch.load(os.path.join(output_path, times[i], f'{times[i]}-baseline'), weights_only=False) - target = torch.load(os.path.join(output_path, times[i], f'{times[i]}-target'), weights_only=False) - - # Calculate ensemble mean error - ensemble_mean = np.mean(prediction_ensemble, 0) - ensemble_mean_error = absolute_error(ensemble_mean, target) - - # Calculate interpolation error (baseline #1) - interpolation_error = np.zeros(target.shape) - for j in range(len(output_channels)): - k = output_to_input_channel_map[j] - if k > -1: - interpolation_error[j,::] = absolute_error(baseline[k,::], target[j,::]) - - # Calculate persistence error (baseline #2) - persistence_error = np.zeros(baseline.shape) - prev=[] - if i > 0: - prev = torch.load(os.path.join(output_path, times[i-1], f'{times[i-1]}-target'), weights_only=False) - persistence_error = absolute_error(prev, target) - else: - # persist the next-time-point target-- this is fiction but it keeps the plots from looking weird. - prev = torch.load(os.path.join(output_path, times[i+1], f'{times[i+1]}-target'), weights_only=False) - persistence_error = absolute_error(target, target) - - - # Calculate CRPS - crps_diffusion_area = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False) - - torch.save(crps_diffusion_area, os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble')) - torch.save(ensemble_mean_error, os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error')) - torch.save(interpolation_error, os.path.join(output_path, times[i], f'{times[i]}-interpolation-error')) - torch.save(persistence_error, os.path.join(output_path, times[i], f'{times[i]}-persistence-error')) - -def compute_crps_over_time_and_area(times, dataset, output_path): - logging.info('computing crps and errors') - longitudes = dataset.longitude() - latitudes = dataset.latitude() - input_channels = dataset.input_channels() - output_channels = dataset.output_channels() - start_time=times[0] - end_time=times[-1] - - logging.info('calculating min/max') - - crps_area = torch.load(os.path.join(output_path, times[0], f'{times[0]}-crps'), weights_only=False) - - # make area and time plot - total_crps_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) - total_ensemble_mean_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) - total_interpolation_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) - total_persistence_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) - - crps_over_time = np.zeros((crps_area.shape[0], len(times))) - ensemble_mean_over_time = np.zeros((crps_area.shape[0], len(times))) - interpolation_over_time = np.zeros((crps_area.shape[0], len(times))) - persistence_over_time = np.zeros((crps_area.shape[0], len(times))) - for i in range(len(times)): - if i % (24*5) == 0: - logging.info(f'on time {times[i]}') - crps_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble'), weights_only=False) - total_crps_area = total_crps_area + crps_area - - ensemble_mean_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error'), weights_only=False) - total_ensemble_mean_area = total_ensemble_mean_area + ensemble_mean_area - interpolation_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-interpolation-error'), weights_only=False) - total_interpolation_area = total_interpolation_area + interpolation_area - persistence_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-persistence-error'), weights_only=False) - if i>0: - total_persistence_area = total_persistence_area + persistence_area - - for j in range(crps_area.shape[0]): - crps_over_time[j,i] = np.mean(crps_area[j,::]) - ensemble_mean_over_time[j,i] = np.mean(ensemble_mean_area[j,::]) - interpolation_over_time[j,i] = np.mean(interpolation_area[j,::]) - persistence_over_time[j,i] = np.mean(persistence_area[j,::]) - mean_crps_area = total_crps_area / len(times) - mean_ensemble_mean_area = total_ensemble_mean_area / len(times) - mean_interpolation_area = total_interpolation_area / len(times) - mean_persistence_area = total_persistence_area / (len(times)-1) - torch.save(mean_crps_area, os.path.join(output_path, f'crps-ensemble-area-{times[0]}-{times[len(times)-1]}')) - torch.save(mean_ensemble_mean_area, os.path.join(output_path, f'mae-ensemble-mean-area-{times[0]}-{times[len(times)-1]}')) - torch.save(mean_interpolation_area, os.path.join(output_path, f'mae-interpolation-area-{times[0]}-{times[len(times)-1]}')) - torch.save(mean_persistence_area, os.path.join(output_path, f'mae-persistence-area-{times[0]}-{times[len(times)-1]}')) - - # Little hack to make the plots look nicer, without having to change dimensions. - persistence_over_time[:,0] = persistence_over_time[:,1] - - torch.save(crps_over_time, os.path.join(output_path, f'crps-ensemble-time-{times[0]}-{times[len(times)-1]}')) - torch.save(ensemble_mean_over_time, os.path.join(output_path, f'mae-ensemble-mean-time-{times[0]}-{times[len(times)-1]}')) - torch.save(interpolation_over_time, os.path.join(output_path, f'mae-interpolation-time-{times[0]}-{times[len(times)-1]}')) - torch.save(persistence_over_time, os.path.join(output_path, f'mae-persistence-time-{times[0]}-{times[len(times)-1]}')) - -def plot_crps_over_time_and_area(times, dataset, output_path): - logging.info('plotting crps and errors') - longitudes = dataset.longitude() - latitudes = dataset.latitude() - input_channels = dataset.input_channels() - output_channels = dataset.output_channels() - start_time=times[0] - end_time=times[-1] - - crps_ensemble_time = torch.load(os.path.join(output_path, f'crps-ensemble-time-{start_time}-{end_time}'), weights_only=False) - crps_area = torch.load(os.path.join(output_path, f'crps-ensemble-area-{start_time}-{end_time}'), weights_only=False) - ensemble_mean_time = torch.load(os.path.join(output_path, f'mae-ensemble-mean-time-{start_time}-{end_time}'), weights_only=False) - ensemble_mean_area = torch.load(os.path.join(output_path, f'mae-ensemble-mean-area-{start_time}-{end_time}'), weights_only=False) - interpolation_time = torch.load(os.path.join(output_path, f'mae-interpolation-time-{start_time}-{end_time}'), weights_only=False) - interpolation_area = torch.load(os.path.join(output_path, f'mae-interpolation-area-{start_time}-{end_time}'), weights_only=False) - persistence_time = torch.load(os.path.join(output_path, f'mae-persistence-time-{start_time}-{end_time}'), weights_only=False) - persistence_area = torch.load(os.path.join(output_path, f'mae-persistence-area-{start_time}-{end_time}'), weights_only=False) - - - for j in range(crps_area.shape[0]): - plot_error_projection(crps_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), - label=output_channels[j].name) - plot_error_projection(ensemble_mean_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-ensemble-mean-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), - label=output_channels[j].name) - plot_error_projection(interpolation_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-interpolation-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), - label=output_channels[j].name) - plot_error_projection(persistence_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), - label=output_channels[j].name) - - maes = {} - maes['interpolation'] = interpolation_time[j,::] - maes['ensemble mean'] = ensemble_mean_time[j,::] - maes['crps'] = crps_ensemble_time[j,:] - maes['persistence'] = persistence_time[j,::] - plot_scores_vs_t(maes, times, os.path.join(output_path, f'NEW-error-plot-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') - From 2449db4a5d6c7b28d682672a280cce9c1ce25dbc Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 17:43:10 +0200 Subject: [PATCH 082/189] Cleanup eval script a bit --- src/hirad/eval/compute_eval.py | 120 ++++++++++++++++----------------- 1 file changed, 58 insertions(+), 62 deletions(-) diff --git a/src/hirad/eval/compute_eval.py b/src/hirad/eval/compute_eval.py index 305f640..b4570d5 100644 --- a/src/hirad/eval/compute_eval.py +++ b/src/hirad/eval/compute_eval.py @@ -4,7 +4,6 @@ import json from omegaconf import OmegaConf, DictConfig import torch -import torch._dynamo import numpy as np import contextlib @@ -12,7 +11,7 @@ from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from concurrent.futures import ThreadPoolExecutor -from hirad.eval import absolute_error, crps, plot_scores_vs_t +from hirad.eval import absolute_error, crps, plot_scores_vs_t, plot_error_projection from hirad.models import EDMPrecondSuperResolution, UNet from hirad.inference import Generator from hirad.utils.inference_utils import save_images, save_results_as_torch @@ -31,12 +30,9 @@ def main(cfg: DictConfig) -> None: DistributedManager.initialize() dist = DistributedManager() device = dist.device - # Initialize logger logger = PythonLogger("generate") # General python logger - logger0 = RankZeroLoggingWrapper(logger, dist) - if cfg.generation.times_range: times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") @@ -49,8 +45,6 @@ def main(cfg: DictConfig) -> None: dataset, sampler = get_dataset_and_sampler_inference( dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time ) - img_shape = dataset.image_shape() - img_out_channels = len(dataset.output_channels()) output_path = getattr(cfg.generation.io, "output_path", "./outputs") compute_crps_per_time(times, dataset, output_path) @@ -91,9 +85,9 @@ def compute_crps_per_time(times, dataset, output_path): curr_time = times[i] if i % (24*5) == 0: logging.info(f'on time {curr_time}') - prediction_ensemble = torch.load(os.path.join(output_path, curr_time, f'{curr_time}-predictions'), weights_only=False) - baseline = torch.load(os.path.join(output_path, curr_time, f'{curr_time}-baseline'), weights_only=False) - target = torch.load(os.path.join(output_path, curr_time, f'{curr_time}-target'), weights_only=False) + prediction_ensemble = load_data(output_path, time=curr_time, filename=f'{curr_time}-predictions') + baseline = load_data(output_path, time=curr_time, filename=f'{curr_time}-baseline') + target = load_data(output_path, time=curr_time, filename=f'{curr_time}-target') # Calculate ensemble mean error ensemble_mean = np.mean(prediction_ensemble, 0) @@ -107,63 +101,60 @@ def compute_crps_per_time(times, dataset, output_path): interpolation_error[j,::] = absolute_error(baseline[k,::], target[j,::]) # Calculate persistence error (baseline #2) - persistence_error = np.zeros(baseline.shape) + persistence_error = np.zeros(target.shape) if i > 0: - prev = torch.load(os.path.join(output_path, times[i-1], f'{times[i-1]}-target'), weights_only=False) + prev = load_data(output_path, time=times[i-1], filename=f'{times[i-1]}-target') persistence_error = absolute_error(prev, target) else: # for the first time point, persist the next-time-point target. # This is fiction but it keeps the plots from looking weird. - prev = torch.load(os.path.join(output_path, times[i+1], f'{times[i+1]}-target'), weights_only=False) + prev = load_data(output_path, time=times[i+1], filename=f'{times[i+1]}-target') persistence_error = absolute_error(prev, target) # Calculate CRPS crps_diffusion_area = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False) - torch.save(crps_diffusion_area, os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble')) - torch.save(ensemble_mean_error, os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error')) - torch.save(interpolation_error, os.path.join(output_path, times[i], f'{times[i]}-interpolation-error')) - torch.save(persistence_error, os.path.join(output_path, times[i], f'{times[i]}-persistence-error')) + save_data(crps_diffusion_area, output_path, time=curr_time, filename=f'{curr_time}-crps-ensemble') + save_data(ensemble_mean_error, output_path, time=curr_time, filename=f'{curr_time}-ensemble-mean-error') + save_data(interpolation_error, output_path, time=curr_time, filename=f'{curr_time}-interpolation-error') + save_data(persistence_error, output_path, time=curr_time, filename=f'{curr_time}-persistence-error') def compute_crps_over_time_and_area(times, dataset, output_path): logging.info('computing crps and errors') - longitudes = dataset.longitude() - latitudes = dataset.latitude() - input_channels = dataset.input_channels() - output_channels = dataset.output_channels() start_time=times[0] end_time=times[-1] - logging.info('calculating min/max') - - crps_area = torch.load(os.path.join(output_path, times[0], f'{times[0]}-crps'), weights_only=False) + # shape = (channels, x, y) + crps_area = load_data(output_path, time=start_time, filename=f'{start_time}-crps-ensemble') + num_channels = crps_area.shape[0] # make area and time plot - total_crps_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) - total_ensemble_mean_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) - total_interpolation_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) - total_persistence_area = np.zeros((crps_area.shape[0],crps_area.shape[1],crps_area.shape[2])) - - crps_over_time = np.zeros((crps_area.shape[0], len(times))) - ensemble_mean_over_time = np.zeros((crps_area.shape[0], len(times))) - interpolation_over_time = np.zeros((crps_area.shape[0], len(times))) - persistence_over_time = np.zeros((crps_area.shape[0], len(times))) + total_crps_area = np.zeros_like(crps_area) + total_ensemble_mean_area = np.zeros_like(crps_area) + total_interpolation_area = np.zeros_like(crps_area) + total_persistence_area = np.zeros_like(crps_area) + + crps_over_time = np.zeros((num_channels, len(times))) + ensemble_mean_over_time = np.zeros_like(crps_over_time) + interpolation_over_time = np.zeros_like(crps_over_time) + persistence_over_time = np.zeros_like(crps_over_time) for i in range(len(times)): + curr_time = times[i] if i % (24*5) == 0: logging.info(f'on time {times[i]}') - crps_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-crps-ensemble'), weights_only=False) + crps_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-crps-ensemble') total_crps_area = total_crps_area + crps_area - ensemble_mean_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-ensemble-mean-error'), weights_only=False) + ensemble_mean_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-ensemble-mean-error') total_ensemble_mean_area = total_ensemble_mean_area + ensemble_mean_area - interpolation_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-interpolation-error'), weights_only=False) + interpolation_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-interpolation-error') total_interpolation_area = total_interpolation_area + interpolation_area - persistence_area = torch.load(os.path.join(output_path, times[i], f'{times[i]}-persistence-error'), weights_only=False) + persistence_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-persistence-error') if i>0: total_persistence_area = total_persistence_area + persistence_area - for j in range(crps_area.shape[0]): + for j in range(num_channels): crps_over_time[j,i] = np.mean(crps_area[j,::]) ensemble_mean_over_time[j,i] = np.mean(ensemble_mean_area[j,::]) interpolation_over_time[j,i] = np.mean(interpolation_area[j,::]) @@ -172,54 +163,59 @@ def compute_crps_over_time_and_area(times, dataset, output_path): mean_ensemble_mean_area = total_ensemble_mean_area / len(times) mean_interpolation_area = total_interpolation_area / len(times) mean_persistence_area = total_persistence_area / (len(times)-1) - torch.save(mean_crps_area, os.path.join(output_path, f'crps-ensemble-area-{times[0]}-{times[len(times)-1]}')) - torch.save(mean_ensemble_mean_area, os.path.join(output_path, f'mae-ensemble-mean-area-{times[0]}-{times[len(times)-1]}')) - torch.save(mean_interpolation_area, os.path.join(output_path, f'mae-interpolation-area-{times[0]}-{times[len(times)-1]}')) - torch.save(mean_persistence_area, os.path.join(output_path, f'mae-persistence-area-{times[0]}-{times[len(times)-1]}')) + save_data(mean_crps_area, output_path, filename=f'crps-ensemble-area-{start_time}-{end_time}') + save_data(mean_ensemble_mean_area, output_path, filename=f'mae-ensemble-mean-area-{start_time}-{end_time}') + save_data(mean_interpolation_area, output_path, filename=f'mae-interpolation-area-{start_time}-{end_time}') + save_data(mean_persistence_area, output_path, filename=f'mae-persistence-area-{start_time}-{end_time}') # Little hack to make the plots look nicer, without having to change dimensions. persistence_over_time[:,0] = persistence_over_time[:,1] - torch.save(crps_over_time, os.path.join(output_path, f'crps-ensemble-time-{times[0]}-{times[len(times)-1]}')) - torch.save(ensemble_mean_over_time, os.path.join(output_path, f'mae-ensemble-mean-time-{times[0]}-{times[len(times)-1]}')) - torch.save(interpolation_over_time, os.path.join(output_path, f'mae-interpolation-time-{times[0]}-{times[len(times)-1]}')) - torch.save(persistence_over_time, os.path.join(output_path, f'mae-persistence-time-{times[0]}-{times[len(times)-1]}')) + save_data(crps_over_time, output_path, filename=f'crps-ensemble-time-{start_time}-{end_time}') + save_data(ensemble_mean_over_time, output_path, filename=f'mae-ensemble-mean-time-{start_time}-{end_time}') + save_data(interpolation_over_time, output_path, filename=f'mae-interpolation-time-{start_time}-{end_time}') + save_data(persistence_over_time, output_path, filename=f'mae-persistence-time-{start_time}-{end_time}') def plot_crps_over_time_and_area(times, dataset, output_path): logging.info('plotting crps and errors') longitudes = dataset.longitude() latitudes = dataset.latitude() - input_channels = dataset.input_channels() output_channels = dataset.output_channels() start_time=times[0] end_time=times[-1] - crps_ensemble_time = torch.load(os.path.join(output_path, f'crps-ensemble-time-{start_time}-{end_time}'), weights_only=False) - crps_area = torch.load(os.path.join(output_path, f'crps-ensemble-area-{start_time}-{end_time}'), weights_only=False) - ensemble_mean_time = torch.load(os.path.join(output_path, f'mae-ensemble-mean-time-{start_time}-{end_time}'), weights_only=False) - ensemble_mean_area = torch.load(os.path.join(output_path, f'mae-ensemble-mean-area-{start_time}-{end_time}'), weights_only=False) - interpolation_time = torch.load(os.path.join(output_path, f'mae-interpolation-time-{start_time}-{end_time}'), weights_only=False) - interpolation_area = torch.load(os.path.join(output_path, f'mae-interpolation-area-{start_time}-{end_time}'), weights_only=False) - persistence_time = torch.load(os.path.join(output_path, f'mae-persistence-time-{start_time}-{end_time}'), weights_only=False) - persistence_area = torch.load(os.path.join(output_path, f'mae-persistence-area-{start_time}-{end_time}'), weights_only=False) + crps_area = load_data(output_path, filename=f'crps-ensemble-area-{start_time}-{end_time}') + ensemble_mean_area = load_data(output_path, filename=f'mae-ensemble-mean-area-{start_time}-{end_time}') + interpolation_area = load_data(output_path, filename=f'mae-interpolation-area-{start_time}-{end_time}') + persistence_area = load_data(output_path, filename=f'mae-persistence-area-{start_time}-{end_time}') + crps_ensemble_time = load_data(output_path, filename=f'crps-ensemble-time-{start_time}-{end_time}') + ensemble_mean_time = load_data(output_path, filename=f'mae-ensemble-mean-time-{start_time}-{end_time}') + interpolation_time = load_data(output_path, filename=f'mae-interpolation-time-{start_time}-{end_time}') + persistence_time = load_data(output_path, filename=f'mae-persistence-time-{start_time}-{end_time}') - for j in range(crps_area.shape[0]): - plot_error_projection(crps_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + for j in range(len(output_channels)): + plot_error_projection(crps_area[j,::], latitudes, longitudes, + _get_data_path(output_path, filename=f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), label=output_channels[j].name, title=f'Mean absolute error: CRPS: {output_channels[j].name}') - plot_error_projection(ensemble_mean_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-ensemble-mean-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + plot_error_projection(ensemble_mean_area[j,::], latitudes, longitudes, + _get_data_path(output_path, filename=f'NEW-mae-ensemble-mean-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), label=output_channels[j].name, title=f'Mean absolute error: Ensemble mean: {output_channels[j].name}') - plot_error_projection(interpolation_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-interpolation-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + plot_error_projection(interpolation_area[j,::], latitudes, longitudes, + _get_data_path(output_path, filename=f'NEW-mae-interpolation-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), label=output_channels[j].name, title=f'Mean absolute error: Interpolation: {output_channels[j].name}') - plot_error_projection(persistence_area[j,::], latitudes, longitudes, os.path.join(output_path, f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + plot_error_projection(persistence_area[j,::], latitudes, longitudes, + _get_data_path(output_path, f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), label=output_channels[j].name, title=f'Mean absolute error: Persistence: {output_channels[j].name}') - + maes = {} maes['interpolation'] = interpolation_time[j,::] maes['ensemble mean'] = ensemble_mean_time[j,::] maes['crps'] = crps_ensemble_time[j,:] maes['persistence'] = persistence_time[j,::] - plot_scores_vs_t(maes, times, os.path.join(output_path, f'NEW-error-plot-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') + plot_scores_vs_t(maes, times, + _get_data_path(output_path, filename=f'NEW-error-plot-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') From 9aff70344d70cb89e640a975979d9312ddd67f75 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 17:54:34 +0200 Subject: [PATCH 083/189] Fully separate CRPS and MAE plots for long time ranges --- src/hirad/eval/compute_eval.py | 10 +++--- src/hirad/inference/generate.py | 7 +--- src/hirad/utils/inference_utils.py | 51 +++--------------------------- 3 files changed, 11 insertions(+), 57 deletions(-) diff --git a/src/hirad/eval/compute_eval.py b/src/hirad/eval/compute_eval.py index b4570d5..b64a55d 100644 --- a/src/hirad/eval/compute_eval.py +++ b/src/hirad/eval/compute_eval.py @@ -52,7 +52,10 @@ def main(cfg: DictConfig) -> None: plot_crps_over_time_and_area(times, dataset, output_path) def _get_data_path(output_path, time=None, filename=None): - return os.path.join(output_path, time, filename) + if time: + return os.path.join(output_path, time, filename) + else: + return os.path.join(output_path, filename) def load_data(output_path, time=None, filename=None): return torch.load(_get_data_path(output_path, time, filename), weights_only=False) @@ -66,7 +69,6 @@ def compute_crps_per_time(times, dataset, output_path): input_channels = dataset.input_channels() output_channels = dataset.output_channels() start_time=times[0] - end_time=times[-1] # Load one prediction ensemble to get the shape prediction_ensemble = torch.load(os.path.join(output_path, start_time, f'{start_time}-predictions'), weights_only=False) @@ -205,7 +207,7 @@ def plot_crps_over_time_and_area(times, dataset, output_path): _get_data_path(output_path, filename=f'NEW-mae-interpolation-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), label=output_channels[j].name, title=f'Mean absolute error: Interpolation: {output_channels[j].name}') plot_error_projection(persistence_area[j,::], latitudes, longitudes, - _get_data_path(output_path, f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), + _get_data_path(output_path, filename=f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'), label=output_channels[j].name, title=f'Mean absolute error: Persistence: {output_channels[j].name}') maes = {} @@ -218,7 +220,5 @@ def plot_crps_over_time_and_area(times, dataset, output_path): title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') - - if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index d26a168..2320656 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -13,7 +13,7 @@ from hirad.models import EDMPrecondSuperResolution, UNet from hirad.inference import Generator -from hirad.utils.inference_utils import save_images, save_results_as_torch, plot_crps_over_time +from hirad.utils.inference_utils import save_images, save_results_as_torch from hirad.utils.function_utils import get_time_from_range from hirad.utils.checkpoint import load_checkpoint @@ -301,11 +301,6 @@ def elapsed_time(self, _): f.close() logger0.info("Generation Completed.") - if cfg.generation.times_range: - # reassign times - times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") #TODO check what time formats we are using and adapt - plot_crps_over_time(times, dataset, output_path) - if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 8e95f1c..14be446 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -239,14 +239,6 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, if mean_pred is not None: mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) - # Plot CRPS - if prediction.shape[0] > 1: - crps_score = crps(prediction, target, average_over_area=False, average_over_channels=True) - _plot_projection(longitudes, latitudes, crps_score, os.path.join(output_path, f'{time_step}-crps-all.jpg')) - crps_score_channels = crps(prediction, target, average_over_area=False, average_over_channels=False) - for channel_num in range(crps_score_channels.shape[0]): - _plot_projection(longitudes, latitudes, crps_score_channels[channel_num,::], os.path.join(output_path, f'{time_step}-crps-{output_channels[channel_num].name}.jpg')) - # Plot power spectra freqs = {} power = {} @@ -258,11 +250,11 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, input_channel_idx = input_channels.index(channel) if channel.name=="tp": - target[idx,::] = _prepare_precipitaiton(target[idx,:,:]) - prediction[:,idx,::] = _prepare_precipitaiton(prediction[:,idx,:,:]) - baseline[input_channel_idx,:,:] = _prepare_precipitaiton(baseline[input_channel_idx,::]) + target[idx,::] = _prepare_precipitation(target[idx,:,:]) + prediction[:,idx,::] = _prepare_precipitation(prediction[:,idx,:,:]) + baseline[input_channel_idx,:,:] = _prepare_precipitation(baseline[input_channel_idx,::]) if mean_pred is not None: - mean_pred[idx,::] = _prepare_precipitaiton(mean_pred[idx,::]) + mean_pred[idx,::] = _prepare_precipitation(mean_pred[idx,::]) if mean_pred is not None: vmin, vmax = calculate_bounds(target[idx,:,:], @@ -322,40 +314,7 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, power['mean_prediction'] = mp_power plot_power_spectra(freqs, power, channel.name, os.path.join(output_path_channel, f'{time_step}-{channel.name}-spectra.jpg')) -def plot_crps_over_time(times, dataset, output_path): - longitudes = dataset.longitude() - latitudes = dataset.latitude() - input_channels = dataset.input_channels() - output_channels = dataset.output_channels() - start_time=times[0] - end_time=times[-1] - - # Load one prediction ensemble to get the shape - prediction_ensemble = torch.load(os.path.join(output_path, times[0], f'{times[0]}-predictions'), weights_only=False) - all_predictions = np.ndarray((len(times), prediction_ensemble.shape[0], prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) - all_targets = np.ndarray((len(times), prediction_ensemble.shape[1], prediction_ensemble.shape[2], prediction_ensemble.shape[3])) - for i in range(len(times)): - prediction_ensemble = torch.load(os.path.join(output_path, times[i], f'{times[i]}-predictions'), weights_only=False) - all_predictions[i,::] = prediction_ensemble - target = torch.load(os.path.join(output_path, times[i], f'{times[i]}-target'), weights_only=False) - all_targets[i,::] = target - score_over_time_channels = crps(all_predictions, all_targets, average_over_area=True, average_over_channels=False, average_over_time=False) - score_over_area_channels = crps(all_predictions, all_targets, average_over_area=False, average_over_channels=False, average_over_time=True) - for channel_num in range(score_over_area_channels.shape[0]): - _plot_projection(longitudes, latitudes, score_over_area_channels[channel_num,::], os.path.join(output_path, f'crps-area-{start_time}-{end_time}-{output_channels[channel_num].name}.jpg')) - _plot_score_vs_t(score_over_time_channels[:, channel_num], times, os.path.join(output_path, f'crps-time-{start_time}-{end_time}-{output_channels[channel_num].name}.jpg')) - -def _plot_score_vs_t(score: np.array, times: np.array, filename: str): - fig = plt.figure() - ax = plt.subplot() - p = plt.plot(times, score) - #plt.ylabel('CRPS') - #plt.xlabel('time') - plt.xticks([times[0],times[-1]]) - plt.savefig(filename) - plt.close('all') - -def _prepare_precipitaiton(precip_array): +def _prepare_precipitation(precip_array): precip_array = np.clip(precip_array, 0, None) precip_array = np.where(precip_array == 0, 1e-6, precip_array) # epsilon = 1e-2 From 20a2ad827503955d474edb478a474ca3ab4542f6 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 16 Jul 2025 18:37:55 +0200 Subject: [PATCH 084/189] Minor improvements to plotting --- src/hirad/eval/compute_eval.py | 8 ++++++-- src/hirad/eval/plotting.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/hirad/eval/compute_eval.py b/src/hirad/eval/compute_eval.py index b64a55d..8a01db6 100644 --- a/src/hirad/eval/compute_eval.py +++ b/src/hirad/eval/compute_eval.py @@ -6,6 +6,8 @@ import torch import numpy as np import contextlib +import datetime +from pandas import to_datetime from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper @@ -48,7 +50,7 @@ def main(cfg: DictConfig) -> None: output_path = getattr(cfg.generation.io, "output_path", "./outputs") compute_crps_per_time(times, dataset, output_path) - compute_crps_over_time_and_area(times, dataset, output_path) + compute_crps_over_time_and_area(times, output_path) plot_crps_over_time_and_area(times, dataset, output_path) def _get_data_path(output_path, time=None, filename=None): @@ -122,7 +124,7 @@ def compute_crps_per_time(times, dataset, output_path): save_data(interpolation_error, output_path, time=curr_time, filename=f'{curr_time}-interpolation-error') save_data(persistence_error, output_path, time=curr_time, filename=f'{curr_time}-persistence-error') -def compute_crps_over_time_and_area(times, dataset, output_path): +def compute_crps_over_time_and_area(times, output_path): logging.info('computing crps and errors') start_time=times[0] end_time=times[-1] @@ -215,6 +217,8 @@ def plot_crps_over_time_and_area(times, dataset, output_path): maes['ensemble mean'] = ensemble_mean_time[j,::] maes['crps'] = crps_ensemble_time[j,:] maes['persistence'] = persistence_time[j,::] + # TODO: consider casting times to datetime objects to avoid warnings. + # However, this seems to be working OK, and a direct cast causes plotting errors plot_scores_vs_t(maes, times, _get_data_path(output_path, filename=f'NEW-error-plot-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'), title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE') diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 0d43714..8d6a514 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -21,12 +21,19 @@ def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np. plt.close('all') def plot_scores_vs_t(scores: dict[str,np.ndarray], times: np.array, filename: str, xlabel='', ylabel='', title=''): + fig = plt.figure() ax = plt.subplot() - colors = ['red', 'green', 'blue', 'orange'] # TODO, add more + colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w'] # TODO, add more i=0 for k in scores.keys(): - p, = ax.plot(times, scores[k], color=colors[i]) + style = colors[i] + # If more than 50 points, don't connect lines + if len(times) > 50: + style = style + '.' + else: + style = style + '-' + p, = ax.plot(times, scores[k], style) i=i+1 p.set_label(k) ax.legend() From 734279f974964af2613bf632f407970d06ed5955 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 21 Jul 2025 10:13:42 +0200 Subject: [PATCH 085/189] Update paths to environment for training scripts --- src/hirad/train_diffusion.sh | 2 +- src/hirad/train_regression.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hirad/train_diffusion.sh b/src/hirad/train_diffusion.sh index 58e6ccb..10eb13f 100644 --- a/src/hirad/train_diffusion.sh +++ b/src/hirad/train_diffusion.sh @@ -38,7 +38,7 @@ export MASTER_PORT=29500 export OMP_NUM_THREADS=72 # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun --environment=./modulus_env.toml bash -c " +srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml " \ No newline at end of file diff --git a/src/hirad/train_regression.sh b/src/hirad/train_regression.sh index a75cc06..f6e276f 100644 --- a/src/hirad/train_regression.sh +++ b/src/hirad/train_regression.sh @@ -40,7 +40,7 @@ export OMP_NUM_THREADS=72 # . ./train_env/bin/activate # python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml # " -srun --environment=./modulus_env.toml bash -c " +srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml " \ No newline at end of file From a618c2f29e07be1e6ec32d818a97a727d387708c Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 21 Jul 2025 10:40:42 +0200 Subject: [PATCH 086/189] Make label opt argument --- src/hirad/eval/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 8d6a514..a7603b7 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -8,7 +8,7 @@ import numpy as np import torch -def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str, label: str, title='', vmin=None, vmax=None): +def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str, label='', title='', vmin=None, vmax=None): """Plot observed or interpolated data in a scatter plot.""" fig = plt.figure() fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) From 5fc44e63537df355838f69611ccb849bbc51043c Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 21 Jul 2025 18:03:25 +0200 Subject: [PATCH 087/189] Add test scripts/configs for regression, diffusion, and generation --- src/hirad/conf/generate_era_cosmo_test.yaml | 20 ++++++++ src/hirad/conf/generation/era_cosmo_test.yaml | 42 ++++++++++++++++ .../training_era_cosmo_diffusion_test.yaml | 24 ++++++++++ .../training_era_cosmo_regression_test.yaml | 25 ++++++++++ src/hirad/generate_test.sh | 48 +++++++++++++++++++ src/hirad/train_diffusion_test.sh | 42 ++++++++++++++++ src/hirad/train_regression_test.sh | 42 ++++++++++++++++ 7 files changed, 243 insertions(+) create mode 100644 src/hirad/conf/generate_era_cosmo_test.yaml create mode 100644 src/hirad/conf/generation/era_cosmo_test.yaml create mode 100644 src/hirad/conf/training_era_cosmo_diffusion_test.yaml create mode 100644 src/hirad/conf/training_era_cosmo_regression_test.yaml create mode 100644 src/hirad/generate_test.sh create mode 100644 src/hirad/train_diffusion_test.sh create mode 100644 src/hirad/train_regression_test.sh diff --git a/src/hirad/conf/generate_era_cosmo_test.yaml b/src/hirad/conf/generate_era_cosmo_test.yaml new file mode 100644 index 0000000..bb83890 --- /dev/null +++ b/src/hirad/conf/generate_era_cosmo_test.yaml @@ -0,0 +1,20 @@ +hydra: + job: + chdir: true + name: generation_era5_cosmo_test + run: + dir: ./outputs/generation/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + # Dataset + - dataset/era_cosmo_inference + + # Sampler + - sampler/stochastic + #- sampler/deterministic + + # Generation + - generation/era_cosmo_test + #- generation/patched_based \ No newline at end of file diff --git a/src/hirad/conf/generation/era_cosmo_test.yaml b/src/hirad/conf/generation/era_cosmo_test.yaml new file mode 100644 index 0000000..9e10eb4 --- /dev/null +++ b/src/hirad/conf/generation/era_cosmo_test.yaml @@ -0,0 +1,42 @@ +# TODO: See if there's a way to inherit from era_cosmo.yaml +num_ensembles: 8 + # Number of ensembles to generate per input +seed_batch_size: 4 + # Size of the batched inference +inference_mode: all + # Choose between "all" (regression + diffusion), "regression" or "diffusion" + # Patch size. Patch-based sampling will be utilized if these dimensions differ from + # img_shape_x and img_shape_y +# overlap_pixels: 0 + # Number of overlapping pixels between adjacent patches +# boundary_pixels: 0 + # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary + # artifact. +patching: False +hr_mean_conditioning: True +# sample_res: full + # Sampling resolution +times_range: ['20200131-0000','20200131-2300',1] +times: null +has_lead_time: False + +perf: + force_fp16: False + # Whether to force fp16 precision for the model. If false, it'll use the precision + # specified upon training. + use_torch_compile: False + # whether to use torch.compile on the diffusion model + # this will make the first time stamp generation very slow due to compilation overheads + # but will significantly speed up subsequent inference runs + num_writer_workers: 8 + # number of workers to use for writing file + # To support multiple workers a threadsafe version of the netCDF library must be used + +io: + res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/diffusion_test/checkpoints_diffusion + # Checkpoint filename for the diffusion model + reg_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_test/checkpoints_regression + # Checkpoint filename for the mean predictor model + output_path: ./outputs/evaluation + + diff --git a/src/hirad/conf/training_era_cosmo_diffusion_test.yaml b/src/hirad/conf/training_era_cosmo_diffusion_test.yaml new file mode 100644 index 0000000..2723050 --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_diffusion_test.yaml @@ -0,0 +1,24 @@ +hydra: + job: + chdir: true + name: diffusion_era5_cosmo_test + run: + dir: ./outputs/training/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_diffusion + + - model_size/mini + + # Training + - training/era_cosmo_diffusion + + # Inference visualization + - generation/era_cosmo_training \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression_test.yaml b/src/hirad/conf/training_era_cosmo_regression_test.yaml new file mode 100644 index 0000000..96e88fa --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_regression_test.yaml @@ -0,0 +1,25 @@ +hydra: + job: + chdir: true + name: regression_era5_cosmo_test + run: + dir: ./outputs/training/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_regression + + - model_size/mini + + # Training + # Leave same as prod + - training/era_cosmo_regression + + # Inference visualization + - generation/era_cosmo_training \ No newline at end of file diff --git a/src/hirad/generate_test.sh b/src/hirad/generate_test.sh new file mode 100644 index 0000000..0f12f9e --- /dev/null +++ b/src/hirad/generate_test.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH --job-name="corrdiff-test-genreate" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/generation_test.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/inference/generate.py --config-name=generate_era_cosmo_test.yaml +" \ No newline at end of file diff --git a/src/hirad/train_diffusion_test.sh b/src/hirad/train_diffusion_test.sh new file mode 100644 index 0000000..9c9831f --- /dev/null +++ b/src/hirad/train_diffusion_test.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +#SBATCH --job-name="corrdiff-test-second-stage" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/training_diffusion_test.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +# Get master node. +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +export MASTER_ADDR +export MASTER_PORT=29500 + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute cores per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion_test.yaml +" \ No newline at end of file diff --git a/src/hirad/train_regression_test.sh b/src/hirad/train_regression_test.sh new file mode 100644 index 0000000..4986fc0 --- /dev/null +++ b/src/hirad/train_regression_test.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +#SBATCH --job-name="corrdiff-test-first-stage" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/training_regression_test.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +# Get master node. +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +export MASTER_ADDR +export MASTER_PORT=29500 + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute cores per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/training/train.py --config-name=training_era_cosmo_regression_test.yaml +" \ No newline at end of file From 284f778ee476aa79ffd3ecbd9e2ef731ba6d58bd Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 21 Jul 2025 18:16:56 +0200 Subject: [PATCH 088/189] Update CI to include regression test script --- ci/cscs.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index b96e65a..1be60ea 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -14,12 +14,12 @@ build_job: variables: DOCKERFILE: ci/docker/Dockerfile.ci -#test_job: -# stage: test -# extends: .container-runner-clariden-gh200 -# image: $PERSIST_IMAGE_NAME -# script: -# - /opt/helloworld/bin/hello -# variables: -# SLURM_JOB_NUM_NODES: 2 -# SLURM_NTASKS: 2 +test_job: + stage: test + extends: .container-runner-clariden-gh200 + image: $PERSIST_IMAGE_NAME + script: + - hirad/train_regression_test.sh + variables: + SLURM_JOB_NUM_NODES: 2 + SLURM_NTASKS: 2 From 995e888a8aefaede077ec7e5ffd136b3ebe19ee6 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 24 Jul 2025 15:58:38 +0200 Subject: [PATCH 089/189] Test script should not be srun if it's going to work with ci/cd --- src/hirad/train_regression_test.sh | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/src/hirad/train_regression_test.sh b/src/hirad/train_regression_test.sh index 4986fc0..6cc0468 100644 --- a/src/hirad/train_regression_test.sh +++ b/src/hirad/train_regression_test.sh @@ -1,23 +1,8 @@ #!/bin/bash -#SBATCH --job-name="corrdiff-test-first-stage" - -### HARDWARE ### -#SBATCH --partition=debug -#SBATCH --nodes=2 -#SBATCH --ntasks-per-node=4 -#SBATCH --gpus-per-node=4 -#SBATCH --cpus-per-task=72 -#SBATCH --time=00:30:00 -#SBATCH --no-requeue -#SBATCH --exclusive - ### OUTPUT ### #SBATCH --output=./logs/training_regression_test.log -### ENVIRONMENT #### -#SBATCH -A a161 - # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM @@ -36,7 +21,5 @@ export MASTER_PORT=29500 # export OMP_NUM_THREADS=$OMP_THREADS export OMP_NUM_THREADS=72 -srun --environment=./ci/edf/modulus_env.toml bash -c " - pip install -e . --no-dependencies - python src/hirad/training/train.py --config-name=training_era_cosmo_regression_test.yaml -" \ No newline at end of file +pip install -e . --no-dependencies +python src/hirad/training/train.py --config-name=training_era_cosmo_regression_test.yaml \ No newline at end of file From ffc1bedf62628e31050ebf351bd59967f0f55621 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 24 Jul 2025 16:20:27 +0200 Subject: [PATCH 090/189] Try to directly run from cscs.yml --- ci/cscs.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index 1be60ea..1668b81 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -19,7 +19,8 @@ test_job: extends: .container-runner-clariden-gh200 image: $PERSIST_IMAGE_NAME script: - - hirad/train_regression_test.sh + - pip install -e . --no-dependencies + - python src/hirad/training/train.py --config-name=training_era_cosmo_regression_test.yaml variables: SLURM_JOB_NUM_NODES: 2 SLURM_NTASKS: 2 From ebc104308bc92c7fbf5bccfe5df602aa56928768 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 24 Jul 2025 18:45:28 +0200 Subject: [PATCH 091/189] plot diurnal cycles of precip amount and wet-hour frequency over time period --- src/hirad/diurnal_cycle.sh | 49 +++++++++ src/hirad/eval/diurnal_cycle_precip.py | 147 +++++++++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 src/hirad/diurnal_cycle.sh create mode 100644 src/hirad/eval/diurnal_cycle_precip.py diff --git a/src/hirad/diurnal_cycle.sh b/src/hirad/diurnal_cycle.sh new file mode 100644 index 0000000..4a021c0 --- /dev/null +++ b/src/hirad/diurnal_cycle.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +#SBATCH --job-name="plot_diurnal_cycle" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +##SBATCH --time=00:10:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/plot_diurnal_cycle.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/eval/diurnal_cycle_precip.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py new file mode 100644 index 0000000..231a5b3 --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -0,0 +1,147 @@ +import logging +from datetime import datetime +from pathlib import Path +from collections import defaultdict + +import hydra +import numpy as np +torch = __import__('torch') +from omegaconf import DictConfig, OmegaConf +import matplotlib.pyplot as plt + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range + +# Constants +CONV_FACTOR = 100 # Convert meters to mm/h +WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h +LOG_INTERVAL = 1 # Log progress every N timesteps + + +def hour_of(dt: str, fmt: str = "%Y%m%d-%H%M") -> int: + return datetime.strptime(dt, fmt).hour + + +def compute_ensemble(hourly_values): + hours = sorted(hourly_values) + means = [np.mean(hourly_values[h]) for h in hours] + stds = [np.std(hourly_values[h]) for h in hours] + return hours, means, stds + + +def save_plot(hours, lines, labels, ylabel, title, out_path): + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.figure(figsize=(8,4)) + for data, label in zip(lines, labels): + if isinstance(data, tuple): # (mean, std) + mean, std = data + plt.plot(hours, mean, label=label) + plt.fill_between(hours, + np.array(mean)-std, + np.array(mean)+std, + alpha=0.3) + else: + plt.plot(hours, data, label=label) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.savefig(out_path) + plt.close() + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting diurnal cycle computation") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + out_root = Path(cfg.generation.io.output_path or './outputs') + load = lambda ts, fn: torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + + # Find channel indices + out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} + in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + tp_out = out_ch['tp']; tp_in = in_ch.get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Prepare data structures + stats = {mode: defaultdict(list) for mode in ['target','baseline','prediction']} + wet_stats = {mode: defaultdict(list) for mode in stats} + + # Collect data + for idx, ts in enumerate(times, 1): + hr = hour_of(ts) + target = load(ts, f"{ts}-target")[tp_out] + baseline = load(ts, f"{ts}-baseline")[tp_in] + stats['target'][hr].append(target) + stats['baseline'][hr].append(baseline) + wet_stats['target'][hr].append((target > WET_THRESHOLD).mean()) + wet_stats['baseline'][hr].append((baseline > WET_THRESHOLD).mean()) + + preds = load(ts, f"{ts}-predictions")[:, tp_out] + for member in preds: + stats['prediction'][hr].append(member.mean()) + wet_stats['prediction'][hr].append((member > WET_THRESHOLD).mean()) + + if idx % LOG_INTERVAL == 0 or idx == len(times): + logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") + + # Compute hourly means + mean_cycle = { + mode: [np.mean(stats[mode][h]) for h in sorted(stats[mode])] + for mode in ['target','baseline'] + } + wet_cycle = { + mode: [np.mean(wet_stats[mode][h]) for h in sorted(wet_stats[mode])] + for mode in ['target','baseline'] + } + logger.info("Computed hourly mean and wet-cycle statistics") + + # Ensemble cycles (mean ± std) + hrs, pred_mean, pred_std = compute_ensemble(stats['prediction']) + _, wet_mean, wet_std = compute_ensemble(wet_stats['prediction']) + logger.info("Computed ensemble statistics") + + # Prepare cyclic series + cycle = lambda x: x + [x[0]] + hrs_c = hrs + [24] + amount_lines = [cycle(mean_cycle['target']), cycle(mean_cycle['baseline']), (cycle(pred_mean), cycle(pred_std))] + wet_lines = [cycle(wet_cycle['target']), cycle(wet_cycle['baseline']), (cycle(wet_mean), cycle(wet_std))] + + # Log the lines to be plotted (debug) + logger.info(f"amount_lines: {amount_lines}") + logger.info(f"wet_lines: {wet_lines}") + + # Plot + plot_paths = [] + fn1 = out_root/'diurnal_cycle_precip_amount.png' + save_plot(hrs_c, amount_lines, ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], 'Rain Rate (mm/h)', + 'Diurnal Cycle of Precip Amount', fn1) + plot_paths.append(fn1) + + fn2 = out_root/'diurnal_cycle_precip_wethours.png' + save_plot(hrs_c, wet_lines * 100., ['COSMO-2','ERA5','Pred Mean ± Std'], 'Wet-Hour Fraction [%]', + 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', fn2) + plot_paths.append(fn2) + + logger.info(f"Plots saved: {', '.join(str(p) for p in plot_paths)}") + +if __name__ == '__main__': + main() From 7470c99713e0eea84b271dccd09c5a080dc14d8a Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Jul 2025 12:36:42 +0200 Subject: [PATCH 092/189] add mlflow logging --- ci/docker/Dockerfile.corrdiff | 3 +- .../conf/logging/era_cosmo_diffusion.yaml | 5 + .../conf/logging/era_cosmo_regression.yaml | 5 + .../conf/training_era_cosmo_diffusion.yaml | 5 +- .../conf/training_era_cosmo_regression.yaml | 7 +- src/hirad/training/train.py | 190 ++++++++++-------- src/hirad/utils/env_info.py | 166 +++++++++++++++ src/hirad/utils/train_helpers.py | 39 ++++ 8 files changed, 330 insertions(+), 90 deletions(-) create mode 100644 src/hirad/conf/logging/era_cosmo_diffusion.yaml create mode 100644 src/hirad/conf/logging/era_cosmo_regression.yaml create mode 100644 src/hirad/utils/env_info.py diff --git a/ci/docker/Dockerfile.corrdiff b/ci/docker/Dockerfile.corrdiff index 93f389c..7908197 100644 --- a/ci/docker/Dockerfile.corrdiff +++ b/ci/docker/Dockerfile.corrdiff @@ -7,4 +7,5 @@ RUN pip install --upgrade pip # Install the rest of dependencies. RUN pip install \ Cartopy==0.22.0 \ - xskillscore + xskillscore \ + mlflow diff --git a/src/hirad/conf/logging/era_cosmo_diffusion.yaml b/src/hirad/conf/logging/era_cosmo_diffusion.yaml new file mode 100644 index 0000000..df7f481 --- /dev/null +++ b/src/hirad/conf/logging/era_cosmo_diffusion.yaml @@ -0,0 +1,5 @@ +# set method to mlflow to log with mlflow +method: null +experiment_name: hirad-corrdiff-diffusion +run_name: era-cosmo-1h +uri: null \ No newline at end of file diff --git a/src/hirad/conf/logging/era_cosmo_regression.yaml b/src/hirad/conf/logging/era_cosmo_regression.yaml new file mode 100644 index 0000000..ba4095b --- /dev/null +++ b/src/hirad/conf/logging/era_cosmo_regression.yaml @@ -0,0 +1,5 @@ +# set method to mlflow to log with mlflow +method: null +experiment_name: hirad-corrdiff-regression +run_name: era-cosmo-1h +uri: null \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 08c8d6a..fee7627 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -21,4 +21,7 @@ defaults: - training/era_cosmo_diffusion # Inference visualization - - generation/era_cosmo_training \ No newline at end of file + - generation/era_cosmo_training + + # Logging + - logging/era_cosmo_diffusion \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml index 83a4f94..ce04119 100644 --- a/src/hirad/conf/training_era_cosmo_regression.yaml +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -1,7 +1,7 @@ hydra: job: chdir: true - name: regression_era5_cosmo_test + name: regression_era5_cosmo_mlflow_test_x16 run: dir: /capstor/scratch/cscs/pstamenk/outputs/training/${hydra:job.name} @@ -21,4 +21,7 @@ defaults: - training/era_cosmo_regression # Inference visualization - - generation/era_cosmo_training \ No newline at end of file + - generation/era_cosmo_training + + # Logging + - logging/era_cosmo_regression \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 931fa66..a60844c 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -11,19 +11,22 @@ import nvtx import torch from hydra.utils import to_absolute_path -from torch.utils.tensorboard import SummaryWriter +# from torch.utils.tensorboard import SummaryWriter from torch.nn.parallel import DistributedDataParallel +import mlflow # from torchinfo import summary from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \ set_patch_shape, compute_num_accumulation_rounds, \ - is_time_for_periodic_task, handle_and_clip_gradients + is_time_for_periodic_task, handle_and_clip_gradients, \ + init_mlflow from hirad.utils.checkpoint import load_checkpoint, save_checkpoint from hirad.utils.patching import RandomPatching2D from hirad.utils.function_utils import get_time_from_range from hirad.utils.inference_utils import save_images +from hirad.utils.env_info import get_env_info, flatten_dict from hirad.models import UNet, EDMPrecondSuperResolution from hirad.losses import ResidualLoss, RegressionLoss, RegressionLossCE from hirad.datasets import init_train_valid_datasets_from_config, get_dataset_and_sampler_inference @@ -66,12 +69,19 @@ def main(cfg: DictConfig) -> None: DistributedManager.initialize() dist = DistributedManager() - if dist.rank==0: - writer = SummaryWriter(log_dir='tensorboard') + OmegaConf.resolve(cfg) + cfg_dict = OmegaConf.to_object(cfg) + + if cfg.logging.method == "mlflow": + init_mlflow(cfg_dict, dist) + if dist.world_size > 1: + torch.distributed.barrier() + elif cfg.logging.method is not None: + raise ValueError("The only available logging method is mlflow. To disable logging set the method to null.") + logger = PythonLogger("main") # general logger logger0 = RankZeroLoggingWrapper(logger, dist) # rank 0 logger - OmegaConf.resolve(cfg) dataset_cfg = OmegaConf.to_container(cfg.dataset) if hasattr(cfg.dataset, "validation_path"): train_test_split = True @@ -96,14 +106,22 @@ def main(cfg: DictConfig) -> None: if dist.rank==0 and not os.path.exists(visualization_dir): os.makedirs(visualization_dir) # added creating checkpoint dir visualize_checkpoints = True + if cfg.training.hp.batch_size_per_gpu == "auto" and \ + cfg.training.hp.total_batch_size == "auto": + raise ValueError("batch_size_per_gpu and total_batch_size can't be both set to 'auto'.") if cfg.training.hp.batch_size_per_gpu == "auto": cfg.training.hp.batch_size_per_gpu = ( cfg.training.hp.total_batch_size // dist.world_size ) + elif cfg.training.hp.total_batch_size == "auto": + cfg.training.hp.total_batch_size = ( + cfg.training.hp.batch_size_per_gpu * dist.world_size + ) + set_seed(dist.rank) configure_cuda_for_consistent_precision() - + # Instantiate the dataset data_loader_kwargs = { "pin_memory": True, @@ -143,7 +161,7 @@ def main(cfg: DictConfig) -> None: if cfg.generation.times_range and cfg.generation.times: raise ValueError("Either times_range or times must be provided, but not both") if cfg.generation.times_range: - times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") #TODO check what time formats we are using and adapt + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") else: times = cfg.generation.times viz_dataset_cfg = OmegaConf.to_container(cfg.generation.dataset) @@ -291,7 +309,6 @@ def main(cfg: DictConfig) -> None: raise FileNotFoundError( f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" ) - #regression_net = torch.nn.Module() #TODO Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name) #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json') @@ -567,9 +584,9 @@ def main(cfg: DictConfig) -> None: ) / n_average_loss_running_mean n_average_loss_running_mean += 1 - if dist.rank == 0: - writer.add_scalar("training_loss", average_loss, cur_nimg) - writer.add_scalar( + if dist.rank == 0 and cfg.logging.method == "mlflow": + mlflow.log_metric("training_loss", average_loss, cur_nimg) + mlflow.log_metric( "training_loss_running_mean", average_loss_running_mean, cur_nimg, @@ -598,8 +615,8 @@ def main(cfg: DictConfig) -> None: if cur_nimg >= lr_rampup: g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // cfg.training.hp.lr_decay_rate) current_lr = g["lr"] - if dist.rank == 0: - writer.add_scalar("learning_rate", current_lr, cur_nimg) + if dist.rank == 0 and cfg.logging.method == "mlflow": + mlflow.log_metric("learning_rate", current_lr, cur_nimg) handle_and_clip_gradients( model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold ) @@ -737,9 +754,9 @@ def main(cfg: DictConfig) -> None: valid_loss_sum, op=torch.distributed.ReduceOp.SUM ) average_valid_loss = valid_loss_sum / dist.world_size - if dist.rank == 0: - writer.add_scalar( - "validation_loss", average_valid_loss, cur_nimg + if dist.rank == 0 and cfg.logging.method == "mlflow": + mlflow.log_metric( + "validation_loss", average_valid_loss, cur_nimg ) @@ -770,83 +787,84 @@ def main(cfg: DictConfig) -> None: done, cfg.training.hp.total_batch_size, dist.rank, - ): - if visualize_checkpoints: - with nvtx.annotate("validation", color="red"): + ) and visualize_checkpoints: + with nvtx.annotate("visualization", color="red"): + if dist.rank == 0: + writer_executor = ThreadPoolExecutor( + max_workers=cfg.generation.perf.num_writer_workers + ) + writer_threads = [] + + times = visualization_dataset.time() + time_index = -1 + for index, (img_clean_viz, img_lr_viz, *lead_time_label_viz) in enumerate( + iter(visualization_data_loader) + ): + time_index += 1 + logger0.info(f"starting index: {time_index}") + + # continue + if lead_time_label_viz: + lead_time_label_viz = lead_time_label_viz[0].to(dist.device).contiguous() + else: + lead_time_label_viz = None + + if use_apex_gn: + img_clean_viz = img_clean_viz.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_viz = img_lr_viz.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean_viz = ( + img_clean_viz.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_viz = ( + img_lr_viz.to(dist.device) + .to(input_dtype) + .contiguous() + ) + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + image_pred_viz, image_reg_viz = generator.generate(img_lr_viz, lead_time_label_viz) if dist.rank == 0: - writer_executor = ThreadPoolExecutor( - max_workers=cfg.generation.perf.num_writer_workers + # write out data in a seperate thread so we don't hold up inferencing + output_path = os.path.join(visualization_dir, f"{cur_nimg}_{times[visualization_sampler[time_index]]}") + if dist.rank==0 and not os.path.exists(output_path): + os.makedirs(output_path) + writer_threads.append( + writer_executor.submit( + save_images, + output_path, + times[visualization_sampler[time_index]], + visualization_dataset, + image_pred_viz.cpu().numpy(), + img_clean_viz.cpu().numpy(), + img_lr_viz.cpu().numpy(), + image_reg_viz.cpu().numpy() if image_reg_viz is not None else None, + ) ) - writer_threads = [] - - times = visualization_dataset.time() - time_index = -1 - for index, (img_clean_viz, img_lr_viz, *lead_time_label_viz) in enumerate( - iter(visualization_data_loader) - ): - time_index += 1 - logger0.info(f"starting index: {time_index}") - - # continue - if lead_time_label_viz: - lead_time_label_viz = lead_time_label_viz[0].to(dist.device).contiguous() - else: - lead_time_label_viz = None + # make sure all the workers are done writing + if dist.rank == 0: + for thread in list(writer_threads): + thread.result() + writer_threads.remove(thread) + writer_executor.shutdown() - if use_apex_gn: - img_clean_viz = img_clean_viz.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - img_lr_viz = img_lr_viz.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - else: - img_clean_viz = ( - img_clean_viz.to(dist.device) - .to(input_dtype) - .contiguous() - ) - img_lr_viz = ( - img_lr_viz.to(dist.device) - .to(input_dtype) - .contiguous() - ) - with torch.autocast( - "cuda", dtype=amp_dtype, enabled=enable_amp - ): - image_pred_viz, image_reg_viz = generator.generate(img_lr_viz, lead_time_label_viz) - if dist.rank == 0: - # write out data in a seperate thread so we don't hold up inferencing - output_path = os.path.join(visualization_dir, f"{cur_nimg}_{times[visualization_sampler[time_index]]}") - if dist.rank==0 and not os.path.exists(output_path): - os.makedirs(output_path) - writer_threads.append( - writer_executor.submit( - save_images, - output_path, - times[visualization_sampler[time_index]], - visualization_dataset, - image_pred_viz.cpu().numpy(), - img_clean_viz.cpu().numpy(), - img_lr_viz.cpu().numpy(), - image_reg_viz.cpu().numpy() if image_reg_viz is not None else None, - ) - ) - # make sure all the workers are done writing - if dist.rank == 0: - for thread in list(writer_threads): - thread.result() - writer_threads.remove(thread) - writer_executor.shutdown() + if dist.world_size > 1: + torch.distributed.barrier() # Done. logger0.info("Training Completed.") - if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/hirad/utils/env_info.py b/src/hirad/utils/env_info.py new file mode 100644 index 0000000..f1cfe64 --- /dev/null +++ b/src/hirad/utils/env_info.py @@ -0,0 +1,166 @@ +""" +Environment Introspection Module +================================ + +This module provides functionality to introspect the Python environment, listing all non-standard +library modules along with their versions and Git information if they are part of a Git repository. +It helps in understanding the environment setup by detailing the modules in use, their versions, and +relevant Git metadata. +""" + +import platform +import os +import sys +from types import ModuleType +from typing import Any, Optional +from git import Repo, InvalidGitRepositoryError + + +def get_module_version(module: ModuleType) -> Optional[str]: + """ + Retrieve the version of a module if available. + + This function attempts to get the version of a module by accessing its ``__version__`` + attribute. It checks if the version is a string to ensure correctness. + + :param module: The module whose version is to be retrieved. + :type module: ModuleType + :return: The version string if available and valid, otherwise None. + :rtype: Optional[str] + """ + version = getattr(module, "__version__", None) + if isinstance(version, str): + return version + return None + + +def get_git_info(path: str) -> Optional[dict[str, Any]]: + """ + Collect basic Git metadata for a given repository path. + + This function checks if the given path is part of a Git repository and collects metadata such as + the commit SHA, modified files, untracked files, and remote URLs. + + :param path: The path to check for Git repository metadata. + :type path: str + :return: A dictionary containing Git metadata if the path is a Git repository, otherwise None. + :rtype: Optional[Dict[str, Any]] + """ + try: + repo = Repo(path, search_parent_directories=True) + diff = "" + for diff_item in repo.index.diff(None, create_patch=True): + a_path = diff_item.a_blob.abspath if diff_item.a_blob else "" + b_path = diff_item.b_blob.abspath if diff_item.b_blob else "" + diff_content = diff_item.diff + if isinstance(diff_content, bytes): + diff_content = diff_content.decode("utf-8") + elif diff_content is None: + diff_content = "" + diff += f"--- a{a_path}\n+++ b{b_path}\n{diff_content}\n\n" + git_info = { + "sha1": repo.head.commit.hexsha, + "diff": diff, + "untracked_files": sorted(repo.untracked_files), + "remotes": [r.url for r in repo.remotes], + } + return git_info + except InvalidGitRepositoryError: + return None + + +def get_module_git_info(module: ModuleType) -> Optional[dict[str, Any]]: + """ + Get Git information for a module if its directory is a Git repository. + + This function determines the directory of the module and checks if it is part of a Git + repository. If it is, it collects and returns the Git metadata. + + :param module: The module to check for Git information. + :type module: ModuleType + :return: A dictionary containing Git metadata if the module is in a Git repository, otherwise + None. + :rtype: Optional[Dict[str, Any]] + """ + module_path = getattr(module, "__file__", None) + if module_path is None or not os.path.isabs(module_path): + return None + module_dir = os.path.dirname(module_path) + return get_git_info(module_dir) + + +def get_env_info(flatten: bool = True, exclude_prefixes: list[str] = None) -> tuple[dict[str, dict[str, Any]], str]: + """ + List all non-standard library modules with their versions and Git information if available. + + This function iterates over all loaded modules in the Python environment, filtering out + built-in modules. It collects version and Git information for each remaining module. + + :return: A tuple containing two elements: + - A dictionary mapping module names to their metadata + - A string containing concatenated Git diffs from all modules + :rtype: tuple[dict[str, dict[str, Any]], str] + :rtype: Dict[str, Dict[str, Any]] + """ + env_info = {} + diffs: list[str] = [] + + exclude_prefixes = exclude_prefixes or [] + + for name, module in sys.modules.copy().items(): + if name in sys.builtin_module_names or name.endswith(('.version','._version')): + continue + + if any(name == prefix or name.startswith(prefix + ".") for prefix in exclude_prefixes): + continue + + version = get_module_version(module) + git_info = get_module_git_info(module) + if version is None and git_info is None: + continue + + module_info: dict[str, Any] = {"version": version} + if git_info is not None: + module_diff = git_info.pop("diff") + if module_diff and module_diff not in diffs: + diffs.append(module_diff) + module_info["git"] = git_info + + env_info[name] = module_info + + env_info["python"] = {"version": platform.python_version()} + + diffs_str = "\n".join(diffs) + + if flatten: + return flatten_dict(env_info), diffs_str + + return env_info, diffs_str + + +def flatten_dict( + d: dict[str, Any], parent_key: str = "", sep: str = "." +) -> dict[str, Any]: + """ + Flatten a nested dictionary. + + This function recursively traverses a nested dictionary and flattens it into a single-level + dictionary with keys formed by concatenating the nested keys using a separator. + + :param d: The dictionary to flatten. + :type d: Dict[str, Any] + :param parent_key: The base key to use for concatenation. + :type parent_key: str + :param sep: The separator to use for concatenating keys. + :type sep: str + :return: A flattened dictionary. + :rtype: Dict[str, Any] + """ + items = {} + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.update(flatten_dict(v, new_key, sep=sep)) + else: + items[new_key] = v + return items diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index dc1a5b9..39a8127 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -17,7 +17,10 @@ import torch import numpy as np import warnings +import mlflow +from hirad.distributed import DistributedManager +from hirad.utils.env_info import get_env_info, flatten_dict def set_patch_shape(img_shape, patch_shape): img_shape_y, img_shape_x = img_shape @@ -109,3 +112,39 @@ def is_time_for_periodic_task( return True else: return cur_nimg % freq < batch_size + + +def init_mlflow(cfg: dict, dist: DistributedManager) -> None: + if dist.rank==0: + if dist.world_size>4: + mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) + mlflow.start_run(run_name=cfg.logging.run_name) #, log_system_metrics=True) + else: + mlflow.system_metrics.set_system_metrics_node_id("node-0") + # mlflow.set_system_metrics_sampling_interval(1) + mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) + mlflow.start_run(run_name=cfg.logging.run_name, log_system_metrics=True) + run = mlflow.active_run() + with open("run_id.txt", 'w') as f: + f.write(run.info.run_id) + # log environment info + mlflow.log_params(flatten_dict(cfg)) + mlflow.log_dict(cfg, "config.json") + python_environment, git_diff = get_env_info(exclude_prefixes=['hirad', '__mp_main__']) + mlflow.log_dict(python_environment, "environment.json") + if git_diff: + mlflow.log_text(git_diff, "git_diff.txt") + + if dist.world_size > 4: + torch.distributed.barrier() + + if (dist.rank!=0 and dist._local_rank==0) or (dist.rank==1 and dist.world_size>4): + mlflow.system_metrics.set_system_metrics_node_id(f"node-{(dist.rank//4)}" + if dist.rank!=1 + else "node-0") + # mlflow.set_system_metrics_sampling_interval(1) + # mlflow.set_system_metrics_samples_before_logging(10) + mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) + with open("run_id.txt", 'r') as f: + run_id = f.read() + mlflow.start_run(run_id=run_id, log_system_metrics=True) \ No newline at end of file From 0f3b0c6749a2fee0b5cca8e557af6da966d3eec0 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Jul 2025 13:54:41 +0200 Subject: [PATCH 093/189] mlflow init bug fix --- src/hirad/training/train.py | 3 +-- src/hirad/utils/train_helpers.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index a60844c..13c4865 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -70,10 +70,9 @@ def main(cfg: DictConfig) -> None: dist = DistributedManager() OmegaConf.resolve(cfg) - cfg_dict = OmegaConf.to_object(cfg) if cfg.logging.method == "mlflow": - init_mlflow(cfg_dict, dist) + init_mlflow(cfg, dist) if dist.world_size > 1: torch.distributed.barrier() elif cfg.logging.method is not None: diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index 39a8127..8adee63 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -18,6 +18,7 @@ import numpy as np import warnings import mlflow +from omegaconf import DictConfig, OmegaConf from hirad.distributed import DistributedManager from hirad.utils.env_info import get_env_info, flatten_dict @@ -114,7 +115,7 @@ def is_time_for_periodic_task( return cur_nimg % freq < batch_size -def init_mlflow(cfg: dict, dist: DistributedManager) -> None: +def init_mlflow(cfg: DictConfig, dist: DistributedManager) -> None: if dist.rank==0: if dist.world_size>4: mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) @@ -128,7 +129,7 @@ def init_mlflow(cfg: dict, dist: DistributedManager) -> None: with open("run_id.txt", 'w') as f: f.write(run.info.run_id) # log environment info - mlflow.log_params(flatten_dict(cfg)) + mlflow.log_params(flatten_dict(OmegaConf.to_object(cfg))) mlflow.log_dict(cfg, "config.json") python_environment, git_diff = get_env_info(exclude_prefixes=['hirad', '__mp_main__']) mlflow.log_dict(python_environment, "environment.json") From 3bc4caed92f4831a1c6b9131041f4628fee86e29 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Jul 2025 15:10:11 +0200 Subject: [PATCH 094/189] add option to continue same run mlflow --- src/hirad/training/train.py | 2 +- src/hirad/utils/train_helpers.py | 40 +++++++++++++++++--------------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 13c4865..12131fc 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -782,7 +782,7 @@ def main(cfg: DictConfig) -> None: torch.distributed.barrier() if is_time_for_periodic_task( cur_nimg, - cfg.training.io.save_checkpoint_freq, + cfg.training.io.visualization_freq, done, cfg.training.hp.total_batch_size, dist.rank, diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index 8adee63..65b4b3c 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -19,6 +19,7 @@ import warnings import mlflow from omegaconf import DictConfig, OmegaConf +import os from hirad.distributed import DistributedManager from hirad.utils.env_info import get_env_info, flatten_dict @@ -117,24 +118,27 @@ def is_time_for_periodic_task( def init_mlflow(cfg: DictConfig, dist: DistributedManager) -> None: if dist.rank==0: - if dist.world_size>4: - mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) - mlflow.start_run(run_name=cfg.logging.run_name) #, log_system_metrics=True) - else: + run_id = None + if os.path.isfile('run_id.txt'): + with open('run_id.txt','r') as f: + run_id = f.read() + if dist.world_size<=4: mlflow.system_metrics.set_system_metrics_node_id("node-0") - # mlflow.set_system_metrics_sampling_interval(1) - mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) - mlflow.start_run(run_name=cfg.logging.run_name, log_system_metrics=True) - run = mlflow.active_run() - with open("run_id.txt", 'w') as f: - f.write(run.info.run_id) - # log environment info - mlflow.log_params(flatten_dict(OmegaConf.to_object(cfg))) - mlflow.log_dict(cfg, "config.json") - python_environment, git_diff = get_env_info(exclude_prefixes=['hirad', '__mp_main__']) - mlflow.log_dict(python_environment, "environment.json") - if git_diff: - mlflow.log_text(git_diff, "git_diff.txt") + if run_id: + mlflow.start_run(run_id=run_id, log_system_metrics=False if dist.world_size>4 else True) + else: + mlflow.start_run(run_name=cfg.logging.run_name, log_system_metrics=False if dist.world_size>4 else True) + if run_id is None: + run = mlflow.active_run() + with open("run_id.txt", 'w') as f: + f.write(run.info.run_id) + # log environment info if run is not continuing from previous checkpoint + mlflow.log_params(flatten_dict(OmegaConf.to_object(cfg))) + mlflow.log_dict(cfg, "config.json") + python_environment, git_diff = get_env_info(exclude_prefixes=['hirad', '__mp_main__']) + mlflow.log_dict(python_environment, "environment.json") + if git_diff: + mlflow.log_text(git_diff, "git_diff.txt") if dist.world_size > 4: torch.distributed.barrier() @@ -143,8 +147,6 @@ def init_mlflow(cfg: DictConfig, dist: DistributedManager) -> None: mlflow.system_metrics.set_system_metrics_node_id(f"node-{(dist.rank//4)}" if dist.rank!=1 else "node-0") - # mlflow.set_system_metrics_sampling_interval(1) - # mlflow.set_system_metrics_samples_before_logging(10) mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) with open("run_id.txt", 'r') as f: run_id = f.read() From 6ff06da774e8aa10ef31551118d19bd619e9009c Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Jul 2025 15:11:05 +0200 Subject: [PATCH 095/189] add separate visualization frequency setting --- src/hirad/conf/training/era_cosmo_diffusion.yaml | 14 ++++++++------ src/hirad/conf/training/era_cosmo_regression.yaml | 14 ++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index 5a14a6a..40e37f3 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -1,10 +1,10 @@ # Hyperparameters hp: - training_duration: 1024 + training_duration: 20000 # Training duration based on the number of processed samples - total_batch_size: 64 + total_batch_size: "auto" # Total batch size - batch_size_per_gpu: "auto" + batch_size_per_gpu: 22 # Batch size per GPU lr: 0.0002 # Learning rate @@ -31,11 +31,13 @@ perf: io: regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression # Where to load the regression checkpoint - print_progress_freq: 128 + print_progress_freq: 500 # How often to print progress - save_checkpoint_freq: 1024 + save_checkpoint_freq: 10000 # How often to save the checkpoints, measured in number of processed samples - validation_freq: 256 + visualization_freq: 200000 + # how often to visualize network outputs + validation_freq: 2000 # how often to record the validation loss, measured in number of processed samples validation_steps: 2 # how many loss evaluations are used to compute the validation loss per checkpoint diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index b45b206..5cd5d0d 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,10 +1,10 @@ # Hyperparameters hp: - training_duration: 1024 + training_duration: 40000 # Training duration based on the number of processed samples - total_batch_size: 64 + total_batch_size: "auto" # Total batch size -- based 8 per GPU -- 2 nodes is 2x8x4 -- see sbatch vars for how many gpus. diffusion need to point to the rgression. - batch_size_per_gpu: "auto" + batch_size_per_gpu: 22 # Batch size per GPU lr: 0.0002 # Learning rate @@ -31,11 +31,13 @@ perf: # I/O io: - print_progress_freq: 128 + print_progress_freq: 500 # How often to print progress - save_checkpoint_freq: 512 + save_checkpoint_freq: 100000 # How often to save the checkpoints, measured in number of processed samples - validation_freq: 256 + visualization_freq: 200000 + # how often to visualize network output + validation_freq: 2000 # how often to record the validation loss, measured in number of processed samples validation_steps: 2 # how many loss evaluations are used to compute the validation loss per checkpoint From 06c185f1a9b9f7189ba89f428eff807cfd817b69 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Fri, 25 Jul 2025 15:56:53 +0200 Subject: [PATCH 096/189] add plots of temperature and windspeed --- src/hirad/diurnal_cycle.sh | 1 + src/hirad/eval/diurnal_cycle_precip.py | 9 +- src/hirad/eval/diurnal_cycle_temp_wind.py | 142 ++++++++++++++++++++++ 3 files changed, 149 insertions(+), 3 deletions(-) create mode 100644 src/hirad/eval/diurnal_cycle_temp_wind.py diff --git a/src/hirad/diurnal_cycle.sh b/src/hirad/diurnal_cycle.sh index 4a021c0..1a62fcf 100644 --- a/src/hirad/diurnal_cycle.sh +++ b/src/hirad/diurnal_cycle.sh @@ -46,4 +46,5 @@ export OMP_NUM_THREADS=72 srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/eval/diurnal_cycle_precip.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index 231a5b3..77af57b 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -16,7 +16,7 @@ # Constants CONV_FACTOR = 100 # Convert meters to mm/h WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h -LOG_INTERVAL = 1 # Log progress every N timesteps +LOG_INTERVAL = 24 # Log progress every N timesteps def hour_of(dt: str, fmt: str = "%Y%m%d-%H%M") -> int: @@ -109,7 +109,7 @@ def main(cfg: DictConfig): for mode in ['target','baseline'] } wet_cycle = { - mode: [np.mean(wet_stats[mode][h]) for h in sorted(wet_stats[mode])] + mode: [np.mean(wet_stats[mode][h]) * 100. for h in sorted(wet_stats[mode])] for mode in ['target','baseline'] } logger.info("Computed hourly mean and wet-cycle statistics") @@ -117,6 +117,9 @@ def main(cfg: DictConfig): # Ensemble cycles (mean ± std) hrs, pred_mean, pred_std = compute_ensemble(stats['prediction']) _, wet_mean, wet_std = compute_ensemble(wet_stats['prediction']) + # Multiply ensemble wet-hour statistics by 100 for percentage + wet_mean = [v * 100. for v in wet_mean] + wet_std = [v * 100. for v in wet_std] logger.info("Computed ensemble statistics") # Prepare cyclic series @@ -137,7 +140,7 @@ def main(cfg: DictConfig): plot_paths.append(fn1) fn2 = out_root/'diurnal_cycle_precip_wethours.png' - save_plot(hrs_c, wet_lines * 100., ['COSMO-2','ERA5','Pred Mean ± Std'], 'Wet-Hour Fraction [%]', + save_plot(hrs_c, wet_lines, ['COSMO-2','ERA5','Pred Mean ± Std'], 'Wet-Hour Fraction [%]', 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', fn2) plot_paths.append(fn2) diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py new file mode 100644 index 0000000..0f769ad --- /dev/null +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -0,0 +1,142 @@ +import logging +from datetime import datetime +from pathlib import Path +from collections import defaultdict + +import hydra +import numpy as np +torch = __import__('torch') +from omegaconf import DictConfig, OmegaConf +import matplotlib.pyplot as plt + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range + +LOG_INTERVAL = 24 + +def hour_of(dt: str, fmt: str = "%Y%m%d-%H%M") -> int: + return datetime.strptime(dt, fmt).hour + +def compute_ensemble(hourly_values): + hours = sorted(hourly_values) + means = [np.mean(hourly_values[h]) for h in hours] + stds = [np.std(hourly_values[h]) for h in hours] + return hours, means, stds + +def save_plot(hours, lines, labels, ylabel, title, out_path): + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.figure(figsize=(8,4)) + for data, label in zip(lines, labels): + if isinstance(data, tuple): # (mean, std) + mean, std = data + plt.plot(hours, mean, label=label) + plt.fill_between(hours, + np.array(mean)-std, + np.array(mean)+std, + alpha=0.3) + else: + plt.plot(hours, data, label=label) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.savefig(out_path) + plt.close() + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting diurnal cycle computation for 2m temperature and windspeed") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + out_root = Path(cfg.generation.io.output_path or './outputs') + load = lambda ts, fn: torch.load(out_root/ts/fn, weights_only=False) + + # Find channel indices + out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} + in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + t2m_out = out_ch.get('2t', out_ch.get('t2m')) + t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out)) + u_out = out_ch.get('10u') + v_out = out_ch.get('10v') + u_in = in_ch.get('10u', u_out) + v_in = in_ch.get('10v', v_out) + logger.info(f"2T channel indices - output: {t2m_out}, input: {t2m_in}") + logger.info(f"10U/10V channel indices - output: {u_out}/{v_out}, input: {u_in}/{v_in}") + + stats_temp = {mode: defaultdict(list) for mode in ['target','baseline','prediction']} + stats_wind = {mode: defaultdict(list) for mode in ['target','baseline','prediction']} + + for idx, ts in enumerate(times, 1): + hr = hour_of(ts) + target = load(ts, f"{ts}-target") + baseline = load(ts, f"{ts}-baseline") + # 2m temperature + stats_temp['target'][hr].append(target[t2m_out].mean()) + stats_temp['baseline'][hr].append(baseline[t2m_in].mean()) + # windspeed using np.hypot + stats_wind['target'][hr].append(np.hypot(target[u_out], target[v_out]).mean()) + stats_wind['baseline'][hr].append(np.hypot(baseline[u_in], baseline[v_in]).mean()) + + preds = load(ts, f"{ts}-predictions") + for member in preds: + stats_temp['prediction'][hr].append(member[t2m_out].mean()) + stats_wind['prediction'][hr].append(np.hypot(member[u_out], member[v_out]).mean()) + + if idx % LOG_INTERVAL == 0 or idx == len(times): + logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") + + # Compute hourly means + mean_cycle_temp = { + mode: [np.mean(stats_temp[mode][h]) for h in sorted(stats_temp[mode])] + for mode in ['target','baseline'] + } + mean_cycle_wind = { + mode: [np.mean(stats_wind[mode][h]) for h in sorted(stats_wind[mode])] + for mode in ['target','baseline'] + } + logger.info("Computed hourly mean statistics") + + # Ensemble cycles (mean ± std) + hrs, pred_mean_temp, pred_std_temp = compute_ensemble(stats_temp['prediction']) + hrs_w, pred_mean_wind, pred_std_wind = compute_ensemble(stats_wind['prediction']) + logger.info("Computed ensemble statistics") + + # Prepare cyclic series + cycle = lambda x: x + [x[0]] + hrs_c = hrs + [24] + hrs_w_c = hrs_w + [24] + temp_lines = [cycle(mean_cycle_temp['target']), cycle(mean_cycle_temp['baseline']), (cycle(pred_mean_temp), cycle(pred_std_temp))] + wind_lines = [cycle(mean_cycle_wind['target']), cycle(mean_cycle_wind['baseline']), (cycle(pred_mean_wind), cycle(pred_std_wind))] + + # Plot + plot_paths = [] + fn1 = out_root/'diurnal_cycle_2t.png' + save_plot(hrs_c, temp_lines, ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], '2m Temperature [K]', + 'Diurnal Cycle of 2m Temperature', fn1) + plot_paths.append(fn1) + + fn2 = out_root/'diurnal_cycle_windspeed.png' + save_plot(hrs_w_c, wind_lines, ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], 'Windspeed [m/s]', + 'Diurnal Cycle of Windspeed', fn2) + plot_paths.append(fn2) + + logger.info(f"Plots saved: {', '.join(str(p) for p in plot_paths)}") + +if __name__ == '__main__': + main() From 02ab40ab8f32ce8f1a5925ae5ca9187fb4b5f742 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Fri, 25 Jul 2025 17:34:39 +0200 Subject: [PATCH 097/189] Add script to plot the 99th all-hour percentile --- src/hirad/diurnal_cycle.sh | 1 + src/hirad/eval/diurnal_cycle_precip.py | 11 +- src/hirad/eval/diurnal_cycle_temp_wind.py | 7 +- src/hirad/eval/percentile99_cycle_precip.py | 153 ++++++++++++++++++++ 4 files changed, 162 insertions(+), 10 deletions(-) create mode 100644 src/hirad/eval/percentile99_cycle_precip.py diff --git a/src/hirad/diurnal_cycle.sh b/src/hirad/diurnal_cycle.sh index 1a62fcf..6643757 100644 --- a/src/hirad/diurnal_cycle.sh +++ b/src/hirad/diurnal_cycle.sh @@ -46,5 +46,6 @@ export OMP_NUM_THREADS=72 srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/eval/diurnal_cycle_precip.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/percentile99_cycle_precip.py --config-name=generate_era_cosmo.yaml python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index 77af57b..8a38078 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -5,7 +5,7 @@ import hydra import numpy as np -torch = __import__('torch') +import torch from omegaconf import DictConfig, OmegaConf import matplotlib.pyplot as plt @@ -34,13 +34,10 @@ def save_plot(hours, lines, labels, ylabel, title, out_path): Path(out_path).parent.mkdir(parents=True, exist_ok=True) plt.figure(figsize=(8,4)) for data, label in zip(lines, labels): - if isinstance(data, tuple): # (mean, std) + if isinstance(data, tuple): mean, std = data - plt.plot(hours, mean, label=label) - plt.fill_between(hours, - np.array(mean)-std, - np.array(mean)+std, - alpha=0.3) + line, = plt.plot(hours, mean, label=label) + plt.fill_between(hours, np.maximum(np.array(mean)-std, 0), np.array(mean)+std, alpha=0.3, color=line.get_color()) else: plt.plot(hours, data, label=label) plt.xlabel('Hour (UTC)') diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index 0f769ad..ed35042 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -5,7 +5,7 @@ import hydra import numpy as np -torch = __import__('torch') +import torch from omegaconf import DictConfig, OmegaConf import matplotlib.pyplot as plt @@ -30,11 +30,12 @@ def save_plot(hours, lines, labels, ylabel, title, out_path): for data, label in zip(lines, labels): if isinstance(data, tuple): # (mean, std) mean, std = data - plt.plot(hours, mean, label=label) + line, = plt.plot(hours, mean, label=label) plt.fill_between(hours, np.array(mean)-std, np.array(mean)+std, - alpha=0.3) + alpha=0.3, + color=line.get_color()) else: plt.plot(hours, data, label=label) plt.xlabel('Hour (UTC)') diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py new file mode 100644 index 0000000..f657b21 --- /dev/null +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -0,0 +1,153 @@ +""" +Plots the diurnal cycle of the all-hour 99th percentile of +precipitation, a somewhat reliable measure of the precipitation intensity. + +Each hour, member and type is treaded separately, to conserve memory... but if the +period is long, this can still be a lot of data and thus an OOM error can occur. +""" +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range + +# Constants +CONV_FACTOR = 100 # Convert meters to mm/h +LOG_INTERVAL = 24 # Log progress every N timesteps + + +def hour_of(dt: str, fmt: str = "%Y%m%d-%H%M") -> int: + return datetime.strptime(dt, fmt).hour + + +def save_plot(hours, lines, labels, ylabel, title, out_path): + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.figure(figsize=(8,4)) + for data, label in zip(lines, labels): + if isinstance(data, tuple): # (mean, std) + mean, std = data + lower = np.maximum(np.array(mean) - std, 0) + upper = np.array(mean) + std + line, = plt.plot(hours, mean, label=label) + plt.fill_between(hours, lower, upper, alpha=0.3, color=line.get_color()) + else: + plt.plot(hours, data, label=label) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.savefig(out_path) + plt.close() + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting 99th-percentile diurnal cycle computation") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root and loader + out_root = Path(cfg.generation.io.output_path or './outputs') + load = lambda ts, fn: torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + + # Find channel indices + out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} + in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + tp_out = out_ch['tp']; tp_in = in_ch.get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Storage for diurnal cycles + pct99_mean = {'target': [], 'baseline': [], 'prediction': []} + pct99_std = {'target': [], 'baseline': [], 'prediction': []} + + # -- Target and Baseline: compute per hour -- + for mode in ['target', 'baseline']: + logger.info(f"Processing mode: {mode}") + for h in list(range(24)): + arrs = [ + load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] + for ts in times if hour_of(ts) == h + ] + stack = np.stack(arrs, axis=0) + f99 = np.percentile(stack, 99, axis=0) + pct99_mean[mode].append(f99.mean()) + pct99_std[mode].append(np.std(f99, axis=None)) + del arrs, stack, f99 + + # -- Predictions: compute per hour per member, then mean+std across members -- + # Determine number of ensemble members + sample = load(times[0], f"{times[0]}-predictions") # [n_members, n_channels, lat, lon] + data_sample = sample[:, tp_out] + n_members = data_sample.shape[0] + + for h in list(range(24)): + logger.info(f"Processing predictions for hour {h}") + mem_f99 = [] + # for each ensemble member, gather its hourly fields + for m in range(n_members): + arrs = [] + for ts in times: + if hour_of(ts) != h: + continue + preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, ...] + arrs.append(preds[m, tp_out]) # one field + # stack over time and compute 99th percentile at each grid point + stack_m = np.stack(arrs, axis=0) + f99_m = np.percentile(stack_m, 99, axis=0) + mem_f99.append(f99_m.mean()) + # ensemble-level mean and std over member-wise percentiles + pct99_mean['prediction'].append(np.mean(mem_f99)) + pct99_std['prediction'].append(np.std(mem_f99, axis=None)) + # clean up per-hour buffers + del mem_f99, stack_m, f99_m + + # Prepare cyclic series + cycle_fn = lambda x: x + [x[0]] + hrs_c = list(range(24)) + [list(range(24))[0] + 24] + pct99_lines = [ + cycle_fn(pct99_mean['target']), + cycle_fn(pct99_mean['baseline']), + ( + cycle_fn(pct99_mean['prediction']), + cycle_fn(pct99_std['prediction']) + ) + ] + + # Plot combined diurnal 99th-percentile cycle + fn = out_root/'diurnal_cycle_precip_99th_percentile.png' + save_plot( + hrs_c, + pct99_lines, + ['COSMO-2','ERA5','CorrDiff 99th Pct ± Std'], + 'Rain Rate (mm/h)', + 'Diurnal Cycle of 99th-Percentile Precipitation', + fn + ) + logger.info(f"Combined plot saved: {fn}") + +if __name__ == '__main__': + main() From c8a3f778d7a69b6a79da837a084c5bf70a6a5c5a Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Jul 2025 17:39:37 +0200 Subject: [PATCH 098/189] add visualization logging option for mlflow --- src/hirad/conf/logging/era_cosmo_diffusion.yaml | 3 ++- src/hirad/conf/logging/era_cosmo_regression.yaml | 3 ++- src/hirad/training/train.py | 9 ++++++++- src/hirad/utils/train_helpers.py | 11 ++++++----- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/hirad/conf/logging/era_cosmo_diffusion.yaml b/src/hirad/conf/logging/era_cosmo_diffusion.yaml index df7f481..c57560c 100644 --- a/src/hirad/conf/logging/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/logging/era_cosmo_diffusion.yaml @@ -2,4 +2,5 @@ method: null experiment_name: hirad-corrdiff-diffusion run_name: era-cosmo-1h -uri: null \ No newline at end of file +uri: null +log_images: false \ No newline at end of file diff --git a/src/hirad/conf/logging/era_cosmo_regression.yaml b/src/hirad/conf/logging/era_cosmo_regression.yaml index ba4095b..ee426e5 100644 --- a/src/hirad/conf/logging/era_cosmo_regression.yaml +++ b/src/hirad/conf/logging/era_cosmo_regression.yaml @@ -2,4 +2,5 @@ method: null experiment_name: hirad-corrdiff-regression run_name: era-cosmo-1h -uri: null \ No newline at end of file +uri: null +log_images: false \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 12131fc..a6aff47 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -796,6 +796,7 @@ def main(cfg: DictConfig) -> None: times = visualization_dataset.time() time_index = -1 + output_paths_list = [] for index, (img_clean_viz, img_lr_viz, *lead_time_label_viz) in enumerate( iter(visualization_data_loader) ): @@ -837,6 +838,7 @@ def main(cfg: DictConfig) -> None: if dist.rank == 0: # write out data in a seperate thread so we don't hold up inferencing output_path = os.path.join(visualization_dir, f"{cur_nimg}_{times[visualization_sampler[time_index]]}") + output_paths_list.append(output_path) if dist.rank==0 and not os.path.exists(output_path): os.makedirs(output_path) writer_threads.append( @@ -857,7 +859,12 @@ def main(cfg: DictConfig) -> None: thread.result() writer_threads.remove(thread) writer_executor.shutdown() - + if cfg.logging.method == "mlflow" and cfg.logging.log_images: + for output_path in output_paths_list: + mlflow.log_artifacts(output_path, + os.path.join( + 'visualization', + os.path.split(output_path)[-1])) if dist.world_size > 1: diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index 65b4b3c..9b53fc3 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -118,6 +118,7 @@ def is_time_for_periodic_task( def init_mlflow(cfg: DictConfig, dist: DistributedManager) -> None: if dist.rank==0: + mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) run_id = None if os.path.isfile('run_id.txt'): with open('run_id.txt','r') as f: @@ -134,11 +135,11 @@ def init_mlflow(cfg: DictConfig, dist: DistributedManager) -> None: f.write(run.info.run_id) # log environment info if run is not continuing from previous checkpoint mlflow.log_params(flatten_dict(OmegaConf.to_object(cfg))) - mlflow.log_dict(cfg, "config.json") - python_environment, git_diff = get_env_info(exclude_prefixes=['hirad', '__mp_main__']) - mlflow.log_dict(python_environment, "environment.json") - if git_diff: - mlflow.log_text(git_diff, "git_diff.txt") + python_environment, git_diff = get_env_info(exclude_prefixes=['hirad', '__mp_main__']) + mlflow.log_dict(python_environment, "environment.json") + if git_diff: + mlflow.log_text(git_diff, "git_diff.txt") + mlflow.log_dict(cfg, "config.json") if dist.world_size > 4: torch.distributed.barrier() From 3d1347aca1f177bd467e9f62ba7999597367b697 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 28 Jul 2025 12:41:21 +0200 Subject: [PATCH 099/189] enable mlflow logging to remote server --- .gitignore | 2 ++ src/hirad/conf/logging/era_cosmo_diffusion.yaml | 4 +++- src/hirad/conf/logging/era_cosmo_regression.yaml | 4 +++- src/hirad/utils/train_helpers.py | 6 ++++++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 17c4ea5..7417316 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,5 @@ temp.zarr.sync* src/hirad/eval/__pycache__/* interpolate_basic.log interpolated.torch +mlruns/ +.secrets.env diff --git a/src/hirad/conf/logging/era_cosmo_diffusion.yaml b/src/hirad/conf/logging/era_cosmo_diffusion.yaml index c57560c..86ec7fe 100644 --- a/src/hirad/conf/logging/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/logging/era_cosmo_diffusion.yaml @@ -1,6 +1,8 @@ # set method to mlflow to log with mlflow -method: null +method: mlflow experiment_name: hirad-corrdiff-diffusion run_name: era-cosmo-1h +# change uri to remote mlflow server; if null, it is stored locally +# if uri is remote make sure to have credentials set in ~/.mlflow/credentials uri: null log_images: false \ No newline at end of file diff --git a/src/hirad/conf/logging/era_cosmo_regression.yaml b/src/hirad/conf/logging/era_cosmo_regression.yaml index ee426e5..e7a6287 100644 --- a/src/hirad/conf/logging/era_cosmo_regression.yaml +++ b/src/hirad/conf/logging/era_cosmo_regression.yaml @@ -1,6 +1,8 @@ # set method to mlflow to log with mlflow -method: null +method: mlflow experiment_name: hirad-corrdiff-regression run_name: era-cosmo-1h +# change uri to remote mlflow server; if null, it is stored locally +# if uri is remote make sure to have credentials set in ~/.mlflow/credentials uri: null log_images: false \ No newline at end of file diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index 9b53fc3..f379aa4 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -118,6 +118,9 @@ def is_time_for_periodic_task( def init_mlflow(cfg: DictConfig, dist: DistributedManager) -> None: if dist.rank==0: + print("Started activating initial mlflow run") + if cfg.logging.uri is not None: + mlflow.set_tracking_uri(cfg.logging.uri) mlflow.set_experiment(experiment_name=cfg.logging.experiment_name) run_id = None if os.path.isfile('run_id.txt'): @@ -145,6 +148,9 @@ def init_mlflow(cfg: DictConfig, dist: DistributedManager) -> None: torch.distributed.barrier() if (dist.rank!=0 and dist._local_rank==0) or (dist.rank==1 and dist.world_size>4): + print("Started actvating sub mlflow run.") + if cfg.logging.uri is not None: + mlflow.set_tracking_uri(cfg.logging.uri) mlflow.system_metrics.set_system_metrics_node_id(f"node-{(dist.rank//4)}" if dist.rank!=1 else "node-0") From 678fbe97924651033ec46c5afd5ace35b025554a Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 28 Jul 2025 15:08:15 +0200 Subject: [PATCH 100/189] fix logging bug for average loss --- src/hirad/training/train.py | 98 +++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 54 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index a6aff47..0573929 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -577,32 +577,11 @@ def main(cfg: DictConfig) -> None: ) average_loss = (loss_sum / dist.world_size).cpu().item() - # update running mean of average loss since last periodic task - average_loss_running_mean += ( - average_loss - average_loss_running_mean - ) / n_average_loss_running_mean - n_average_loss_running_mean += 1 - - if dist.rank == 0 and cfg.logging.method == "mlflow": - mlflow.log_metric("training_loss", average_loss, cur_nimg) - mlflow.log_metric( - "training_loss_running_mean", - average_loss_running_mean, - cur_nimg, - ) - - ptt = is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ) - if ptt: - # reset running mean of average loss - average_loss_running_mean = 0 - n_average_loss_running_mean = 1 + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 # Update weights. with nvtx.annotate("update weights", color="blue"): @@ -625,38 +604,49 @@ def main(cfg: DictConfig) -> None: cur_nimg += cfg.training.hp.total_batch_size done = cur_nimg >= cfg.training.hp.training_duration - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - # Print stats if we crossed the printing threshold with this batch - tick_end_time = time.time() - fields = [] - fields += [f"samples {cur_nimg:<9.1f}"] - fields += [f"training_loss {average_loss:<7.2f}"] - fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] - fields += [f"learning_rate {current_lr:<7.8f}"] - fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] - fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] - fields += [ - f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" - ] - fields += [ - f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" - ] - if torch.cuda.is_available(): + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] fields += [ - f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" ] fields += [ - f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" ] - torch.cuda.reset_peak_memory_stats() - logger0.info(" ".join(fields)) + if torch.cuda.is_available(): + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + + if cfg.logging.method == "mlflow": + mlflow.log_metric("training_loss", average_loss, cur_nimg) + mlflow.log_metric( + "training_loss_running_mean", + average_loss_running_mean, + cur_nimg, + ) + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 with nvtx.annotate("validation", color="red"): # Validation From e7d6c6bc900bad33ca53df02b7e42db659864932 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 28 Jul 2025 16:17:31 +0200 Subject: [PATCH 101/189] update readme --- README.md | 108 ++++++++++++++++++++++++------------------------------ 1 file changed, 47 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index b0dbd2e..3e6e5c5 100644 --- a/README.md +++ b/README.md @@ -2,37 +2,20 @@ HiRAD-Gen is short for high-resolution atmospheric downscaling using generative models. This repository contains the code and configuration required to train and use the model. -## Installation (Alps) +[Setup clariden/santis](#setup-claridensantis) +[Regression training - clariden/santis](#run-regression-model-training-alps) +[Diffusion training - clariden/santis](#run-diffusion-model-training-alps) +[Inference - clariden/santis](#running-inference-on-alps) +[Installation - uenv/venv - deprecated](#installation-alps-uenvvenv---deprecated) -To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these steps: - -1. **Start the PyTorch user environment**: - ```bash - uenv start pytorch/v2.6.0:v1 --view=default - ``` - -2. **Create a Python virtual environment** (replace `{env_name}` with your desired environment name): - ```bash - python -m venv ./{env_name} - ``` - -3. **Activate the virtual environment**: - ```bash - source ./{env_name}/bin/activate - ``` - -4. **Install project dependencies**: - ```bash - pip install -e . - ``` - -This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure. +## Setup clariden/santis +Container environment setup needed to run training and inference experiments on clariden/santis is contained in this repository under `ci/edf/modulus_env.toml`. Image squash is on clariden/alps under `/capstor/scratch/cscs/pstamenk/hirad.sqsh`. All the jobs can be run using this environment without additional installations and setup. ## Training ### Run regression model training (Alps) -1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. +1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. Here, you can change the sbatch settings. Inside this script set the following: ```bash ### OUTPUT ### @@ -42,12 +25,6 @@ Inside this script set the following: ```bash #SBATCH -A your_compute_group ``` -```bash -srun bash -c " - . ./{your_env_name}/bin/activate - python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml -" -``` 2. Set up the following config files in `src/hirad/conf`: @@ -55,14 +32,9 @@ srun bash -c " ``` hydra: run: - dir: your_path_to_save_training_output -``` -- In `training/era_cosmo_regression.yaml` set: -``` -hp: - training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) + dir: your_path_to_save_training_outputs ``` -- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. +- All other parameters for training regression can be changed in the main config file `training_era_cosmo_regression.yaml` and config files the main config is referencing (default values are working for debugging purposes). 3. Submit the job with: ```bash @@ -72,8 +44,7 @@ sbatch src/hirad/train_regression.sh ### Run diffusion model training (Alps) Before training diffusion model, checkpoint for regression model has to exist. -1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. -Inside this script set the following: +1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. Here, you can change the sbatch settings. Inside this script set the following: ```bash ### OUTPUT ### #SBATCH --output=your_path_to_output_log @@ -82,12 +53,6 @@ Inside this script set the following: ```bash #SBATCH -A your_compute_group ``` -```bash -srun bash -c " - . ./{your_env_name}/bin/activate - python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml -" -``` 2. Set up the following config files in `src/hirad/conf`: @@ -97,14 +62,12 @@ hydra: run: dir: your_path_to_save_training_output ``` -- In `training/era_cosmo_regression.yaml` set: +- In `training/era_cosmo_diffusion.yaml` set: ``` -hp: - training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) io: regression_checkpoint_path: path_to_directory_containing_regression_training_model_checkpoints ``` -- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. +- All other parameters for training regression can be changed in the main config file `training_era_cosmo_diffusion.yaml` and config files the main config is referencing (default values are working for debugging purposes). 3. Submit the job with: ```bash @@ -125,12 +88,6 @@ Inside this script set the following: ```bash #SBATCH -A your_compute_group ``` -```bash -srun bash -c " - . ./{your_env_name}/bin/activate - python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml -" -``` 2. Set up the following config files in `src/hirad/conf`: @@ -155,13 +112,42 @@ Finally, from the dataset, subset of time steps can be chosen to do inference fo One way is to list steps under `times:` in format `%Y%m%d-%H%M` for era5_cosmo dataset. -The other way is to specify `times_range:` with three items: first time step (`%Y%m%d-%H%M`), last time step (`%Y%m%d-%H%M`), hour shift (int). Hour shift specifies distance in hours between closest time steps for specific dataset (6 for era_cosmo). +The other way is to specify `times_range:` with three items: first time step (`%Y%m%d-%H%M`), last time step (`%Y%m%d-%H%M`), hour shift (int). Hour shift specifies distance in hours between closest time steps for specific dataset. -By default, inference is done for one time step `20160101-0000` - -- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. +- In `dataset/era_cosmo_inference.yaml` set the `dataset_path` if different from default. Make sure that specified times or times_range is contained in dataset_path. 3. Submit the job with: ```bash sbatch src/hirad/generate.sh -``` \ No newline at end of file +``` + +## MLflow logging + +During training MLflow can be used to log metrics. +Logging config files for regression and diffusion are located in `src/hirad/conf/logging/`. Set `method` to `mlflow` and specify `uri` if you want to log on remote server, otherwise run will be logged locally in output directory. Other options can also be modified here. + +## Installation (Alps uenv/venv) - deprecated + +To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these steps: + +1. **Start the PyTorch user environment**: + ```bash + uenv start pytorch/v2.6.0:v1 --view=default + ``` + +2. **Create a Python virtual environment** (replace `{env_name}` with your desired environment name): + ```bash + python -m venv ./{env_name} + ``` + +3. **Activate the virtual environment**: + ```bash + source ./{env_name}/bin/activate + ``` + +4. **Install project dependencies**: + ```bash + pip install -e . + ``` + +This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure. \ No newline at end of file From 6a63bcd5d85ceb07cbbf30cbd3e0dc9cef83238e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 28 Jul 2025 16:18:46 +0200 Subject: [PATCH 102/189] fix indents in readme --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3e6e5c5..0244ba8 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,10 @@ HiRAD-Gen is short for high-resolution atmospheric downscaling using generative models. This repository contains the code and configuration required to train and use the model. -[Setup clariden/santis](#setup-claridensantis) -[Regression training - clariden/santis](#run-regression-model-training-alps) -[Diffusion training - clariden/santis](#run-diffusion-model-training-alps) -[Inference - clariden/santis](#running-inference-on-alps) +[Setup clariden/santis](#setup-claridensantis) +[Regression training - clariden/santis](#run-regression-model-training-alps) +[Diffusion training - clariden/santis](#run-diffusion-model-training-alps) +[Inference - clariden/santis](#running-inference-on-alps) [Installation - uenv/venv - deprecated](#installation-alps-uenvvenv---deprecated) ## Setup clariden/santis From f9b8c87780b4ffc721abe98deec6af90ba14f5ed Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 28 Jul 2025 16:19:50 +0200 Subject: [PATCH 103/189] simplify cscs.yml to try to get ci/cd to run --- ci/cscs.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index 1668b81..3e50b37 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -19,8 +19,11 @@ test_job: extends: .container-runner-clariden-gh200 image: $PERSIST_IMAGE_NAME script: - - pip install -e . --no-dependencies - - python src/hirad/training/train.py --config-name=training_era_cosmo_regression_test.yaml + - echo 'hello world' + # - pip install -e . --no-dependencies + # - python src/hirad/training/train.py --config-name=training_era_cosmo_regression_test.yaml variables: SLURM_JOB_NUM_NODES: 2 SLURM_NTASKS: 2 + SLURM_ACCOUNT: a161 + SBATCH_ACCOUNT: a161 From 79e31126963a448d5fdebcb4abd219d292749ded Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 29 Jul 2025 18:37:42 +0200 Subject: [PATCH 104/189] logs --- src/hirad/eval/diurnal_cycle_precip.py | 2 +- src/hirad/eval/diurnal_cycle_temp_wind.py | 2 +- src/hirad/eval/percentile99_cycle_precip.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index 8a38078..f6e944a 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -59,7 +59,7 @@ def main(cfg: DictConfig): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - logger.info("Starting diurnal cycle computation") + logger.info("Starting computations for diurnal cycle of precipitation amount and wet-hours") times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") logger.info(f"Loaded {len(times)} timesteps to process") diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index ed35042..e5e5c0f 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -55,7 +55,7 @@ def main(cfg: DictConfig): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - logger.info("Starting diurnal cycle computation for 2m temperature and windspeed") + logger.info("Starting computation for diurnal cyles of 2m temperature and windspeed") times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") logger.info(f"Loaded {len(times)} timesteps to process") diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index f657b21..823a2d4 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -59,7 +59,7 @@ def main(cfg: DictConfig): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - logger.info("Starting 99th-percentile diurnal cycle computation") + logger.info("Starting computation for diurnal cycle of 99th-percentile of precipitation") times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") logger.info(f"Loaded {len(times)} timesteps to process") From 7a117853b6a2573bb7da4677da521c9206644b2c Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 30 Jul 2025 13:00:48 +0200 Subject: [PATCH 105/189] mask out sea points --- src/hirad/eval/diurnal_cycle_precip.py | 17 ++++++++++++++--- src/hirad/eval/diurnal_cycle_temp_wind.py | 9 +++++++++ src/hirad/eval/percentile99_cycle_precip.py | 9 +++++++-- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index f6e944a..31a22ed 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -18,7 +18,6 @@ WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h LOG_INTERVAL = 24 # Log progress every N timesteps - def hour_of(dt: str, fmt: str = "%Y%m%d-%H%M") -> int: return datetime.strptime(dt, fmt).hour @@ -81,18 +80,30 @@ def main(cfg: DictConfig): # Prepare data structures stats = {mode: defaultdict(list) for mode in ['target','baseline','prediction']} wet_stats = {mode: defaultdict(list) for mode in stats} + + # Load land mask + lsm_dat = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy') + lsm=np.flip(lsm_dat.reshape(352,544),0) + # Collect data for idx, ts in enumerate(times, 1): hr = hour_of(ts) target = load(ts, f"{ts}-target")[tp_out] baseline = load(ts, f"{ts}-baseline")[tp_in] + + # Mask target, baseline, and preds where lsm < 0.5 + land_mask = lsm >= 0.5 + target = target * land_mask + baseline = baseline * land_mask + stats['target'][hr].append(target) stats['baseline'][hr].append(baseline) wet_stats['target'][hr].append((target > WET_THRESHOLD).mean()) wet_stats['baseline'][hr].append((baseline > WET_THRESHOLD).mean()) preds = load(ts, f"{ts}-predictions")[:, tp_out] + preds = preds * land_mask for member in preds: stats['prediction'][hr].append(member.mean()) wet_stats['prediction'][hr].append((member > WET_THRESHOLD).mean()) @@ -126,8 +137,8 @@ def main(cfg: DictConfig): wet_lines = [cycle(wet_cycle['target']), cycle(wet_cycle['baseline']), (cycle(wet_mean), cycle(wet_std))] # Log the lines to be plotted (debug) - logger.info(f"amount_lines: {amount_lines}") - logger.info(f"wet_lines: {wet_lines}") + # logger.info(f"amount_lines: {amount_lines}") + # logger.info(f"wet_lines: {wet_lines}") # Plot plot_paths = [] diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index e5e5c0f..5a333b1 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -83,10 +83,18 @@ def main(cfg: DictConfig): stats_temp = {mode: defaultdict(list) for mode in ['target','baseline','prediction']} stats_wind = {mode: defaultdict(list) for mode in ['target','baseline','prediction']} + # Load land-sea mask + lsm_dat = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy') + lsm = np.flip(lsm_dat.reshape(352,544), 0) + land_mask = lsm >= 0.5 + for idx, ts in enumerate(times, 1): hr = hour_of(ts) target = load(ts, f"{ts}-target") baseline = load(ts, f"{ts}-baseline") + # Apply land mask + target = target * land_mask + baseline = baseline * land_mask # 2m temperature stats_temp['target'][hr].append(target[t2m_out].mean()) stats_temp['baseline'][hr].append(baseline[t2m_in].mean()) @@ -95,6 +103,7 @@ def main(cfg: DictConfig): stats_wind['baseline'][hr].append(np.hypot(baseline[u_in], baseline[v_in]).mean()) preds = load(ts, f"{ts}-predictions") + preds = preds * land_mask for member in preds: stats_temp['prediction'][hr].append(member[t2m_out].mean()) stats_wind['prediction'][hr].append(np.hypot(member[u_out], member[v_out]).mean()) diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index 823a2d4..5013b91 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -80,6 +80,11 @@ def main(cfg: DictConfig): tp_out = out_ch['tp']; tp_in = in_ch.get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + # Load land-sea mask + lsm_dat = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy') + lsm = np.flip(lsm_dat.reshape(352,544), 0) + land_mask = lsm >= 0.5 + # Storage for diurnal cycles pct99_mean = {'target': [], 'baseline': [], 'prediction': []} pct99_std = {'target': [], 'baseline': [], 'prediction': []} @@ -89,7 +94,7 @@ def main(cfg: DictConfig): logger.info(f"Processing mode: {mode}") for h in list(range(24)): arrs = [ - load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] + load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] * land_mask for ts in times if hour_of(ts) == h ] stack = np.stack(arrs, axis=0) @@ -114,7 +119,7 @@ def main(cfg: DictConfig): if hour_of(ts) != h: continue preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, ...] - arrs.append(preds[m, tp_out]) # one field + arrs.append(preds[m, tp_out] * land_mask) # apply mask # stack over time and compute 99th percentile at each grid point stack_m = np.stack(arrs, axis=0) f99_m = np.percentile(stack_m, 99, axis=0) From e23452856954c4549fadfab03e2dfad137994819 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 30 Jul 2025 18:48:58 +0200 Subject: [PATCH 106/189] Convert to xarray, still buggy --- src/hirad/eval/diurnal_cycle_precip.py | 203 ++++++++--------- src/hirad/eval/diurnal_cycle_temp_wind.py | 234 +++++++++++--------- src/hirad/eval/percentile99_cycle_precip.py | 22 +- 3 files changed, 237 insertions(+), 222 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index 31a22ed..989ddc5 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -1,56 +1,23 @@ import logging from datetime import datetime from pathlib import Path -from collections import defaultdict import hydra +import matplotlib.pyplot as plt import numpy as np import torch +import xarray as xr from omegaconf import DictConfig, OmegaConf -import matplotlib.pyplot as plt from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range # Constants -CONV_FACTOR = 100 # Convert meters to mm/h +CONV_FACTOR = 100*24 # Convert meters to mm/day WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h LOG_INTERVAL = 24 # Log progress every N timesteps -def hour_of(dt: str, fmt: str = "%Y%m%d-%H%M") -> int: - return datetime.strptime(dt, fmt).hour - - -def compute_ensemble(hourly_values): - hours = sorted(hourly_values) - means = [np.mean(hourly_values[h]) for h in hours] - stds = [np.std(hourly_values[h]) for h in hours] - return hours, means, stds - - -def save_plot(hours, lines, labels, ylabel, title, out_path): - Path(out_path).parent.mkdir(parents=True, exist_ok=True) - plt.figure(figsize=(8,4)) - for data, label in zip(lines, labels): - if isinstance(data, tuple): - mean, std = data - line, = plt.plot(hours, mean, label=label) - plt.fill_between(hours, np.maximum(np.array(mean)-std, 0), np.array(mean)+std, alpha=0.3, color=line.get_color()) - else: - plt.plot(hours, data, label=label) - plt.xlabel('Hour (UTC)') - plt.xticks(range(0,25,3)) - plt.xlim(0,24) - plt.ylabel(ylabel) - plt.title(title) - plt.grid(True) - plt.legend() - plt.tight_layout() - plt.savefig(out_path) - plt.close() - - @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig): # Setup logging @@ -60,6 +27,7 @@ def main(cfg: DictConfig): logger.info("Starting computations for diurnal cycle of precipitation amount and wet-hours") times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] logger.info(f"Loaded {len(times)} timesteps to process") ds_cfg = OmegaConf.to_container(cfg.dataset) @@ -69,90 +37,113 @@ def main(cfg: DictConfig): logger.info("Dataset and sampler initialized") out_root = Path(cfg.generation.io.output_path or './outputs') - load = lambda ts, fn: torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + def load(ts, fn): + return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} - tp_out = out_ch['tp']; tp_in = in_ch.get('tp', tp_out) + tp_out = out_ch['tp'] + tp_in = in_ch.get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") - # Prepare data structures - stats = {mode: defaultdict(list) for mode in ['target','baseline','prediction']} - wet_stats = {mode: defaultdict(list) for mode in stats} - - # Load land mask - lsm_dat = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy') - lsm=np.flip(lsm_dat.reshape(352,544),0) + # Land-sea mask + lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) + land_mask = np.where(lsm_data >= 0.5, 1.0, np.nan) + coords = {"lat": np.arange(land_mask.shape[0]), "lon": np.arange(land_mask.shape[1])} + # Prepare lists to collect DataArrays + target_precip, baseline_precip, pred_precip = [], [], [] + target_wet, baseline_wet, pred_wet = [], [], [] # Collect data for idx, ts in enumerate(times, 1): - hr = hour_of(ts) - target = load(ts, f"{ts}-target")[tp_out] - baseline = load(ts, f"{ts}-baseline")[tp_in] - - # Mask target, baseline, and preds where lsm < 0.5 - land_mask = lsm >= 0.5 - target = target * land_mask - baseline = baseline * land_mask - - stats['target'][hr].append(target) - stats['baseline'][hr].append(baseline) - wet_stats['target'][hr].append((target > WET_THRESHOLD).mean()) - wet_stats['baseline'][hr].append((baseline > WET_THRESHOLD).mean()) - - preds = load(ts, f"{ts}-predictions")[:, tp_out] - preds = preds * land_mask - for member in preds: - stats['prediction'][hr].append(member.mean()) - wet_stats['prediction'][hr].append((member > WET_THRESHOLD).mean()) + dt = datetimes[idx-1] + target = load(ts, f"{ts}-target")[tp_out] * land_mask + baseline = load(ts, f"{ts}-baseline")[tp_in] * land_mask / 6 # 6 becasue 1h -> 6h bug in dataset? + preds = load(ts, f"{ts}-predictions")[:, tp_out, :, :] * land_mask + + # DataArrays for spatial mean + da_target = xr.DataArray(target, dims=("lat","lon"), coords=coords) + da_baseline = xr.DataArray(baseline, dims=("lat","lon"), coords=coords) + da_preds = xr.DataArray(preds, dims=("member","lat","lon"), coords={"member": np.arange(preds.shape[0]), **coords}) + + # Spatial mean + target_precip.append(da_target.mean(dim=("lat","lon")).assign_coords(time=dt)) + baseline_precip.append(da_baseline.mean(dim=("lat","lon")).assign_coords(time=dt)) + pred_precip.append(da_preds.mean(dim=("lat","lon")).assign_coords(time=dt)) + + # Wet-hour fraction (percentage) + target_wet.append(((da_target / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) + baseline_wet.append(((da_baseline / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) + pred_wet.append(((da_preds / 24> WET_THRESHOLD).mean(dim=("lat","lon")).assign_coords(time=dt))) if idx % LOG_INTERVAL == 0 or idx == len(times): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") - # Compute hourly means - mean_cycle = { - mode: [np.mean(stats[mode][h]) for h in sorted(stats[mode])] - for mode in ['target','baseline'] - } - wet_cycle = { - mode: [np.mean(wet_stats[mode][h]) * 100. for h in sorted(wet_stats[mode])] - for mode in ['target','baseline'] - } - logger.info("Computed hourly mean and wet-cycle statistics") - - # Ensemble cycles (mean ± std) - hrs, pred_mean, pred_std = compute_ensemble(stats['prediction']) - _, wet_mean, wet_std = compute_ensemble(wet_stats['prediction']) - # Multiply ensemble wet-hour statistics by 100 for percentage - wet_mean = [v * 100. for v in wet_mean] - wet_std = [v * 100. for v in wet_std] - logger.info("Computed ensemble statistics") - - # Prepare cyclic series - cycle = lambda x: x + [x[0]] - hrs_c = hrs + [24] - amount_lines = [cycle(mean_cycle['target']), cycle(mean_cycle['baseline']), (cycle(pred_mean), cycle(pred_std))] - wet_lines = [cycle(wet_cycle['target']), cycle(wet_cycle['baseline']), (cycle(wet_mean), cycle(wet_std))] - - # Log the lines to be plotted (debug) - # logger.info(f"amount_lines: {amount_lines}") - # logger.info(f"wet_lines: {wet_lines}") - - # Plot - plot_paths = [] - fn1 = out_root/'diurnal_cycle_precip_amount.png' - save_plot(hrs_c, amount_lines, ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], 'Rain Rate (mm/h)', - 'Diurnal Cycle of Precip Amount', fn1) - plot_paths.append(fn1) - - fn2 = out_root/'diurnal_cycle_precip_wethours.png' - save_plot(hrs_c, wet_lines, ['COSMO-2','ERA5','Pred Mean ± Std'], 'Wet-Hour Fraction [%]', - 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', fn2) - plot_paths.append(fn2) - - logger.info(f"Plots saved: {', '.join(str(p) for p in plot_paths)}") + # Helper to concat and compute diurnal stats + def concat_and_group(list_of_da, is_member=False, scale=1.0): + da = xr.concat(list_of_da, dim="time").groupby("time.hour") + if is_member: + mean = da.mean(dim=[d for d in da.dims if d in ['time', 'member']]) * scale + std = da.std(dim=[d for d in da.dims if d in ['time', 'member']]) * scale + else: + mean = da.mean(dim="time") * scale + std = None + return mean, std + + # Compute diurnal means and stds + amount_target_mean, _ = concat_and_group(target_precip) + amount_baseline_mean, _ = concat_and_group(baseline_precip) + amount_pred_mean, amount_pred_std = concat_and_group(pred_precip, is_member=True) + + wet_target_mean, _ = concat_and_group(target_wet, scale=100.0) + wet_baseline_mean, _ = concat_and_group(baseline_wet, scale=100.0) + wet_pred_mean, wet_pred_std = concat_and_group(pred_wet, is_member=True, scale=100.0) + + # Plot helper + def save_plot(hour, means, stds, labels, ylabel, title, out_path): + hrs = np.concatenate([hour.values, [24]]) + plt.figure(figsize=(8,4)) + for mean, std, label in zip(means, stds, labels): + vals = np.append(mean.values, mean.values[0]) + line, = plt.plot(hrs, vals, label=label) + if std is not None: + stdv = np.append(std.values, std.values[0]) + plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path) + plt.close() + + # Generate plots + save_plot( + amount_target_mean.hour, + [amount_target_mean, amount_baseline_mean, amount_pred_mean], + [None, None, amount_pred_std], + ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], + 'Precipitation (mm/day)', + 'Diurnal Cycle of Precip Amount', + out_root / 'diurnal_cycle_precip_amount.png' + ) + save_plot( + wet_target_mean.hour, + [wet_target_mean, wet_baseline_mean, wet_pred_mean], + [None, None, wet_pred_std], + ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], + 'Wet-Hour Fraction [%]', + 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', + out_root / 'diurnal_cycle_precip_wethours.png' + ) + + logger.info("Plots saved.") if __name__ == '__main__': main() diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index 5a333b1..d03e62c 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -1,13 +1,13 @@ import logging from datetime import datetime from pathlib import Path -from collections import defaultdict import hydra +import matplotlib.pyplot as plt import numpy as np import torch +import xarray as xr from omegaconf import DictConfig, OmegaConf -import matplotlib.pyplot as plt from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager @@ -15,60 +15,26 @@ LOG_INTERVAL = 24 -def hour_of(dt: str, fmt: str = "%Y%m%d-%H%M") -> int: - return datetime.strptime(dt, fmt).hour - -def compute_ensemble(hourly_values): - hours = sorted(hourly_values) - means = [np.mean(hourly_values[h]) for h in hours] - stds = [np.std(hourly_values[h]) for h in hours] - return hours, means, stds - -def save_plot(hours, lines, labels, ylabel, title, out_path): - Path(out_path).parent.mkdir(parents=True, exist_ok=True) - plt.figure(figsize=(8,4)) - for data, label in zip(lines, labels): - if isinstance(data, tuple): # (mean, std) - mean, std = data - line, = plt.plot(hours, mean, label=label) - plt.fill_between(hours, - np.array(mean)-std, - np.array(mean)+std, - alpha=0.3, - color=line.get_color()) - else: - plt.plot(hours, data, label=label) - plt.xlabel('Hour (UTC)') - plt.xticks(range(0,25,3)) - plt.xlim(0,24) - plt.ylabel(ylabel) - plt.title(title) - plt.grid(True) - plt.legend() - plt.tight_layout() - plt.savefig(out_path) - plt.close() - @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig): + # Initialize DistributedManager.initialize() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - logger.info("Starting computation for diurnal cyles of 2m temperature and windspeed") + # Load times + logger.info("Starting computation for diurnal cycles of 2m temperature and windspeed") times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] logger.info(f"Loaded {len(times)} timesteps to process") + # Dataset ds_cfg = OmegaConf.to_container(cfg.dataset) dataset, _ = get_dataset_and_sampler_inference( ds_cfg, times, cfg.generation.get('has_lead_time', False) ) - logger.info("Dataset and sampler initialized") - out_root = Path(cfg.generation.io.output_path or './outputs') - load = lambda ts, fn: torch.load(out_root/ts/fn, weights_only=False) - - # Find channel indices + # Indices for channels out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} t2m_out = out_ch.get('2t', out_ch.get('t2m')) @@ -77,76 +43,136 @@ def main(cfg: DictConfig): v_out = out_ch.get('10v') u_in = in_ch.get('10u', u_out) v_in = in_ch.get('10v', v_out) - logger.info(f"2T channel indices - output: {t2m_out}, input: {t2m_in}") - logger.info(f"10U/10V channel indices - output: {u_out}/{v_out}, input: {u_in}/{v_in}") - stats_temp = {mode: defaultdict(list) for mode in ['target','baseline','prediction']} - stats_wind = {mode: defaultdict(list) for mode in ['target','baseline','prediction']} + # Output path + out_root = Path(cfg.generation.io.output_path or './outputs') + load = lambda ts, fn: torch.load(out_root/ts/fn, weights_only=False) + + # Land-sea mask + lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) + land_mask = np.where(lsm_data >= 0.5, 1.0, np.nan) + coords = {"lat": np.arange(land_mask.shape[0]), "lon": np.arange(land_mask.shape[1])} - # Load land-sea mask - lsm_dat = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy') - lsm = np.flip(lsm_dat.reshape(352,544), 0) - land_mask = lsm >= 0.5 + # Prepare lists to collect DataArrays + target_temp, baseline_temp, pred_temp = [], [], [] + target_wind, baseline_wind, pred_wind = [], [], [] + # Loop over timestamps for idx, ts in enumerate(times, 1): - hr = hour_of(ts) - target = load(ts, f"{ts}-target") - baseline = load(ts, f"{ts}-baseline") - # Apply land mask - target = target * land_mask - baseline = baseline * land_mask - # 2m temperature - stats_temp['target'][hr].append(target[t2m_out].mean()) - stats_temp['baseline'][hr].append(baseline[t2m_in].mean()) - # windspeed using np.hypot - stats_wind['target'][hr].append(np.hypot(target[u_out], target[v_out]).mean()) - stats_wind['baseline'][hr].append(np.hypot(baseline[u_in], baseline[v_in]).mean()) - - preds = load(ts, f"{ts}-predictions") - preds = preds * land_mask - for member in preds: - stats_temp['prediction'][hr].append(member[t2m_out].mean()) - stats_wind['prediction'][hr].append(np.hypot(member[u_out], member[v_out]).mean()) + dt = datetimes[idx-1] + + # Load and apply land mask + target = load(ts, f"{ts}-target") * land_mask + baseline = load(ts, f"{ts}-baseline") * land_mask + predictions = load(ts, f"{ts}-predictions") * land_mask + + # Wrap into DataArrays (convert temperature to Celsius inline) + da_tgt_temp = xr.DataArray( + target[t2m_out] - 273.15, dims=("lat","lon"), coords=coords + ) + da_bsl_temp = xr.DataArray( + baseline[t2m_in] - 273.15, dims=("lat","lon"), coords=coords + ) + tgt_wind = np.hypot(target[u_out], target[v_out]) + bsl_wind = np.hypot(baseline[u_in], baseline[v_in]) + da_tgt_wind = xr.DataArray(tgt_wind, dims=("lat","lon"), coords=coords) + da_bsl_wind = xr.DataArray(bsl_wind, dims=("lat","lon"), coords=coords) + + da_pred_members_temp = xr.DataArray( + predictions[:, t2m_out, :, :] - 273.15, dims=("member","lat","lon"), + coords={"member": np.arange(predictions.shape[0]), **coords} + ) + da_pred_members_wind = xr.DataArray( + np.hypot(predictions[:, u_out, :, :], predictions[:, v_out, :, :]), + dims=("member","lat","lon"), coords={"member": np.arange(predictions.shape[0]), **coords} + ) + + # Compute spatial mean and assign time coordinate + target_temp.append( + da_tgt_temp.mean(dim=("lat","lon")).assign_coords(time=dt) + ) + baseline_temp.append( + da_bsl_temp.mean(dim=("lat","lon")).assign_coords(time=dt) + ) + pred_temp.append( + da_pred_members_temp.mean(dim=("lat","lon")).assign_coords(time=dt) + ) + target_wind.append( + da_tgt_wind.mean(dim=("lat","lon")).assign_coords(time=dt) + ) + baseline_wind.append( + da_bsl_wind.mean(dim=("lat","lon")).assign_coords(time=dt) + ) + pred_wind.append( + da_pred_members_wind.mean(dim=("lat","lon")).assign_coords(time=dt) + ) if idx % LOG_INTERVAL == 0 or idx == len(times): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") - # Compute hourly means - mean_cycle_temp = { - mode: [np.mean(stats_temp[mode][h]) for h in sorted(stats_temp[mode])] - for mode in ['target','baseline'] - } - mean_cycle_wind = { - mode: [np.mean(stats_wind[mode][h]) for h in sorted(stats_wind[mode])] - for mode in ['target','baseline'] - } - logger.info("Computed hourly mean statistics") - - # Ensemble cycles (mean ± std) - hrs, pred_mean_temp, pred_std_temp = compute_ensemble(stats_temp['prediction']) - hrs_w, pred_mean_wind, pred_std_wind = compute_ensemble(stats_wind['prediction']) - logger.info("Computed ensemble statistics") - - # Prepare cyclic series - cycle = lambda x: x + [x[0]] - hrs_c = hrs + [24] - hrs_w_c = hrs_w + [24] - temp_lines = [cycle(mean_cycle_temp['target']), cycle(mean_cycle_temp['baseline']), (cycle(pred_mean_temp), cycle(pred_std_temp))] - wind_lines = [cycle(mean_cycle_wind['target']), cycle(mean_cycle_wind['baseline']), (cycle(pred_mean_wind), cycle(pred_std_wind))] - - # Plot - plot_paths = [] - fn1 = out_root/'diurnal_cycle_2t.png' - save_plot(hrs_c, temp_lines, ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], '2m Temperature [K]', - 'Diurnal Cycle of 2m Temperature', fn1) - plot_paths.append(fn1) - - fn2 = out_root/'diurnal_cycle_windspeed.png' - save_plot(hrs_w_c, wind_lines, ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], 'Windspeed [m/s]', - 'Diurnal Cycle of Windspeed', fn2) - plot_paths.append(fn2) - - logger.info(f"Plots saved: {', '.join(str(p) for p in plot_paths)}") + # Helper to concat and compute diurnal stats + def concat_and_group(list_of_da): + da = xr.concat(list_of_da, dim="time").groupby("time.hour") + mean = da.mean(dim=[d for d in da.dims if d in ['time', 'member']]) + std = da.std(dim=[d for d in da.dims if d in ['time', 'member']]) + return mean, std + + # Compute diurnal means and stds + temp_target_mean, _ = concat_and_group(target_temp) + temp_baseline_mean, _ = concat_and_group(baseline_temp) + temp_pred_mean, temp_pred_std = concat_and_group(pred_temp) + + wind_target_mean, _ = concat_and_group(target_wind) + wind_baseline_mean, _ = concat_and_group(baseline_wind) + wind_pred_mean, wind_pred_std = concat_and_group(pred_wind) + + # Plot helper + def save_plot(hour, means, stds, labels, ylabel, title, out_path): + hrs = np.concatenate([hour.values, [24]]) + plt.figure(figsize=(8,4)) + for mean, std, label in zip(means, stds, labels): + vals = np.append(mean.values, mean.values[0]) + line, = plt.plot(hrs, vals, label=label) + if std is not None: + stdv = np.append(std.values, std.values[0]) + plt.fill_between(hrs, vals - stdv, vals + stdv, color=line.get_color(), alpha=0.3) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path) + plt.close() + + # Generate plots + save_plot( + temp_target_mean.hour, + [temp_target_mean, temp_baseline_mean, temp_pred_mean], + [None, None, temp_pred_std], + ['COSMO-2', 'ERA5', 'CorrDiff ± Std(Members)'], + '2m Temperature [°C]', + 'Diurnal Cycle of 2m Temperature', + out_root / 'diurnal_cycle_2t.png' + ) + save_plot( + wind_target_mean.hour, + [wind_target_mean, wind_baseline_mean, wind_pred_mean], + [None, None, wind_pred_std], + ['COSMO-2', 'ERA5', 'CorrDiff ± Std(Members)'], + 'Windspeed [m/s]', + 'Diurnal Cycle of Windspeed', + out_root / 'diurnal_cycle_windspeed.png' + ) + + logger.info("Plots saved.") + + plt.imshow(target[t2m_out], cmap='viridis') + plt.colorbar(label='2m Temperature [°C]') + plt.savefig('lsm.png') if __name__ == '__main__': main() diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index 5013b91..1f4e479 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -20,7 +20,7 @@ from hirad.utils.function_utils import get_time_from_range # Constants -CONV_FACTOR = 100 # Convert meters to mm/h +CONV_FACTOR = 100 * 24 # Convert meters to mm/day LOG_INTERVAL = 24 # Log progress every N timesteps @@ -80,10 +80,9 @@ def main(cfg: DictConfig): tp_out = out_ch['tp']; tp_in = in_ch.get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") - # Load land-sea mask - lsm_dat = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy') - lsm = np.flip(lsm_dat.reshape(352,544), 0) - land_mask = lsm >= 0.5 + # Land-sea mask + lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) + land_mask = np.where(lsm_data >= 0.5, 1.0, np.nan) # Storage for diurnal cycles pct99_mean = {'target': [], 'baseline': [], 'prediction': []} @@ -99,8 +98,7 @@ def main(cfg: DictConfig): ] stack = np.stack(arrs, axis=0) f99 = np.percentile(stack, 99, axis=0) - pct99_mean[mode].append(f99.mean()) - pct99_std[mode].append(np.std(f99, axis=None)) + pct99_mean[mode].append(np.nanmean(f99) if mode == 'target' else np.nanmean(f99) / 6.0) # / 6 because bug in dataset? del arrs, stack, f99 # -- Predictions: compute per hour per member, then mean+std across members -- @@ -123,10 +121,10 @@ def main(cfg: DictConfig): # stack over time and compute 99th percentile at each grid point stack_m = np.stack(arrs, axis=0) f99_m = np.percentile(stack_m, 99, axis=0) - mem_f99.append(f99_m.mean()) + mem_f99.append(np.nanmean(f99_m)) # mean over grid points # ensemble-level mean and std over member-wise percentiles - pct99_mean['prediction'].append(np.mean(mem_f99)) - pct99_std['prediction'].append(np.std(mem_f99, axis=None)) + pct99_mean['prediction'].append(np.nanmean(mem_f99)) + pct99_std['prediction'].append(np.std(pct99_mean['prediction'])) # clean up per-hour buffers del mem_f99, stack_m, f99_m @@ -135,7 +133,7 @@ def main(cfg: DictConfig): hrs_c = list(range(24)) + [list(range(24))[0] + 24] pct99_lines = [ cycle_fn(pct99_mean['target']), - cycle_fn(pct99_mean['baseline']), + cycle_fn(pct99_mean['baseline']), # 6 becasue bug in dataset? ( cycle_fn(pct99_mean['prediction']), cycle_fn(pct99_std['prediction']) @@ -148,7 +146,7 @@ def main(cfg: DictConfig): hrs_c, pct99_lines, ['COSMO-2','ERA5','CorrDiff 99th Pct ± Std'], - 'Rain Rate (mm/h)', + 'Precipitation (mm/day)', 'Diurnal Cycle of 99th-Percentile Precipitation', fn ) From 12702b67ec24cae514b9ca586c9505efa9f6ddee Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 31 Jul 2025 11:55:52 +0200 Subject: [PATCH 107/189] comment --- src/hirad/eval/diurnal_cycle_precip.py | 5 ++--- src/hirad/eval/percentile99_cycle_precip.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index 989ddc5..e5c9560 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -14,7 +14,7 @@ from hirad.utils.function_utils import get_time_from_range # Constants -CONV_FACTOR = 100*24 # Convert meters to mm/day +CONV_FACTOR = 100*24 # Convert meters to mm/day WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h LOG_INTERVAL = 24 # Log progress every N timesteps @@ -34,7 +34,6 @@ def main(cfg: DictConfig): dataset, _ = get_dataset_and_sampler_inference( ds_cfg, times, cfg.generation.get('has_lead_time', False) ) - logger.info("Dataset and sampler initialized") out_root = Path(cfg.generation.io.output_path or './outputs') def load(ts, fn): @@ -60,7 +59,7 @@ def load(ts, fn): for idx, ts in enumerate(times, 1): dt = datetimes[idx-1] target = load(ts, f"{ts}-target")[tp_out] * land_mask - baseline = load(ts, f"{ts}-baseline")[tp_in] * land_mask / 6 # 6 becasue 1h -> 6h bug in dataset? + baseline = load(ts, f"{ts}-baseline")[tp_in] * land_mask / 6. # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset preds = load(ts, f"{ts}-predictions")[:, tp_out, :, :] * land_mask # DataArrays for spatial mean diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index 1f4e479..51fc28b 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -98,7 +98,7 @@ def main(cfg: DictConfig): ] stack = np.stack(arrs, axis=0) f99 = np.percentile(stack, 99, axis=0) - pct99_mean[mode].append(np.nanmean(f99) if mode == 'target' else np.nanmean(f99) / 6.0) # / 6 because bug in dataset? + pct99_mean[mode].append(np.nanmean(f99) if mode == 'target' else np.nanmean(f99) / 6.0) # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset del arrs, stack, f99 # -- Predictions: compute per hour per member, then mean+std across members -- From b537630c09969e38209f18eca331abd8e42c108d Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 31 Jul 2025 14:01:48 +0200 Subject: [PATCH 108/189] fix bug in std --- src/hirad/eval/diurnal_cycle_precip.py | 13 +++++++------ src/hirad/eval/diurnal_cycle_temp_wind.py | 11 ++++++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index e5c9560..2bc31c0 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -62,7 +62,7 @@ def load(ts, fn): baseline = load(ts, f"{ts}-baseline")[tp_in] * land_mask / 6. # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset preds = load(ts, f"{ts}-predictions")[:, tp_out, :, :] * land_mask - # DataArrays for spatial mean + # DataArrays for spatial means at each timestep da_target = xr.DataArray(target, dims=("lat","lon"), coords=coords) da_baseline = xr.DataArray(baseline, dims=("lat","lon"), coords=coords) da_preds = xr.DataArray(preds, dims=("member","lat","lon"), coords={"member": np.arange(preds.shape[0]), **coords}) @@ -72,7 +72,7 @@ def load(ts, fn): baseline_precip.append(da_baseline.mean(dim=("lat","lon")).assign_coords(time=dt)) pred_precip.append(da_preds.mean(dim=("lat","lon")).assign_coords(time=dt)) - # Wet-hour fraction (percentage) + # Wet-hour fraction, i.e., freq(precip) > WET_THRESHOLD target_wet.append(((da_target / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) baseline_wet.append(((da_baseline / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) pred_wet.append(((da_preds / 24> WET_THRESHOLD).mean(dim=("lat","lon")).assign_coords(time=dt))) @@ -84,10 +84,11 @@ def load(ts, fn): def concat_and_group(list_of_da, is_member=False, scale=1.0): da = xr.concat(list_of_da, dim="time").groupby("time.hour") if is_member: - mean = da.mean(dim=[d for d in da.dims if d in ['time', 'member']]) * scale - std = da.std(dim=[d for d in da.dims if d in ['time', 'member']]) * scale + timmean = da.mean(dim='time') * scale + mean = timmean.mean(dim='member') + std = timmean.std(dim='member') else: - mean = da.mean(dim="time") * scale + mean = da.mean(dim='time') * scale std = None return mean, std @@ -96,7 +97,7 @@ def concat_and_group(list_of_da, is_member=False, scale=1.0): amount_baseline_mean, _ = concat_and_group(baseline_precip) amount_pred_mean, amount_pred_std = concat_and_group(pred_precip, is_member=True) - wet_target_mean, _ = concat_and_group(target_wet, scale=100.0) + wet_target_mean, _ = concat_and_group(target_wet, scale=100.0) # scale to percentage wet_baseline_mean, _ = concat_and_group(baseline_wet, scale=100.0) wet_pred_mean, wet_pred_std = concat_and_group(pred_wet, is_member=True, scale=100.0) diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index d03e62c..43a33b8 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -111,10 +111,15 @@ def main(cfg: DictConfig): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") # Helper to concat and compute diurnal stats - def concat_and_group(list_of_da): + def concat_and_group(list_of_da, is_member=False, scale=1.0): da = xr.concat(list_of_da, dim="time").groupby("time.hour") - mean = da.mean(dim=[d for d in da.dims if d in ['time', 'member']]) - std = da.std(dim=[d for d in da.dims if d in ['time', 'member']]) + if is_member: + timmean = da.mean(dim='time') * scale + mean = timmean.mean(dim='member') + std = timmean.std(dim='member') + else: + mean = da.mean(dim='time') * scale + std = None return mean, std # Compute diurnal means and stds From 1873b5833fde39d677c53918d73eb2404144dfff Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 31 Jul 2025 16:43:03 +0200 Subject: [PATCH 109/189] add patched diffusion configs --- .../era_cosmo_training_patched.yaml | 24 ++++ .../model/era_cosmo_diffusion_patched.yaml | 13 ++ .../training/era_cosmo_diffusion_patched.yaml | 53 +++++++ .../training_era_cosmo_diffusion_patched.yaml | 27 ++++ src/hirad/inference/generate.py | 2 +- src/hirad/inference/stochastic_sampler.py | 3 +- src/hirad/training/train.py | 130 +++++++++--------- 7 files changed, 182 insertions(+), 70 deletions(-) create mode 100644 src/hirad/conf/generation/era_cosmo_training_patched.yaml create mode 100644 src/hirad/conf/model/era_cosmo_diffusion_patched.yaml create mode 100644 src/hirad/conf/training/era_cosmo_diffusion_patched.yaml create mode 100644 src/hirad/conf/training_era_cosmo_diffusion_patched.yaml diff --git a/src/hirad/conf/generation/era_cosmo_training_patched.yaml b/src/hirad/conf/generation/era_cosmo_training_patched.yaml new file mode 100644 index 0000000..e6e2d7c --- /dev/null +++ b/src/hirad/conf/generation/era_cosmo_training_patched.yaml @@ -0,0 +1,24 @@ +defaults: + - ../sampler@sampler: stochastic + - ../dataset@dataset: era_cosmo_inference + +num_ensembles: 16 + # Number of ensembles to generate per input + +patching: True +# Use patch-based sampling +overlap_pix: 4 +# Number of overlapping pixels between adjacent patches +boundary_pix: 2 +# Number of boundary pixels to be cropped out. 2 is recommended to address the boundary +# artifact. +patch_shape_x: 128 +patch_shape_y: 128 + +times_range: null +times: + - 20200926-1800 + # - 20200927-0000 + +perf: + num_writer_workers: 10 \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_diffusion_patched.yaml b/src/hirad/conf/model/era_cosmo_diffusion_patched.yaml new file mode 100644 index 0000000..9362932 --- /dev/null +++ b/src/hirad/conf/model/era_cosmo_diffusion_patched.yaml @@ -0,0 +1,13 @@ +name: patched_diffusion + # Name of the preconditioner +hr_mean_conditioning: True + # High-res mean (regression's output) as additional condition + +# Standard model parameters. +# Standard model parameters. +model_args: + gridtype: "learnable" + # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'. + # Controls how positional information is encoded. + N_grid_channels: 100 + # Number of channels for positional grid embeddings \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion_patched.yaml b/src/hirad/conf/training/era_cosmo_diffusion_patched.yaml new file mode 100644 index 0000000..679afd6 --- /dev/null +++ b/src/hirad/conf/training/era_cosmo_diffusion_patched.yaml @@ -0,0 +1,53 @@ +# Hyperparameters +hp: + training_duration: 1000000 + # Training duration based on the number of processed samples + total_batch_size: "auto" + # Total batch size + batch_size_per_gpu: 10 + # Batch size per GPU + lr: 0.0002 + # Learning rate + grad_clip_threshold: 1e6 + # no gradient clipping for defualt non-patch-based training + lr_decay: 0.7 + # LR decay rate + lr_rampup: 1000000 + # Rampup for learning rate, in number of samples + lr_decay_rate: 5e5 + # Learning rate decay threshold in number of samples, applied every lr_decay_rate samples. + patch_shape_x: 128 + patch_shape_y: 128 + # Patch size. Patch training is used if these dimensions differ from + # img_shape_x and img_shape_y. + patch_num: 7 + # Number of patches from a single sample. Total number of patches is + # patch_num * batch_size_global. + max_patch_per_gpu: 100 + # Maximum number of pataches a gpu can hold + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 10 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# I/O +io: + regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression + # Where to load the regression checkpoint + print_progress_freq: 1000 + # How often to print progress + save_checkpoint_freq: 500000 + # How often to save the checkpoints, measured in number of processed samples + visualization_freq: 250000 + # how often to visualize network outputs + validation_freq: 50000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 100 + # how many loss evaluations are used to compute the validation loss per checkpoint + checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion_patched.yaml b/src/hirad/conf/training_era_cosmo_diffusion_patched.yaml new file mode 100644 index 0000000..38baba0 --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_diffusion_patched.yaml @@ -0,0 +1,27 @@ +hydra: + job: + chdir: true + name: diffusion_era5_cosmo_patched + run: + dir: /capstor/scratch/cscs/pstamenk/outputs/training/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_diffusion_patched + + - model_size/normal + + # Training + - training/era_cosmo_diffusion_patched + + # Inference visualization + - generation/era_cosmo_training_patched + + # Logging + - logging/era_cosmo_diffusion \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 2320656..a553af8 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -131,7 +131,7 @@ def main(cfg: DictConfig) -> None: else: net_reg = None - # Reset since we are using a different mode. + # Reset since we are using a different mode. if cfg.generation.perf.use_torch_compile: torch._dynamo.reset() # Only compile residual network diff --git a/src/hirad/inference/stochastic_sampler.py b/src/hirad/inference/stochastic_sampler.py index 606c911..24c5f7a 100644 --- a/src/hirad/inference/stochastic_sampler.py +++ b/src/hirad/inference/stochastic_sampler.py @@ -180,10 +180,11 @@ def stochastic_sampler( # input and position padding + patching if patching: + # print(f"Input for generator beofre patching {x_lr.shape}") # Patched conditioning [x_lr, mean_hr] # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) x_lr = patching.apply(input=x_lr, additional_input=img_lr) - + # print(f"Input for generator after patching {x_lr.shape}") # Function to select the correct positional embedding for each patch def patch_embedding_selector(emb): # emb: (N_pe, image_shape_y, image_shape_x) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 0573929..519845b 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -82,7 +82,7 @@ def main(cfg: DictConfig) -> None: logger0 = RankZeroLoggingWrapper(logger, dist) # rank 0 logger dataset_cfg = OmegaConf.to_container(cfg.dataset) - if hasattr(cfg.dataset, "validation_path"): + if hasattr(cfg.dataset, "validation_path") and cfg.dataset.validation_path is not None: train_test_split = True else: train_test_split = False @@ -222,19 +222,12 @@ def main(cfg: DictConfig) -> None: if hasattr(cfg.model, "model_args"): # override defaults from config file model_args.update(OmegaConf.to_container(cfg.model.model_args)) - use_torch_compile = False - use_apex_gn = False - profile_mode = False + use_torch_compile = getattr(cfg.training.perf, "torch_compile", False) + use_apex_gn = getattr(cfg.training.perf, "use_apex_gn", False) + profile_mode = getattr(cfg.training.perf, "profile_mode", False) - if hasattr(cfg.training.perf, "torch_compile"): - use_torch_compile = cfg.training.perf.torch_compile - if hasattr(cfg.training.perf, "use_apex_gn"): - use_apex_gn = cfg.training.perf.use_apex_gn - model_args["use_apex_gn"] = use_apex_gn - - if hasattr(cfg.training.perf, "profile_mode"): - profile_mode = cfg.training.perf.profile_mode - model_args["profile_mode"] = profile_mode + model_args["use_apex_gn"] = use_apex_gn + model_args["profile_mode"] = profile_mode if enable_amp: model_args["amp_mode"] = enable_amp @@ -622,7 +615,7 @@ def main(cfg: DictConfig) -> None: fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] fields += [ - f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.4f}" ] fields += [ f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" @@ -787,62 +780,63 @@ def main(cfg: DictConfig) -> None: times = visualization_dataset.time() time_index = -1 output_paths_list = [] - for index, (img_clean_viz, img_lr_viz, *lead_time_label_viz) in enumerate( - iter(visualization_data_loader) - ): - time_index += 1 - logger0.info(f"starting index: {time_index}") + with torch.no_grad(): + for index, (img_clean_viz, img_lr_viz, *lead_time_label_viz) in enumerate( + iter(visualization_data_loader) + ): + time_index += 1 + logger0.info(f"starting index: {time_index}") + + # continue + if lead_time_label_viz: + lead_time_label_viz = lead_time_label_viz[0].to(dist.device).contiguous() + else: + lead_time_label_viz = None - # continue - if lead_time_label_viz: - lead_time_label_viz = lead_time_label_viz[0].to(dist.device).contiguous() - else: - lead_time_label_viz = None - - if use_apex_gn: - img_clean_viz = img_clean_viz.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - img_lr_viz = img_lr_viz.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - else: - img_clean_viz = ( - img_clean_viz.to(dist.device) - .to(input_dtype) - .contiguous() - ) - img_lr_viz = ( - img_lr_viz.to(dist.device) - .to(input_dtype) - .contiguous() - ) - with torch.autocast( - "cuda", dtype=amp_dtype, enabled=enable_amp - ): - image_pred_viz, image_reg_viz = generator.generate(img_lr_viz, lead_time_label_viz) - if dist.rank == 0: - # write out data in a seperate thread so we don't hold up inferencing - output_path = os.path.join(visualization_dir, f"{cur_nimg}_{times[visualization_sampler[time_index]]}") - output_paths_list.append(output_path) - if dist.rank==0 and not os.path.exists(output_path): - os.makedirs(output_path) - writer_threads.append( - writer_executor.submit( - save_images, - output_path, - times[visualization_sampler[time_index]], - visualization_dataset, - image_pred_viz.cpu().numpy(), - img_clean_viz.cpu().numpy(), - img_lr_viz.cpu().numpy(), - image_reg_viz.cpu().numpy() if image_reg_viz is not None else None, + if use_apex_gn: + img_clean_viz = img_clean_viz.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_viz = img_lr_viz.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean_viz = ( + img_clean_viz.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_viz = ( + img_lr_viz.to(dist.device) + .to(input_dtype) + .contiguous() + ) + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + image_pred_viz, image_reg_viz = generator.generate(img_lr_viz, lead_time_label_viz) + if dist.rank == 0: + # write out data in a seperate thread so we don't hold up inferencing + output_path = os.path.join(visualization_dir, f"{cur_nimg}_{times[visualization_sampler[time_index]]}") + output_paths_list.append(output_path) + if dist.rank==0 and not os.path.exists(output_path): + os.makedirs(output_path) + writer_threads.append( + writer_executor.submit( + save_images, + output_path, + times[visualization_sampler[time_index]], + visualization_dataset, + image_pred_viz.cpu().numpy(), + img_clean_viz.cpu().numpy(), + img_lr_viz.cpu().numpy(), + image_reg_viz.cpu().numpy() if image_reg_viz is not None else None, + ) ) - ) # make sure all the workers are done writing if dist.rank == 0: for thread in list(writer_threads): From 79c6f6c4777b105b229aa41afe3a03bdfdd545b4 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 31 Jul 2025 18:39:46 +0200 Subject: [PATCH 110/189] small fixes --- src/hirad/eval/diurnal_cycle_precip.py | 1 - src/hirad/eval/diurnal_cycle_temp_wind.py | 12 ++++-------- src/hirad/eval/percentile99_cycle_precip.py | 4 ++-- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index 2bc31c0..50961e3 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -101,7 +101,6 @@ def concat_and_group(list_of_da, is_member=False, scale=1.0): wet_baseline_mean, _ = concat_and_group(baseline_wet, scale=100.0) wet_pred_mean, wet_pred_std = concat_and_group(pred_wet, is_member=True, scale=100.0) - # Plot helper def save_plot(hour, means, stds, labels, ylabel, title, out_path): hrs = np.concatenate([hour.values, [24]]) plt.figure(figsize=(8,4)) diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index 43a33b8..8e1ff30 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -125,13 +125,12 @@ def concat_and_group(list_of_da, is_member=False, scale=1.0): # Compute diurnal means and stds temp_target_mean, _ = concat_and_group(target_temp) temp_baseline_mean, _ = concat_and_group(baseline_temp) - temp_pred_mean, temp_pred_std = concat_and_group(pred_temp) + temp_pred_mean, temp_pred_std = concat_and_group(pred_temp, is_member=True) wind_target_mean, _ = concat_and_group(target_wind) wind_baseline_mean, _ = concat_and_group(baseline_wind) - wind_pred_mean, wind_pred_std = concat_and_group(pred_wind) + wind_pred_mean, wind_pred_std = concat_and_group(pred_wind, is_member=True) - # Plot helper def save_plot(hour, means, stds, labels, ylabel, title, out_path): hrs = np.concatenate([hour.values, [24]]) plt.figure(figsize=(8,4)) @@ -140,7 +139,7 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): line, = plt.plot(hrs, vals, label=label) if std is not None: stdv = np.append(std.values, std.values[0]) - plt.fill_between(hrs, vals - stdv, vals + stdv, color=line.get_color(), alpha=0.3) + plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3) plt.xlabel('Hour (UTC)') plt.xticks(range(0,25,3)) plt.xlim(0,24) @@ -163,6 +162,7 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): 'Diurnal Cycle of 2m Temperature', out_root / 'diurnal_cycle_2t.png' ) + save_plot( wind_target_mean.hour, [wind_target_mean, wind_baseline_mean, wind_pred_mean], @@ -175,9 +175,5 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): logger.info("Plots saved.") - plt.imshow(target[t2m_out], cmap='viridis') - plt.colorbar(label='2m Temperature [°C]') - plt.savefig('lsm.png') - if __name__ == '__main__': main() diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index 51fc28b..bd6991d 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -124,7 +124,7 @@ def main(cfg: DictConfig): mem_f99.append(np.nanmean(f99_m)) # mean over grid points # ensemble-level mean and std over member-wise percentiles pct99_mean['prediction'].append(np.nanmean(mem_f99)) - pct99_std['prediction'].append(np.std(pct99_mean['prediction'])) + pct99_std['prediction'].append(np.nanstd(mem_f99)) # clean up per-hour buffers del mem_f99, stack_m, f99_m @@ -133,7 +133,7 @@ def main(cfg: DictConfig): hrs_c = list(range(24)) + [list(range(24))[0] + 24] pct99_lines = [ cycle_fn(pct99_mean['target']), - cycle_fn(pct99_mean['baseline']), # 6 becasue bug in dataset? + cycle_fn(pct99_mean['baseline']), ( cycle_fn(pct99_mean['prediction']), cycle_fn(pct99_std['prediction']) From 001d5e2b8e2f7ce5eabe48a43cccf7e4b461ce66 Mon Sep 17 00:00:00 2001 From: David Leutwyler <14977216+leuty@users.noreply.github.com> Date: Mon, 4 Aug 2025 14:46:14 +0200 Subject: [PATCH 111/189] New script to plot maps (#18) --- src/hirad/eval/__init__.py | 2 +- src/hirad/eval/plot_maps.py | 265 ++++++++++++++++++++++++++++++++ src/hirad/eval/plotting.py | 96 +++++++++++- src/hirad/inference/generate.py | 15 +- src/hirad/maps.sh | 49 ++++++ 5 files changed, 407 insertions(+), 20 deletions(-) create mode 100644 src/hirad/eval/plot_maps.py create mode 100644 src/hirad/maps.sh diff --git a/src/hirad/eval/__init__.py b/src/hirad/eval/__init__.py index 82e9073..31b49dd 100644 --- a/src/hirad/eval/__init__.py +++ b/src/hirad/eval/__init__.py @@ -1,2 +1,2 @@ from .metrics import absolute_error, compute_mae, average_power_spectrum, crps -from .plotting import plot_error_projection, plot_power_spectra, plot_scores_vs_t \ No newline at end of file +from .plotting import plot_map, plot_error_projection, plot_power_spectra, plot_scores_vs_t \ No newline at end of file diff --git a/src/hirad/eval/plot_maps.py b/src/hirad/eval/plot_maps.py new file mode 100644 index 0000000..e68f956 --- /dev/null +++ b/src/hirad/eval/plot_maps.py @@ -0,0 +1,265 @@ +"""Generates maps of precipitation, temperature, and wind components/speed/direction.""" +import logging +from dataclasses import dataclass, field, replace +from datetime import datetime +from pathlib import Path + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.eval import compute_mae, plot_map +from hirad.eval.plotting import plot_map_precipitation, wind_direction +from hirad.utils.function_utils import get_time_from_range +from hirad.utils.inference_utils import calculate_bounds + +@dataclass +class ChannelMeta: + """Metadata for a channel.""" + name: str + cmap: str = "viridis" + me_cmap: str | None = None + unit: str = "" + norm: any = None + err_vmin: float = None + err_vmax: float = None + vmin: float = None + vmax: float = None + extend: str = "both" + precip_kwargs: dict = field(default_factory=lambda: {"threshold": 0.1, "rfac": 100.0}) + + @classmethod + def get(cls, ch_or_name: "ChannelMeta | str | None", *, vmin=None, vmax=None) -> "ChannelMeta": + name = getattr(ch_or_name, "name", ch_or_name or "") + base = CHANNELS.get(name) or cls(name=name) + if vmin is not None or vmax is not None: + return replace(base, vmin=vmin, vmax=vmax) + return base + +CHANNELS = { + "tp": ChannelMeta(name="tp", cmap=None, unit="mm/h", extend="max", precip_kwargs={"threshold": 0.1, "rfac": 100.0}), + "2t": ChannelMeta(name="2t", cmap="RdYlBu_r", me_cmap="RdBu", unit="K", err_vmin=-4.5, err_vmax=4.5), + "10u": ChannelMeta(name="10u", cmap="BrBG", me_cmap="BrBG", unit="m/s", err_vmin=-10, err_vmax=10, vmin=-10, vmax=10), + "10v": ChannelMeta(name="10v", cmap="BrBG", me_cmap="BrBG", unit="m/s", err_vmin=-10, err_vmax=10, vmin=-10, vmax=10), +} + +def format_time_str(dt_str, input_fmt="%Y%m%d-%H%M", output_fmt="%d-%m-%Y %H:%M"): + """Convert time string from input_fmt to output_fmt.""" + dt = datetime.strptime(dt_str, input_fmt) + return dt.strftime(output_fmt) + +class FileRepository: + def __init__(self, root_path): + self.root = Path(root_path) + + def load(self, time, filename): + return torch.load(self.root / time / filename, weights_only=False) + + def _ensure_dir(self, *subdirs): + """Make (and return) root_path/subdir1/subdir2/….""" + d = self.root.joinpath(*subdirs) + d.mkdir(parents=True, exist_ok=True) + return d + + def _make_fname(self, curr_time, prefix, suffix, member_idx): + """Build a filename like '20250724-1230-prefix-suffix[_member]'.""" + base = f"{curr_time}-{prefix}-{suffix}" + if member_idx is not None: + base += f"_{member_idx}" + return base + + def output_file(self, channel, curr_time, suffix, member_idx=None): + # decide on the folder name: e.g. 'tp_100m' or just 'tp' + folder = f"{channel.name}_{channel.level}" if getattr(channel, "level", None) else channel.name + fname = self._make_fname(curr_time, channel.name, suffix, member_idx) + return self._ensure_dir(folder) / fname + + def wind_file(self, wind_type, curr_time, suffix, member_idx=None): + # e.g. wind_type = "FF10m" or "DD10m" + fname = self._make_fname(curr_time, wind_type, suffix, member_idx) + return self._ensure_dir(wind_type) / fname + +def map_output_to_input_channels(output_channels, input_channels): + """ + Maps output channels to input channels based on their names. + """ + return { + j: next((k for k, input_channel in enumerate(input_channels) if input_channel.name == output_channel.name), -1) + for j, output_channel in enumerate(output_channels) + } + +def save_field(name, data, meta, files, channel, t, member=None, kind=None, cmap=None, vmin=None, vmax=None, custom_path=None, plot_func=None, title=None, **plot_kwargs): + """Save a field by plotting it with the appropriate function and parameters. Handles precipitation specially and supports custom paths.""" + # Determine output path + suffix = f"{name}-{kind}" if kind else f"{name}" + out_path = custom_path or files.output_file(channel, t, suffix, member) + + # Choose plotting function + plot = plot_func or plot_map + + # Case dependent plot parameters + title = title or meta.name + extend = 'max' if kind == 'mae' else meta.extend + common = {'title': title, 'norm': meta.norm, 'extend': extend, **plot_kwargs} + + # Precipitation case + if plot.__name__ == 'plot_map_precipitation': + precip_args = {**meta.precip_kwargs, **{k: plot_kwargs[k] for k in meta.precip_kwargs if k in plot_kwargs}} + plot(data, out_path, title=title, **precip_args) + else: + plot(data, out_path, vmin=vmin if vmin is not None else meta.vmin, vmax=vmax if vmax is not None else meta.vmax, cmap=cmap or meta.cmap, label=meta.unit, **common) + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig) -> None: + + # Initialize distributed manager + DistributedManager.initialize() + + # Initialize logger + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger("plot_maps") + + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") + + dataset_cfg = OmegaConf.to_container(cfg.dataset) + has_lead_time = cfg.generation.get("has_lead_time", False) + dataset, sampler = get_dataset_and_sampler_inference( + dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time + ) + + output_path = getattr(cfg.generation.io, "output_path", "./outputs") + files = FileRepository(output_path) + + for curr_time in times: + prediction = files.load(curr_time, f'{curr_time}-predictions') + baseline = files.load(curr_time, f'{curr_time}-baseline') + target = files.load(curr_time, f'{curr_time}-target') + + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + output_to_input_channel_map = map_output_to_input_channels(output_channels, input_channels) + + for idx, channel in enumerate(output_channels): + input_channel_idx = output_to_input_channel_map[idx] + plot_title = f"{format_time_str(curr_time)}: {getattr(channel, 'title', channel.name)}" + vmin, vmax = calculate_bounds( + target[idx,:,:], + prediction[:,idx,:,:], + baseline[input_channel_idx,:,:] + ) + metadata = ChannelMeta.get(channel, vmin=vmin, vmax=vmax) + + if channel.name == "tp": + save_field( + "target", target[idx, :, :], metadata, files, channel, curr_time, + plot_func=plot_map_precipitation, title=plot_title + ) + save_field( + "baseline", baseline[input_channel_idx, :, :], metadata, files, channel, curr_time, + plot_func=plot_map_precipitation, title=plot_title + ) + if prediction.shape[0] > 1: + for member_idx in range(prediction.shape[0]): + save_field( + "prediction", prediction[member_idx, idx, :, :], metadata, files, channel, curr_time, + member=member_idx, plot_func=plot_map_precipitation, title=plot_title + ) + else: + save_field( + "prediction", prediction[0, idx, :, :], metadata, files, channel, curr_time, + plot_func=plot_map_precipitation, title=plot_title + ) + continue + + # Plot target and baseline + save_field("target", target[idx, :, :], metadata, files, channel, curr_time, title=plot_title) + save_field("baseline", baseline[input_channel_idx, :, :], metadata, files, channel, curr_time, title=plot_title) + + # Baseline MAE and ME + _, baseline_mae = compute_mae(baseline[input_channel_idx, :, :], target[idx, :, :]) + baseline_me = (baseline[input_channel_idx, :, :] - target[idx, :, :]) + save_field("baseline", baseline_mae.reshape(baseline[input_channel_idx, :, :].shape), metadata, files, channel, curr_time, kind="mae", cmap=metadata.cmap if channel.name not in ("10u", "10v", "2t") else 'viridis', vmin=0, vmax=metadata.err_vmax, title=plot_title) + save_field("baseline", baseline_me, metadata, files, channel, curr_time, kind="me", cmap=metadata.me_cmap, vmin=metadata.err_vmin, vmax=metadata.err_vmax, title=plot_title) + + # Ensemble predictions + for member_idx in range(prediction.shape[0]): + member = prediction[member_idx, idx, :, :] + save_field("prediction", member, metadata, files, channel, curr_time, member=member_idx, title=plot_title) + _, prediction_mae = compute_mae(member, target[idx, :, :]) + save_field("prediction", prediction_mae.reshape(member.shape), metadata, files, channel, curr_time, member=member_idx, kind="mae", cmap=metadata.cmap if channel.name not in ("10u", "10v", "2t") else 'viridis', vmin=0, vmax=metadata.err_vmax, title=plot_title) + prediction_me = (member - target[idx, :, :]) + save_field("prediction", prediction_me, metadata, files, channel, curr_time, member=member_idx, kind="me", cmap=metadata.me_cmap, vmin=metadata.err_vmin, vmax=metadata.err_vmax, title=plot_title) + + # Plot Windspeed and direction + wind_channels = {ch.name: idx for idx, ch in enumerate(output_channels) if ch.name in ("10u", "10v")} + if "10u" in wind_channels and "10v" in wind_channels: + idx_10u = wind_channels["10u"] + idx_10v = wind_channels["10v"] + input_idx_10u = output_to_input_channel_map[idx_10u] + input_idx_10v = output_to_input_channel_map[idx_10v] + + # Compute windspeed and direction for target, baseline, prediction + target_wind_speed = np.hypot(target[idx_10u, :, :], target[idx_10v, :, :]) + target_wind_dir = wind_direction(target[idx_10u, :, :], target[idx_10v, :, :]) + baseline_wind_speed = np.hypot(baseline[input_idx_10u, :, :], baseline[input_idx_10v, :, :]) + baseline_wind_dir = wind_direction(baseline[input_idx_10u, :, :], baseline[input_idx_10v, :, :]) + prediction_wind_speed = np.hypot(prediction[:, idx_10u, :, :], prediction[:, idx_10v, :, :]) + prediction_wind_dir = wind_direction(prediction[:, idx_10u, :, :], prediction[:, idx_10v, :, :]) + + plot_title_speed = f"{format_time_str(curr_time)}: FF10m" + plot_title_dir = f"{format_time_str(curr_time)}: DD10m" + + wind_meta = ChannelMeta.get("10u", vmin=0, vmax=10) + dir_meta = ChannelMeta.get("10u", vmin=0, vmax=360) + + # Save windspeed plots + save_field( + "FF10m-target", target_wind_speed, wind_meta, files, None, curr_time, + cmap="viridis", vmin=0, vmax=10, extend='max', + custom_path=files.wind_file("FF10m", curr_time, "FF10m-target"), + plot_func=plot_map, title=plot_title_speed + ) + save_field( + "FF10m-baseline", baseline_wind_speed, wind_meta, files, None, curr_time, + cmap="viridis", vmin=0, vmax=10, extend='max', + custom_path=files.wind_file("FF10m", curr_time, "FF10m-baseline"), + plot_func=plot_map, title=plot_title_speed + ) + for member_idx in range(prediction.shape[0]): + save_field( + "FF10m-prediction", prediction_wind_speed[member_idx], wind_meta, files, None, curr_time, + member=member_idx, cmap="viridis", vmin=0, vmax=10, extend='max', + custom_path=files.wind_file("FF10m", curr_time, "FF10m-prediction", member_idx), + plot_func=plot_map, title=plot_title_speed + ) + + # Save wind direction plots + save_field( + "DD10m-target", target_wind_dir, dir_meta, files, None, curr_time, + cmap="twilight", vmin=0, vmax=360, + custom_path=files.wind_file("DD10m", curr_time, "DD10m-target"), + plot_func=plot_map, title=plot_title_dir + ) + save_field( + "DD10m-baseline", baseline_wind_dir, dir_meta, files, None, curr_time, + cmap="twilight", vmin=0, vmax=360, + custom_path=files.wind_file("DD10m", curr_time, "DD10m-baseline"), + plot_func=plot_map, title=plot_title_dir + ) + for member_idx in range(prediction.shape[0]): + save_field( + "DD10m-prediction", prediction_wind_dir[member_idx], dir_meta, files, None, curr_time, + member=member_idx, cmap="twilight", vmin=0, vmax=360, + custom_path=files.wind_file("DD10m", curr_time, "DD10m-prediction", member_idx), + plot_func=plot_map, title=plot_title_dir + ) + + logger.info("Image loading and plotting completed.") + +if __name__ == "__main__": + main() diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index a7603b7..2eb1c62 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -1,13 +1,99 @@ import logging -import os - -from hirad.eval import crps, absolute_error import cartopy.crs as ccrs +import cartopy.feature as cfeature import matplotlib.pyplot as plt import numpy as np -import torch +from matplotlib.colors import BoundaryNorm, ListedColormap + + +# COSMO‑2 GRID: TODO: Add to dataset config +LAT = np.arange(-4.42, 3.36 + 0.02, 0.02) +LON = np.arange(-6.82, 4.80 + 0.02, 0.02) +RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) + +def plot_map(values: np.array, + filename: str, + label='', + title='', + vmin=None, + vmax=None, + cmap=None, + extend='neither', + norm=None, + ticks=None): + """Plot observed or interpolated data in a scatter plot.""" + logging.info(f'Creating map: {filename}') + latitudes = LAT[RELAX_ZONE : RELAX_ZONE + 352] + longitudes = LON[RELAX_ZONE : RELAX_ZONE + 544] + lon2d, lat2d = np.meshgrid(longitudes, latitudes) + + fig, ax = plt.subplots( + figsize=(8, 6), + subplot_kw={"projection": ccrs.RotatedPole(pole_longitude=-170.0, + pole_latitude= 43.0)} + ) + contour = ax.pcolormesh( + lon2d, lat2d, values, + cmap=cmap, shading="auto", + norm=norm if norm else None, + vmin=None if norm else vmin, + vmax=None if norm else vmax, + ) + ax.coastlines() + ax.add_feature(cfeature.BORDERS, linewidth=1) + ax.gridlines(visible=False) + ax.set_xticks([]) + ax.set_yticks([]) + + plt.title(title) + cbar = plt.colorbar( + contour, + label=label, + orientation="horizontal", + extend=extend, + shrink=0.75, + pad=0.02 + ) + if ticks is not None: + cbar.set_ticks(ticks) + + plt.tight_layout() + fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight") + plt.close(fig) + +def plot_map_precipitation(values, filename, title='', threshold=0.1, rfac=100.0): + """Plot precipitation data with specific colormap and thresholds.""" + # Scale and mask values below threshold + values = rfac * values # m/h --> mm/h + values = np.ma.masked_where(values <= threshold, values) + + # Predefined colors and bounds specific for precipitation + colors = ['none', 'powderblue', 'dodgerblue', 'mediumblue', + 'forestgreen', 'limegreen', 'lawngreen', + 'yellow', 'gold', 'darkorange', 'red', + 'darkviolet', 'violet', 'thistle'] + bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 30, 50, 70, 100, 150, 200] + + cmap = ListedColormap(colors) + norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False) + + plot_map( + values, filename, + cmap=cmap, + norm=norm, + ticks=bounds, + title=title, + label='mm/h', + extend='max' + ) + +def wind_direction(u, v): + """Compute wind direction from u and v components.""" + return(np.arctan2(-u, -v) * 180 / np.pi) % 360 + +@DeprecationWarning def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str, label='', title='', vmin=None, vmax=None): """Plot observed or interpolated data in a scatter plot.""" fig = plt.figure() @@ -56,4 +142,4 @@ def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): #plt.psd(x) logging.info(f'plotting values to {filename}') plt.savefig(filename) - plt.close('all') + plt.close('all') \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index a553af8..bd4d2e1 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -13,7 +13,7 @@ from hirad.models import EDMPrecondSuperResolution, UNet from hirad.inference import Generator -from hirad.utils.inference_utils import save_images, save_results_as_torch +from hirad.utils.inference_utils import save_results_as_torch from hirad.utils.function_utils import get_time_from_range from hirad.utils.checkpoint import load_checkpoint @@ -250,19 +250,6 @@ def elapsed_time(self, _): batch_size = image_out.shape[0] # write out data in a seperate thread so we don't hold up inferencing - if not cfg.generation.times_range: - writer_threads.append( - writer_executor.submit( - save_images, - savedir, - times[sampler[time_index]], - dataset, - image_out.cpu().numpy(), - image_tar.cpu().numpy(), - image_lr.cpu().numpy(), - image_reg.cpu().numpy() if image_reg is not None else None, - ) - ) writer_threads.append( writer_executor.submit( save_results_as_torch, diff --git a/src/hirad/maps.sh b/src/hirad/maps.sh new file mode 100644 index 0000000..bcf4008 --- /dev/null +++ b/src/hirad/maps.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +#SBATCH --job-name="plot" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:10:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/plot_maps.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/eval/plot_maps.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file From 8efcedd1076b8c03b3c2a16151ab8f015657e513 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Mon, 4 Aug 2025 15:34:16 +0200 Subject: [PATCH 112/189] use more xarray --- src/hirad/eval/percentile99_cycle_precip.py | 99 ++++++++++++--------- 1 file changed, 56 insertions(+), 43 deletions(-) diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index bd6991d..abfa548 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -14,6 +14,7 @@ import numpy as np import torch from omegaconf import DictConfig, OmegaConf +import xarray as xr from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager @@ -82,55 +83,67 @@ def main(cfg: DictConfig): # Land-sea mask lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) - land_mask = np.where(lsm_data >= 0.5, 1.0, np.nan) + land_mask = xr.DataArray( + np.where(lsm_data >= 0.5, 1.0, np.nan), + dims=['lat', 'lon'] + ) # Storage for diurnal cycles - pct99_mean = {'target': [], 'baseline': [], 'prediction': []} - pct99_std = {'target': [], 'baseline': [], 'prediction': []} - - # -- Target and Baseline: compute per hour -- + pct99_mean = {} + pct99_std = {} + + # -- Process target and baseline -- for mode in ['target', 'baseline']: logger.info(f"Processing mode: {mode}") - for h in list(range(24)): - arrs = [ - load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] * land_mask - for ts in times if hour_of(ts) == h - ] - stack = np.stack(arrs, axis=0) - f99 = np.percentile(stack, 99, axis=0) - pct99_mean[mode].append(np.nanmean(f99) if mode == 'target' else np.nanmean(f99) / 6.0) # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset - del arrs, stack, f99 + + data_list = [] + for ts in times: + data = load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] * land_mask + data_list.append(data) + + da = xr.DataArray( + np.stack(data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + + # Group by hour and compute 99th percentile + hourly_p99 = da.groupby('time.hour').quantile(0.99, dim='time') + + # Apply scaling factor for baseline + if mode == 'baseline': + hourly_p99 = hourly_p99 / 6.0 + + pct99_mean[mode] = hourly_p99.mean(dim=['lat', 'lon']) # -- Predictions: compute per hour per member, then mean+std across members -- - # Determine number of ensemble members - sample = load(times[0], f"{times[0]}-predictions") # [n_members, n_channels, lat, lon] - data_sample = sample[:, tp_out] - n_members = data_sample.shape[0] - - for h in list(range(24)): - logger.info(f"Processing predictions for hour {h}") - mem_f99 = [] - # for each ensemble member, gather its hourly fields - for m in range(n_members): - arrs = [] - for ts in times: - if hour_of(ts) != h: - continue - preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, ...] - arrs.append(preds[m, tp_out] * land_mask) # apply mask - # stack over time and compute 99th percentile at each grid point - stack_m = np.stack(arrs, axis=0) - f99_m = np.percentile(stack_m, 99, axis=0) - mem_f99.append(np.nanmean(f99_m)) # mean over grid points - # ensemble-level mean and std over member-wise percentiles - pct99_mean['prediction'].append(np.nanmean(mem_f99)) - pct99_std['prediction'].append(np.nanstd(mem_f99)) - # clean up per-hour buffers - del mem_f99, stack_m, f99_m - - # Prepare cyclic series - cycle_fn = lambda x: x + [x[0]] - hrs_c = list(range(24)) + [list(range(24))[0] + 24] + logger.info("Processing predictions") + + # Load all prediction data at once into xarray + pred_data_list = [] + for ts in times: + preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] + pred_data_list.append(preds[:, tp_out] * land_mask) # apply mask + + pred_da = xr.DataArray( + np.stack(pred_data_list, axis=1), # [n_members, time, lat, lon] + dims=['member', 'time', 'lat', 'lon'], + coords={ + 'member': range(len(pred_data_list[0])), + 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + } + ) + + # Group by hour, compute 99th percentile across time, then spatial mean + hourly_p99_by_member = pred_da.groupby('time.hour').quantile(0.99, dim='time').mean(dim=['lat', 'lon']) + + # Store ensemble statistics as xarray DataArrays + pct99_mean['prediction'] = hourly_p99_by_member.mean(dim='member') + pct99_std['prediction'] = hourly_p99_by_member.std(dim='member') + + # Prepare cyclic lists for plotting + cycle_fn = lambda x: x.values.tolist() + [x.values.tolist()[0]] + hrs_c = list(range(24)) + [0 + 24] pct99_lines = [ cycle_fn(pct99_mean['target']), cycle_fn(pct99_mean['baseline']), From 99a066116af67f0cb1c77ccb9278c8b76125e15b Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Mon, 4 Aug 2025 19:15:27 +0200 Subject: [PATCH 113/189] add histograms --- src/hirad/eval/hist.py | 279 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 src/hirad/eval/hist.py diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py new file mode 100644 index 0000000..804096d --- /dev/null +++ b/src/hirad/eval/hist.py @@ -0,0 +1,279 @@ +""" +Plots the domain-mean precipitation distribution over land. + +This script computes and visualizes the distribution of precipitation values +across the land domain for different data sources (target, baseline, predictions). +""" +import logging +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range + +# Constants +CONV_FACTOR = 100 # Convert meters to mm/h +LOG_INTERVAL = 24 # Log progress every N timesteps + + +def save_distribution_plot(hist_data_dict, bin_edges, labels, colors, title, ylabel, out_path, percentiles_data=None): + """Save distribution plot with pre-computed histograms.""" + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + + plt.figure(figsize=(10, 6)) + + # Plot histograms from pre-computed bin counts + for (key, hist_data), label, color in zip(hist_data_dict.items(), labels, colors): + if isinstance(hist_data, tuple): # Handle ensemble data + # Plot individual members with transparency + for i, member_hist in enumerate(hist_data): + alpha = 0.5 if i > 0 else 0.7 + label_member = label if i == 0 else None + # Plot histogram from bin counts + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + plt.plot(bin_centers, member_hist, alpha=alpha, color=color, + label=label_member, drawstyle='steps-mid') + else: + # Plot histogram from bin counts + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + plt.plot(bin_centers, hist_data, alpha=0.7, color=color, + label=label, linewidth=2, drawstyle='steps-mid') + + plt.xscale('log') + plt.yscale('log') + plt.xlabel(ylabel) + plt.ylabel('Probability Density') + plt.title(title) + plt.grid(True, alpha=0.3) + + # Add percentile lines if provided + if percentiles_data: + # Define colors for different datasets + percentile_colors = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green'} + # Track if we've added legend labels for line styles + legend_added = {'99': False, '99.9': False, '99.99': False} + + for dataset_name, percentiles in percentiles_data.items(): + if dataset_name in ['target', 'baseline']: + # Plot percentiles for target and baseline + color = percentile_colors[dataset_name] + for percentile, value in percentiles.items(): + if percentile == 99: + linestyle = '--' + legend_label = '99th percentiles' if not legend_added['99'] else None + legend_added['99'] = True + elif percentile == 99.9: + linestyle = ':' + legend_label = '99.9th percentiles' if not legend_added['99.9'] else None + legend_added['99.9'] = True + elif percentile == 99.99: + linestyle = '-.' + legend_label = '99.99th percentiles' if not legend_added['99.99'] else None + legend_added['99.99'] = True + + plt.vlines(x=value, colors=color, + linestyles=linestyle, alpha=0.8, label=legend_label) + + elif dataset_name == 'predictions': + # Plot percentiles for ensemble members + color = percentile_colors[dataset_name] + for member_name, member_percentiles in percentiles.items(): + for percentile, value in member_percentiles.items(): + if percentile == 99: + linestyle = '--' + legend_label = '99th percentiles' if not legend_added['99'] else None + legend_added['99'] = True + elif percentile == 99.9: + linestyle = ':' + legend_label = '99.9th percentiles' if not legend_added['99.9'] else None + legend_added['99.9'] = True + elif percentile == 99.99: + linestyle = '-.' + legend_label = '99.99th percentiles' if not legend_added['99.99'] else None + legend_added['99.99'] = True + + plt.vlines(x=value, colors=color, + linestyles=linestyle, alpha=0.6, label=legend_label) + + plt.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for domain-mean precipitation distribution over land") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root and loader + out_root = Path(cfg.generation.io.output_path or './outputs') + + def load(ts, fn): + return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + + # Find channel indices + out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} + in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + tp_out = out_ch['tp'] + tp_in = in_ch.get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Land-sea mask + lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) + land_mask = xr.DataArray( + np.where(lsm_data >= 0.5, 1.0, np.nan), + dims=['lat', 'lon'] + ) + + # Define histogram bins + bins = np.logspace(-1, 1, 50) # Log-spaced bins for precipitation + + # Storage for histogram data + hist_data = {} + # Store all land values for percentile calculation + all_land_values = {} + + # -- Process target and baseline -- + for mode in ['target', 'baseline']: + logger.info(f"Processing mode: {mode}") + + # Initialize histogram accumulator and collect all values + hist_counts = np.zeros(len(bins) - 1) + total_samples = 0 + all_values = [] + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + data = load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] * land_mask + + # Apply scaling factor for baseline + if mode == 'baseline': + data = data / 6.0 + + # Extract land values (remove NaN values) + land_values = data.values[~np.isnan(data.values)] + all_values.extend(land_values) + + # Accumulate histogram counts + counts, _ = np.histogram(land_values, bins=bins) + hist_counts += counts + total_samples += len(land_values) + + # Normalize to probability density + bin_widths = np.diff(bins) + hist_data[mode] = hist_counts / (total_samples * bin_widths) + all_land_values[mode] = np.array(all_values) + logger.info(f"Processed {total_samples} land values for {mode}") + + # -- Process predictions: compute histogram for each ensemble member -- + logger.info("Processing predictions") + + n_members = None + member_hist_data = [] + all_member_values = [] # Store all values for each member + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] + + if n_members is None: + n_members = preds.shape[0] + # Initialize histogram accumulators for each member + member_hist_data = [np.zeros(len(bins) - 1) for _ in range(n_members)] + member_sample_counts = [0 for _ in range(n_members)] + all_member_values = [[] for _ in range(n_members)] # Initialize value storage + + for member_idx in range(n_members): + data = preds[member_idx, tp_out] * land_mask + # Extract land values (remove NaN values) + land_values = data.values[~np.isnan(data.values)] + all_member_values[member_idx].extend(land_values) # Store values for percentiles + + # Accumulate histogram counts for this member + counts, _ = np.histogram(land_values, bins=bins) + member_hist_data[member_idx] += counts + member_sample_counts[member_idx] += len(land_values) + + # Normalize member histograms to probability density + bin_widths = np.diff(bins) + normalized_member_hists = [] + for member_idx in range(n_members): + normalized_hist = member_hist_data[member_idx] / (member_sample_counts[member_idx] * bin_widths) + normalized_member_hists.append(normalized_hist) + + hist_data['predictions'] = tuple(normalized_member_hists) + + logger.info(f"Collected {n_members} ensemble members for predictions") + + # Compute percentiles for all datasets + percentiles_data = {} + + # Target percentiles + target_data_array = xr.DataArray(all_land_values['target']) + target_p99 = target_data_array.quantile(0.99).item() + target_p999 = target_data_array.quantile(0.999).item() + target_p9999 = target_data_array.quantile(0.9999).item() + percentiles_data['target'] = {99: target_p99, 99.9: target_p999, 99.99: target_p9999} + + # Baseline percentiles + baseline_data_array = xr.DataArray(all_land_values['baseline']) + baseline_p99 = baseline_data_array.quantile(0.99).item() + baseline_p999 = baseline_data_array.quantile(0.999).item() + baseline_p9999 = baseline_data_array.quantile(0.9999).item() + percentiles_data['baseline'] = {99: baseline_p99, 99.9: baseline_p999, 99.99: baseline_p9999} + + # Ensemble member percentiles + percentiles_data['predictions'] = {} + for member_idx in range(n_members): + member_data_array = xr.DataArray(all_member_values[member_idx]) + member_p99 = member_data_array.quantile(0.99).item() + member_p999 = member_data_array.quantile(0.999).item() + member_p9999 = member_data_array.quantile(0.9999).item() + percentiles_data['predictions'][f'member_{member_idx}'] = {99: member_p99, 99.9: member_p999, 99.99: member_p9999} + + + # Create distribution plots + labels = ['COSMO-2', 'ERA5', 'CorrDiff Ensemble'] + colors = ['blue', 'orange', 'green'] + + fn = out_root / 'precipitation_distribution_over_land.png' + save_distribution_plot( + hist_data, + bins, + labels, + colors, + 'Domain-Mean Precip. Over Land (Pooled Data)', + 'Precipitation (mm/h)', + fn, + percentiles_data + ) + logger.info(f"Distribution plot saved: {fn}") + + +if __name__ == '__main__': + main() \ No newline at end of file From 477140b5e59395b5386d827c26f92fb6d0cb0979 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 5 Aug 2025 10:08:09 +0200 Subject: [PATCH 114/189] add percentile lines --- src/hirad/eval/hist.py | 81 ++++++++++----------- src/hirad/eval/percentile99_cycle_precip.py | 29 +++++--- 2 files changed, 55 insertions(+), 55 deletions(-) diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py index 804096d..59ee2ed 100644 --- a/src/hirad/eval/hist.py +++ b/src/hirad/eval/hist.py @@ -2,7 +2,7 @@ Plots the domain-mean precipitation distribution over land. This script computes and visualizes the distribution of precipitation values -across the land domain for different data sources (target, baseline, predictions). +over land. """ import logging from pathlib import Path @@ -50,57 +50,52 @@ def save_distribution_plot(hist_data_dict, bin_edges, labels, colors, title, yla plt.yscale('log') plt.xlabel(ylabel) plt.ylabel('Probability Density') + plt.ylim(1e-8, 1) + plt.xlim(bin_edges[1], bin_edges[-1]) plt.title(title) plt.grid(True, alpha=0.3) # Add percentile lines if provided if percentiles_data: - # Define colors for different datasets - percentile_colors = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green'} - # Track if we've added legend labels for line styles - legend_added = {'99': False, '99.9': False, '99.99': False} + # Calculate y-range for percentile lines (lowest 10% of log scale) + y_bottom, y_top = plt.ylim() + log_bottom, log_top = np.log10(y_bottom), np.log10(y_top) + vline_ymax = 10**(log_bottom + 0.1 * (log_top - log_bottom)) + vline_ymin = y_bottom - for dataset_name, percentiles in percentiles_data.items(): + # Define line styles for percentiles + percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'} + percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'} + colors = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green'} + legend_added = set() + + # Plot all percentile lines + for dataset_name, data in percentiles_data.items(): + color = colors[dataset_name] + if dataset_name in ['target', 'baseline']: - # Plot percentiles for target and baseline - color = percentile_colors[dataset_name] - for percentile, value in percentiles.items(): - if percentile == 99: - linestyle = '--' - legend_label = '99th percentiles' if not legend_added['99'] else None - legend_added['99'] = True - elif percentile == 99.9: - linestyle = ':' - legend_label = '99.9th percentiles' if not legend_added['99.9'] else None - legend_added['99.9'] = True - elif percentile == 99.99: - linestyle = '-.' - legend_label = '99.99th percentiles' if not legend_added['99.99'] else None - legend_added['99.99'] = True + # Single dataset + for percentile, value in data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries - plt.vlines(x=value, colors=color, - linestyles=linestyle, alpha=0.8, label=legend_label) - - elif dataset_name == 'predictions': - # Plot percentiles for ensemble members - color = percentile_colors[dataset_name] - for member_name, member_percentiles in percentiles.items(): - for percentile, value in member_percentiles.items(): - if percentile == 99: - linestyle = '--' - legend_label = '99th percentiles' if not legend_added['99'] else None - legend_added['99'] = True - elif percentile == 99.9: - linestyle = ':' - legend_label = '99.9th percentiles' if not legend_added['99.9'] else None - legend_added['99.9'] = True - elif percentile == 99.99: - linestyle = '-.' - legend_label = '99.99th percentiles' if not legend_added['99.99'] else None - legend_added['99.99'] = True + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.8) # No label here + else: + # Ensemble members + for member_data in data.values(): + for percentile, value in member_data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries - plt.vlines(x=value, colors=color, - linestyles=linestyle, alpha=0.6, label=legend_label) + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.6) # No label here + + # Add black legend entries for percentiles (override the colored ones) + for percentile in [99, 99.9, 99.99]: + if percentile in legend_added: + plt.plot([], [], color='black', linestyle=percentile_styles[percentile], + label=percentile_labels[percentile]) plt.legend() plt.tight_layout() diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index abfa548..7f97e21 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -73,12 +73,14 @@ def main(cfg: DictConfig): # Output root and loader out_root = Path(cfg.generation.io.output_path or './outputs') - load = lambda ts, fn: torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + def load(ts, fn): + return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} - tp_out = out_ch['tp']; tp_in = in_ch.get('tp', tp_out) + tp_out = out_ch['tp'] + tp_in = in_ch.get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask @@ -123,16 +125,17 @@ def main(cfg: DictConfig): pred_data_list = [] for ts in times: preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] - pred_data_list.append(preds[:, tp_out] * land_mask) # apply mask + # Extract precipitation channel and convert to xarray for proper broadcasting + tp_data = preds[:, tp_out] # [n_members, lat, lon] + tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon']) + pred_data_list.append(tp_da * land_mask) # apply mask - pred_da = xr.DataArray( - np.stack(pred_data_list, axis=1), # [n_members, time, lat, lon] - dims=['member', 'time', 'lat', 'lon'], - coords={ - 'member': range(len(pred_data_list[0])), - 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] - } - ) + pred_da = xr.concat(pred_data_list, dim='time') # [n_members, time, lat, lon] + pred_da = pred_da.assign_coords({ + 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + }) + # Transpose to get the expected dimension order: [member, time, lat, lon] + pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') # Group by hour, compute 99th percentile across time, then spatial mean hourly_p99_by_member = pred_da.groupby('time.hour').quantile(0.99, dim='time').mean(dim=['lat', 'lon']) @@ -142,7 +145,9 @@ def main(cfg: DictConfig): pct99_std['prediction'] = hourly_p99_by_member.std(dim='member') # Prepare cyclic lists for plotting - cycle_fn = lambda x: x.values.tolist() + [x.values.tolist()[0]] + def cycle_fn(x): + return x.values.tolist() + [x.values.tolist()[0]] + hrs_c = list(range(24)) + [0 + 24] pct99_lines = [ cycle_fn(pct99_mean['target']), From f79f95f3707d6c167b3ee983ad041c1d1c944be1 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 5 Aug 2025 13:46:39 +0200 Subject: [PATCH 115/189] add plot for probability of exceedance --- src/hirad/eval/probability_of_exceedance.py | 262 ++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100644 src/hirad/eval/probability_of_exceedance.py diff --git a/src/hirad/eval/probability_of_exceedance.py b/src/hirad/eval/probability_of_exceedance.py new file mode 100644 index 0000000..e5da6b9 --- /dev/null +++ b/src/hirad/eval/probability_of_exceedance.py @@ -0,0 +1,262 @@ +""" +Plots the probability of exceedance for precipitation over land. + +This script computes and visualizes the complementary cumulative distribution +(probability of exceeding x mm/h) over land). +""" +import logging +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range + +# Constants +CONV_FACTOR = 100 # Convert meters to mm/h +LOG_INTERVAL = 24 # Log progress every N timesteps + + +def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title, ylabel, out_path, percentiles_data=None): + """Save probability of exceedance plot with pre-computed data.""" + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + + plt.figure(figsize=(10, 6)) + + # Plot exceedance curves + for (key, exceedance_data), label, color in zip(exceedance_data_dict.items(), labels, colors): + if isinstance(exceedance_data, tuple): # Handle ensemble data + # Plot individual members with transparency + for i, member_exceedance in enumerate(exceedance_data): + alpha = 0.5 if i > 0 else 0.7 + label_member = label if i == 0 else None + plt.plot(thresholds, member_exceedance, alpha=alpha, color=color, + label=label_member, linewidth=1) + else: + # Plot single dataset + plt.plot(thresholds, exceedance_data, alpha=0.7, color=color, + label=label, linewidth=2) + + plt.xscale('log') + plt.yscale('log') + plt.xlabel(ylabel) + plt.ylabel('Probability of Exceedance') + plt.ylim(1e-8, 1) + plt.xlim(thresholds[1], thresholds[-1]) + plt.title(title) + plt.grid(True, alpha=0.3) + + # Add percentile lines if provided + if percentiles_data: + # Calculate y-range for percentile lines (lowest 10% of log scale) + y_bottom, y_top = plt.ylim() + log_bottom, log_top = np.log10(y_bottom), np.log10(y_top) + vline_ymax = 10**(log_bottom + 0.1 * (log_top - log_bottom)) + vline_ymin = y_bottom + + # Define line styles for percentiles + percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'} + percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'} + colors_perc = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green'} + legend_added = set() + + # Plot all percentile lines + for dataset_name, data in percentiles_data.items(): + color = colors_perc[dataset_name] + + if dataset_name in ['target', 'baseline']: + # Single dataset + for percentile, value in data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.8) # No label here + else: + # Ensemble members + for member_data in data.values(): + for percentile, value in member_data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.6) # No label here + + # Add black legend entries for percentiles (override the colored ones) + for percentile in [99, 99.9, 99.99]: + if percentile in legend_added: + plt.plot([], [], color='black', linestyle=percentile_styles[percentile], + label=percentile_labels[percentile]) + + plt.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for probability of exceedance over land") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root and loader + out_root = Path(cfg.generation.io.output_path or './outputs') + + def load(ts, fn): + return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + + # Find channel indices + out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} + in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + tp_out = out_ch['tp'] + tp_in = in_ch.get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # Land-sea mask + lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) + land_mask = xr.DataArray( + np.where(lsm_data >= 0.5, 1.0, np.nan), + dims=['lat', 'lon'] + ) + + # Define thresholds for exceedance calculation + thresholds = np.logspace(-2, 2, 200) # From 0.01 to 100 mm/h + + # Storage for exceedance data + exceedance_data = {} + # Store all land values for percentile calculation + all_land_values = {} + + # -- Process target and baseline -- + for mode in ['target', 'baseline']: + logger.info(f"Processing mode: {mode}") + + all_values = [] + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + data = load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] * land_mask + + # Apply scaling factor for baseline + if mode == 'baseline': + data = data / 6.0 + + # Extract land values (remove NaN values) + land_values = data.values[~np.isnan(data.values)] + all_values.extend(land_values) + + # Compute exceedance probabilities + all_values = np.array(all_values) + exceedance_probs = [] + for threshold in thresholds: + prob_exceed = np.mean(all_values > threshold) + exceedance_probs.append(prob_exceed) + + exceedance_data[mode] = np.array(exceedance_probs) + all_land_values[mode] = all_values + logger.info(f"Processed {len(all_values)} land values for {mode}") + + # -- Process predictions: compute exceedance for each ensemble member -- + logger.info("Processing predictions") + + n_members = None + all_member_values = [] # Store all values for each member + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] + + if n_members is None: + n_members = preds.shape[0] + all_member_values = [[] for _ in range(n_members)] # Initialize value storage + + for member_idx in range(n_members): + data = preds[member_idx, tp_out] * land_mask + # Extract land values (remove NaN values) + land_values = data.values[~np.isnan(data.values)] + all_member_values[member_idx].extend(land_values) + + # Compute exceedance probabilities for each ensemble member + member_exceedance_data = [] + for member_idx in range(n_members): + member_values = np.array(all_member_values[member_idx]) + member_exceedance = [] + for threshold in thresholds: + prob_exceed = np.mean(member_values > threshold) + member_exceedance.append(prob_exceed) + member_exceedance_data.append(np.array(member_exceedance)) + + exceedance_data['predictions'] = tuple(member_exceedance_data) + + logger.info(f"Collected {n_members} ensemble members for predictions") + + # Compute percentiles for all datasets + percentiles_data = {} + + # Target percentiles + target_data_array = xr.DataArray(all_land_values['target']) + target_p99 = target_data_array.quantile(0.99).item() + target_p999 = target_data_array.quantile(0.999).item() + target_p9999 = target_data_array.quantile(0.9999).item() + percentiles_data['target'] = {99: target_p99, 99.9: target_p999, 99.99: target_p9999} + + # Baseline percentiles + baseline_data_array = xr.DataArray(all_land_values['baseline']) + baseline_p99 = baseline_data_array.quantile(0.99).item() + baseline_p999 = baseline_data_array.quantile(0.999).item() + baseline_p9999 = baseline_data_array.quantile(0.9999).item() + percentiles_data['baseline'] = {99: baseline_p99, 99.9: baseline_p999, 99.99: baseline_p9999} + + # Ensemble member percentiles + percentiles_data['predictions'] = {} + for member_idx in range(n_members): + member_data_array = xr.DataArray(all_member_values[member_idx]) + member_p99 = member_data_array.quantile(0.99).item() + member_p999 = member_data_array.quantile(0.999).item() + member_p9999 = member_data_array.quantile(0.9999).item() + percentiles_data['predictions'][f'member_{member_idx}'] = {99: member_p99, 99.9: member_p999, 99.99: member_p9999} + + + # Create exceedance plots + labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] + colors = ['blue', 'orange', 'green'] + + fn = out_root / 'precipitation_exceedance_over_land.png' + save_exceedance_plot( + exceedance_data, + thresholds, + labels, + colors, + 'Probability of Exceedance', + 'All-hour Precipitation Over Land [mm/h] (Pooled Data)', + fn, + percentiles_data + ) + logger.info(f"Exceedance plot saved: {fn}") + + +if __name__ == '__main__': + main() From e631235252813a192d56dc4ac0b74b1c011be44a Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 5 Aug 2025 13:46:54 +0200 Subject: [PATCH 116/189] label --- src/hirad/eval/diurnal_cycle_precip.py | 4 ++-- src/hirad/eval/diurnal_cycle_temp_wind.py | 4 ++-- src/hirad/eval/hist.py | 2 +- src/hirad/eval/percentile99_cycle_precip.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index 50961e3..ef57155 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -127,7 +127,7 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): amount_target_mean.hour, [amount_target_mean, amount_baseline_mean, amount_pred_mean], [None, None, amount_pred_std], - ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], + ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)'], 'Precipitation (mm/day)', 'Diurnal Cycle of Precip Amount', out_root / 'diurnal_cycle_precip_amount.png' @@ -136,7 +136,7 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): wet_target_mean.hour, [wet_target_mean, wet_baseline_mean, wet_pred_mean], [None, None, wet_pred_std], - ['COSMO-2','ERA5','CorrDiff ± Std(Members)'], + ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)'], 'Wet-Hour Fraction [%]', 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', out_root / 'diurnal_cycle_precip_wethours.png' diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index 8e1ff30..90d178f 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -157,7 +157,7 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): temp_target_mean.hour, [temp_target_mean, temp_baseline_mean, temp_pred_mean], [None, None, temp_pred_std], - ['COSMO-2', 'ERA5', 'CorrDiff ± Std(Members)'], + ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)'], '2m Temperature [°C]', 'Diurnal Cycle of 2m Temperature', out_root / 'diurnal_cycle_2t.png' @@ -167,7 +167,7 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): wind_target_mean.hour, [wind_target_mean, wind_baseline_mean, wind_pred_mean], [None, None, wind_pred_std], - ['COSMO-2', 'ERA5', 'CorrDiff ± Std(Members)'], + ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)'], 'Windspeed [m/s]', 'Diurnal Cycle of Windspeed', out_root / 'diurnal_cycle_windspeed.png' diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py index 59ee2ed..49b6bd2 100644 --- a/src/hirad/eval/hist.py +++ b/src/hirad/eval/hist.py @@ -253,7 +253,7 @@ def load(ts, fn): # Create distribution plots - labels = ['COSMO-2', 'ERA5', 'CorrDiff Ensemble'] + labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] colors = ['blue', 'orange', 'green'] fn = out_root / 'precipitation_distribution_over_land.png' diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index 7f97e21..3ee9d1c 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -163,7 +163,7 @@ def cycle_fn(x): save_plot( hrs_c, pct99_lines, - ['COSMO-2','ERA5','CorrDiff 99th Pct ± Std'], + ['COSMO-2 Analysis','ERA5','CorrDiff 99th Pct ± Std'], 'Precipitation (mm/day)', 'Diurnal Cycle of 99th-Percentile Precipitation', fn From 4904a539e9075bda78057960f3c2fe934d42ceea Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 5 Aug 2025 14:45:37 +0200 Subject: [PATCH 117/189] rename --- src/hirad/eval/{plot_maps.py => snapshots.py} | 0 src/hirad/{diurnal_cycle.sh => eval_precip.sh} | 14 ++++++++++---- src/hirad/maps.sh | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) rename src/hirad/eval/{plot_maps.py => snapshots.py} (100%) rename src/hirad/{diurnal_cycle.sh => eval_precip.sh} (80%) diff --git a/src/hirad/eval/plot_maps.py b/src/hirad/eval/snapshots.py similarity index 100% rename from src/hirad/eval/plot_maps.py rename to src/hirad/eval/snapshots.py diff --git a/src/hirad/diurnal_cycle.sh b/src/hirad/eval_precip.sh similarity index 80% rename from src/hirad/diurnal_cycle.sh rename to src/hirad/eval_precip.sh index 6643757..bb4b810 100644 --- a/src/hirad/diurnal_cycle.sh +++ b/src/hirad/eval_precip.sh @@ -1,19 +1,19 @@ #!/bin/bash -#SBATCH --job-name="plot_diurnal_cycle" +#SBATCH --job-name="eval_precip" ### HARDWARE ### -#SBATCH --partition=debug +#SBATCH --partition=normal #SBATCH --nodes=1 #SBATCH --ntasks-per-node=2 #SBATCH --gpus-per-node=2 #SBATCH --cpus-per-task=72 -##SBATCH --time=00:10:00 +#SBATCH --time=05:00:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=./logs/plot_diurnal_cycle.log +#SBATCH --output=./logs/plots_precipe.log ### ENVIRONMENT #### #SBATCH -A a161 @@ -48,4 +48,10 @@ srun --environment=./ci/edf/modulus_env.toml bash -c " python src/hirad/eval/diurnal_cycle_precip.py --config-name=generate_era_cosmo.yaml python src/hirad/eval/percentile99_cycle_precip.py --config-name=generate_era_cosmo.yaml python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=generate_era_cosmo.yaml +" + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + python src/hirad/eval/hist.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/probability_of_exceedance.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file diff --git a/src/hirad/maps.sh b/src/hirad/maps.sh index bcf4008..db6722b 100644 --- a/src/hirad/maps.sh +++ b/src/hirad/maps.sh @@ -45,5 +45,5 @@ export OMP_NUM_THREADS=72 # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies - python src/hirad/eval/plot_maps.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/snapshots.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file From 58c6607ccdb3c5f77188eca906645ffdd54f120c Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 5 Aug 2025 15:15:31 +0200 Subject: [PATCH 118/189] plot map of the 99th all-hour percentile --- src/hirad/eval/map_99pctl.py | 173 +++++++++++++++++++++++++++++++++++ src/hirad/maps.sh | 1 + 2 files changed, 174 insertions(+) create mode 100644 src/hirad/eval/map_99pctl.py diff --git a/src/hirad/eval/map_99pctl.py b/src/hirad/eval/map_99pctl.py new file mode 100644 index 0000000..547a81e --- /dev/null +++ b/src/hirad/eval/map_99pctl.py @@ -0,0 +1,173 @@ +""" +Plots maps of the 99th percentile of precipitation over the entire time period. + +This script computes the all-time 99th percentile of precipitation for each grid point +and creates maps for target (COSMO-2), baseline (ERA5), and predictions (CorrDiff). +""" +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import plot_map_precipitation + +# Constants +CONV_FACTOR = 100 * 24 # Convert meters to mm/day +LOG_INTERVAL = 24 # Log progress every N timesteps + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for 99th percentile precipitation maps") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root and loader + out_root = Path(cfg.generation.io.output_path or './outputs') + def load(ts, fn): + return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + + # Find channel indices + out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} + in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + tp_out = out_ch['tp'] + tp_in = in_ch.get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # -- Process target -- + logger.info("Processing target (COSMO-2)") + target_data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing target timestep {i+1}/{len(times)}: {ts}") + data = load(ts, f"{ts}-target")[tp_out] + target_data_list.append(data) + + target_da = xr.DataArray( + np.stack(target_data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + + # Compute 99th percentile across all time steps + target_p99 = target_da.quantile(0.99, dim='time') + logger.info("Target 99th percentile computed") + + # -- Process baseline -- + logger.info("Processing baseline (ERA5)") + baseline_data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing baseline timestep {i+1}/{len(times)}: {ts}") + data = load(ts, f"{ts}-baseline")[tp_in] + baseline_data_list.append(data) + + baseline_da = xr.DataArray( + np.stack(baseline_data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + + # Apply scaling factor for baseline and compute 99th percentile + baseline_p99 = (baseline_da / 6.0).quantile(0.99, dim='time') + logger.info("Baseline 99th percentile computed") + + # -- Process predictions -- + logger.info("Processing predictions (CorrDiff)") + pred_data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing predictions timestep {i+1}/{len(times)}: {ts}") + preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] + # Extract precipitation channel + tp_data = preds[:, tp_out] # [n_members, lat, lon] + tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon']) + pred_data_list.append(tp_da) + + pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon] + pred_da = pred_da.assign_coords({ + 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + }) + pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') + + # Compute 99th percentile across time for each member, then ensemble mean + pred_p99_by_member = pred_da.quantile(0.99, dim='time') + pred_p99_mean = pred_p99_by_member.mean(dim='member') + logger.info("Predictions 99th percentile computed") + + # Create output directory + map_output_dir = out_root / 'maps_99th_percentile' + map_output_dir.mkdir(parents=True, exist_ok=True) + + # Plot maps using the precipitation-specific plotting function + logger.info("Creating precipitation maps") + + # Target map + plot_map_precipitation( + target_p99.values, + str(map_output_dir / 'target_99th_percentile'), + title='COSMO-2 Analysis: 99th Percentile Precipitation', + threshold=0.1, + rfac=1.0 # Already converted to mm/day + ) + logger.info("Target map saved") + + # Baseline map + plot_map_precipitation( + baseline_p99.values, + str(map_output_dir / 'baseline_99th_percentile'), + title='ERA5: 99th Percentile Precipitation', + threshold=0.1, + rfac=1.0 # Already converted to mm/day + ) + logger.info("Baseline map saved") + + # Prediction ensemble mean map + plot_map_precipitation( + pred_p99_mean.values, + str(map_output_dir / 'prediction_ensmean_99th_percentile'), + title='CorrDiff Ensemble Mean: 99th Percentile Precipitation', + threshold=0.1, + rfac=1.0 # Already converted to mm/day + ) + logger.info("Prediction ensemble mean map saved") + + # Individual ensemble member maps + logger.info("Creating individual ensemble member maps") + n_members = pred_p99_by_member.shape[0] + for member_idx in range(n_members): + plot_map_precipitation( + pred_p99_by_member[member_idx].values, + str(map_output_dir / f'prediction_member_{member_idx:02d}_99th_percentile'), + title=f'CorrDiff Member {member_idx+1}: 99th Percentile Precipitation', + threshold=0.1, + rfac=1.0 # Already converted to mm/day + ) + logger.info(f"Individual ensemble member maps saved ({n_members} members)") + + logger.info(f"All maps saved to {map_output_dir}") + logger.info("99th percentile precipitation mapping completed successfully") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/hirad/maps.sh b/src/hirad/maps.sh index db6722b..aa1bcbd 100644 --- a/src/hirad/maps.sh +++ b/src/hirad/maps.sh @@ -46,4 +46,5 @@ export OMP_NUM_THREADS=72 srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/eval/snapshots.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/ap_99pctl.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file From 588681a197754deebe4ca96e2cf8eef24e5c6cd1 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Tue, 5 Aug 2025 15:57:16 +0200 Subject: [PATCH 119/189] plot mean --- src/hirad/eval/map_mean.py | 173 +++++++++++++++++++++++++++++++++++++ src/hirad/maps.sh | 7 +- 2 files changed, 177 insertions(+), 3 deletions(-) create mode 100644 src/hirad/eval/map_mean.py diff --git a/src/hirad/eval/map_mean.py b/src/hirad/eval/map_mean.py new file mode 100644 index 0000000..2eae774 --- /dev/null +++ b/src/hirad/eval/map_mean.py @@ -0,0 +1,173 @@ +""" +Plots maps of the mean precipitation over the entire time period. + +This script computes the temporal mean of precipitation for each grid point +and creates maps for target (COSMO-2), baseline (ERA5), and predictions (CorrDiff). +""" +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import plot_map_precipitation + +# Constants +CONV_FACTOR = 100 * 24 # Convert meters to mm/day +LOG_INTERVAL = 24 # Log progress every N timesteps + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for mean precipitation maps") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root and loader + out_root = Path(cfg.generation.io.output_path or './outputs') + def load(ts, fn): + return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + + # Find channel indices + out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} + in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + tp_out = out_ch['tp'] + tp_in = in_ch.get('tp', tp_out) + logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") + + # -- Process target -- + logger.info("Processing target (COSMO-2)") + target_data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing target timestep {i+1}/{len(times)}: {ts}") + data = load(ts, f"{ts}-target")[tp_out] + target_data_list.append(data) + + target_da = xr.DataArray( + np.stack(target_data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + + # Compute mean across all time steps + target_mean = target_da.mean(dim='time') + logger.info("Target mean computed") + + # -- Process baseline -- + logger.info("Processing baseline (ERA5)") + baseline_data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing baseline timestep {i+1}/{len(times)}: {ts}") + data = load(ts, f"{ts}-baseline")[tp_in] + baseline_data_list.append(data) + + baseline_da = xr.DataArray( + np.stack(baseline_data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + + # Apply scaling factor for baseline and compute mean + baseline_mean = (baseline_da / 6.0).mean(dim='time') + logger.info("Baseline mean computed") + + # -- Process predictions -- + logger.info("Processing predictions (CorrDiff)") + pred_data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing predictions timestep {i+1}/{len(times)}: {ts}") + preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] + # Extract precipitation channel + tp_data = preds[:, tp_out] # [n_members, lat, lon] + tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon']) + pred_data_list.append(tp_da) + + pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon] + pred_da = pred_da.assign_coords({ + 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + }) + pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') + + # Compute mean across time for each member, then ensemble mean + pred_mean_by_member = pred_da.mean(dim='time') + pred_mean_mean = pred_mean_by_member.mean(dim='member') + logger.info("Predictions mean computed") + + # Create output directory + map_output_dir = out_root / 'maps_mean' + map_output_dir.mkdir(parents=True, exist_ok=True) + + # Plot maps using the precipitation-specific plotting function + logger.info("Creating precipitation maps") + + # Target map + plot_map_precipitation( + target_mean.values, + str(map_output_dir / 'target_mean'), + title='COSMO-2 Analysis: Mean Precipitation', + threshold=0.01, # Lower threshold for mean precipitation + rfac=1.0 # Already converted to mm/day + ) + logger.info("Target map saved") + + # Baseline map + plot_map_precipitation( + baseline_mean.values, + str(map_output_dir / 'baseline_mean'), + title='ERA5: Mean Precipitation', + threshold=0.01, # Lower threshold for mean precipitation + rfac=1.0 # Already converted to mm/day + ) + logger.info("Baseline map saved") + + # Prediction ensemble mean map + plot_map_precipitation( + pred_mean_mean.values, + str(map_output_dir / 'prediction_ensmean_mean'), + title='CorrDiff Ensemble Mean: Mean Precipitation', + threshold=0.01, # Lower threshold for mean precipitation + rfac=1.0 # Already converted to mm/day + ) + logger.info("Prediction ensemble mean map saved") + + # Individual ensemble member maps + logger.info("Creating individual ensemble member maps") + n_members = pred_mean_by_member.shape[0] + for member_idx in range(n_members): + plot_map_precipitation( + pred_mean_by_member[member_idx].values, + str(map_output_dir / f'prediction_member_{member_idx:02d}_mean'), + title=f'CorrDiff Member {member_idx+1}: Mean Precipitation', + threshold=0.01, # Lower threshold for mean precipitation + rfac=1.0 # Already converted to mm/day + ) + logger.info(f"Individual ensemble member maps saved ({n_members} members)") + + logger.info(f"All maps saved to {map_output_dir}") + logger.info("Mean precipitation mapping completed successfully") + + +if __name__ == '__main__': + main() diff --git a/src/hirad/maps.sh b/src/hirad/maps.sh index aa1bcbd..543de5a 100644 --- a/src/hirad/maps.sh +++ b/src/hirad/maps.sh @@ -3,12 +3,12 @@ #SBATCH --job-name="plot" ### HARDWARE ### -#SBATCH --partition=debug +#SBATCH --partition=normal #SBATCH --nodes=1 #SBATCH --ntasks-per-node=2 #SBATCH --gpus-per-node=2 #SBATCH --cpus-per-task=72 -#SBATCH --time=00:10:00 +#SBATCH --time=01:00:00 #SBATCH --no-requeue #SBATCH --exclusive @@ -46,5 +46,6 @@ export OMP_NUM_THREADS=72 srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/eval/snapshots.py --config-name=generate_era_cosmo.yaml - python src/hirad/eval/ap_99pctl.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/map_99pctl.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/map_mean.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file From 3ca921c8d8f206e4ef5f97d248835465318ad487 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 6 Aug 2025 14:11:18 +0200 Subject: [PATCH 120/189] Update model compilation - Adjusted model compilation logic to compile models before DistributedDataParallel is called. This reduces overhead of DDP that causes torch compile to issue warnings - Added device_id parameter to DistributedManager initialization. - Removed 'reduce-overhead' mode from torch.compile to fix CUDA graph compilation issues during inference. - Changed logging from error to warning when model file is not found. --- src/hirad/distributed/manager.py | 1 + src/hirad/inference/generate.py | 4 ++-- src/hirad/training/train.py | 5 ++++- src/hirad/utils/checkpoint.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py index eca46c6..9695818 100644 --- a/src/hirad/distributed/manager.py +++ b/src/hirad/distributed/manager.py @@ -551,6 +551,7 @@ def setup( backend, rank=manager.rank, world_size=manager.world_size, + device_id=manager.device, ) # rank=manager.rank, # world_size=manager.world_size, diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index a553af8..4ef2970 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -137,8 +137,8 @@ def main(cfg: DictConfig) -> None: # Only compile residual network # Overhead of compiling regression network outweights any benefits if net_res: - net_res = torch.compile(net_res, mode="reduce-overhead") - + net_res = torch.compile(net_res) #, mode="reduce-overhead") + # removed reduce-overhead because it was breaking cuda graph compilation generator = Generator( net_reg=net_reg, net_res=net_res, diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 519845b..ba1b975 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -282,6 +282,8 @@ def main(cfg: DictConfig) -> None: # Enable distributed data parallel if applicable if dist.world_size > 1: + if use_torch_compile: + model = torch.compile(model) model = DistributedDataParallel( model, device_ids=[dist.local_rank], @@ -332,7 +334,8 @@ def main(cfg: DictConfig) -> None: # Compile the model and regression net if applicable if use_torch_compile: - model = torch.compile(model) + if dist.world_size==1: + model = torch.compile(model) if regression_net: regression_net = torch.compile(regression_net) diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index a346b16..070f8c3 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -279,7 +279,7 @@ def load_checkpoint( path, name, index=epoch, ) if not Path(file_name).exists(): - checkpoint_logging.error( + checkpoint_logging.warning( f"Could not find valid model file {file_name}, skipping load" ) else: From f4453421e51d8b89de388443563081e1659fd74d Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 6 Aug 2025 14:26:03 +0200 Subject: [PATCH 121/189] turn on optimizations --- src/hirad/conf/training/era_cosmo_diffusion.yaml | 13 ++++++++----- src/hirad/conf/training/era_cosmo_regression.yaml | 13 +++++++------ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index 40e37f3..2a2fd33 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -1,10 +1,10 @@ # Hyperparameters hp: - training_duration: 20000 + training_duration: 10000 # Training duration based on the number of processed samples total_batch_size: "auto" # Total batch size - batch_size_per_gpu: 22 + batch_size_per_gpu: 20 # Batch size per GPU lr: 0.0002 # Learning rate @@ -26,16 +26,19 @@ perf: # DataLoader worker processes songunet_checkpoint_level: 0 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint - + use_apex_gn: True + torch_compile: True + profile_mode: False # I/O io: - regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression + # regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression + regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo_perf_apex_gn/checkpoints_regression # Where to load the regression checkpoint print_progress_freq: 500 # How often to print progress save_checkpoint_freq: 10000 # How often to save the checkpoints, measured in number of processed samples - visualization_freq: 200000 + visualization_freq: 10000 # how often to visualize network outputs validation_freq: 2000 # how often to record the validation loss, measured in number of processed samples diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index 5cd5d0d..40e9005 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,10 +1,10 @@ # Hyperparameters hp: - training_duration: 40000 + training_duration: 10000 # Training duration based on the number of processed samples total_batch_size: "auto" # Total batch size -- based 8 per GPU -- 2 nodes is 2x8x4 -- see sbatch vars for how many gpus. diffusion need to point to the rgression. - batch_size_per_gpu: 22 + batch_size_per_gpu: 20 # Batch size per GPU lr: 0.0002 # Learning rate @@ -26,16 +26,17 @@ perf: # DataLoader worker processes songunet_checkpoint_level: 0 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint - # torch_compile: True - # use_apex_gn: True + use_apex_gn: True + torch_compile: True + profile_mode: False # I/O io: print_progress_freq: 500 # How often to print progress - save_checkpoint_freq: 100000 + save_checkpoint_freq: 10000 # How often to save the checkpoints, measured in number of processed samples - visualization_freq: 200000 + visualization_freq: 10000 # how often to visualize network output validation_freq: 2000 # how often to record the validation loss, measured in number of processed samples From 062c8d8de2fbfcb0c09099f4e9e942885bb1cbeb Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 6 Aug 2025 15:58:26 +0200 Subject: [PATCH 122/189] add some common untility functions --- src/hirad/eval/plotting.py | 68 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 2eb1c62..39b3208 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -4,7 +4,11 @@ import cartopy.feature as cfeature import matplotlib.pyplot as plt import numpy as np +import torch +import xarray as xr from matplotlib.colors import BoundaryNorm, ListedColormap +from pathlib import Path +from datetime import datetime # COSMO‑2 GRID: TODO: Add to dataset config @@ -12,6 +16,68 @@ LON = np.arange(-6.82, 4.80 + 0.02, 0.02) RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) +# Constants for data processing +CONV_FACTOR_HOURLY = 100 # Convert precip meters to mm/h +CONV_FACTOR = CONV_FACTOR_HOURLY * 24 # Convert precip from meters to mm/day +WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h + +def get_channel_indices(dataset, channels=None): + """ + Get channel indices for input and output channels from dataset. + + Args: + dataset: Dataset object with input_channels() and output_channels() methods + channels: Optional list of channel names to look up. If None, returns all channel mappings. + + Returns: + dict: Dictionary with 'input' and 'output' keys, each containing channel name -> index mapping + + Example: + indices = get_channel_indices(dataset, ['tp', '2t', '10u', '10v']) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) # Fallback to output index if not in input + """ + out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} + in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + + if channels is None: + return {'input': in_ch, 'output': out_ch} + + # Filter to requested channels only + filtered_out = {ch: out_ch[ch] for ch in channels if ch in out_ch} + filtered_in = {ch: in_ch[ch] for ch in channels if ch in in_ch} + + return {'input': filtered_in, 'output': filtered_out} + +def load_land_sea_mask(path='/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy'): + """Load and retrun a land-sea mask as xarray DataArray.""" + lsm_data = np.load(path).reshape(352, 544) + return xr.DataArray( + np.where(lsm_data >= 0.5, 1.0, np.nan), + dims=['lat', 'lon'], + coords={"lat": np.arange(352), "lon": np.arange(544)} + ) + +def load_prediction_data_torch(out_root, times, filename_pattern, conv_factor=CONV_FACTOR): + """Generic loader for prediction data with conversion factor.""" + def load(ts, fn): + return torch.load(out_root/ts/fn, weights_only=False) * conv_factor + + return lambda ts, fn: load(ts, filename_pattern.format(ts=ts, fn=fn)) + +def concat_and_group_diurnal(list_of_da, is_member=False, scale=1.0): + """Helper to concatenate DataArrays and compute diurnal statistics.""" + da = xr.concat(list_of_da, dim="time").groupby("time.hour") + if is_member: + timmean = da.mean(dim='time') * scale + mean = timmean.mean(dim='member') + std = timmean.std(dim='member') + else: + mean = da.mean(dim='time') * scale + std = None + return mean, std + + def plot_map(values: np.array, filename: str, label='', @@ -108,7 +174,6 @@ def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np. def plot_scores_vs_t(scores: dict[str,np.ndarray], times: np.array, filename: str, xlabel='', ylabel='', title=''): - fig = plt.figure() ax = plt.subplot() colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w'] # TODO, add more i=0 @@ -131,7 +196,6 @@ def plot_scores_vs_t(scores: dict[str,np.ndarray], times: np.array, filename: st plt.close('all') def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): - fig = plt.figure() for k in freqs.keys(): plt.loglog(freqs[k], spec[k], label=k) plt.title(channel_name) From 392637dd0bfe9e14a2f6d001588456d36a95a0a9 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 6 Aug 2025 15:59:36 +0200 Subject: [PATCH 123/189] use get_channel_indices function --- src/hirad/eval/diurnal_cycle_precip.py | 8 ++++---- src/hirad/eval/diurnal_cycle_temp_wind.py | 17 ++++++++++++----- src/hirad/eval/hist.py | 8 ++++---- src/hirad/eval/map_99pctl.py | 9 ++++----- src/hirad/eval/map_mean.py | 9 ++++----- src/hirad/eval/percentile99_cycle_precip.py | 8 ++++---- src/hirad/eval/probability_of_exceedance.py | 8 ++++---- src/hirad/eval_precip.sh | 2 +- 8 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index ef57155..56e19a5 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -12,6 +12,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices # Constants CONV_FACTOR = 100*24 # Convert meters to mm/day @@ -40,10 +41,9 @@ def load(ts, fn): return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices - out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} - in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} - tp_out = out_ch['tp'] - tp_in = in_ch.get('tp', tp_out) + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index 90d178f..1859ad7 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -12,6 +12,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices LOG_INTERVAL = 24 @@ -35,18 +36,24 @@ def main(cfg: DictConfig): ) # Indices for channels - out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} - in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} + indices = get_channel_indices(dataset) + out_ch = indices['output'] + in_ch = indices['input'] + + # Temperature channel (try '2t' first, fallback to 't2m') t2m_out = out_ch.get('2t', out_ch.get('t2m')) t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out)) - u_out = out_ch.get('10u') - v_out = out_ch.get('10v') + + # Wind channels + u_out = out_ch['10u'] u_in = in_ch.get('10u', u_out) + v_out = out_ch['10v'] v_in = in_ch.get('10v', v_out) # Output path out_root = Path(cfg.generation.io.output_path or './outputs') - load = lambda ts, fn: torch.load(out_root/ts/fn, weights_only=False) + def load(ts, fn): + return torch.load(out_root/ts/fn, weights_only=False) # Land-sea mask lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py index 49b6bd2..2a9f8a7 100644 --- a/src/hirad/eval/hist.py +++ b/src/hirad/eval/hist.py @@ -17,6 +17,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices # Constants CONV_FACTOR = 100 # Convert meters to mm/h @@ -128,10 +129,9 @@ def load(ts, fn): return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices - out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} - in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} - tp_out = out_ch['tp'] - tp_in = in_ch.get('tp', tp_out) + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask diff --git a/src/hirad/eval/map_99pctl.py b/src/hirad/eval/map_99pctl.py index 547a81e..4e7395a 100644 --- a/src/hirad/eval/map_99pctl.py +++ b/src/hirad/eval/map_99pctl.py @@ -17,7 +17,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import plot_map_precipitation +from hirad.eval.plotting import plot_map_precipitation, get_channel_indices # Constants CONV_FACTOR = 100 * 24 # Convert meters to mm/day @@ -48,10 +48,9 @@ def load(ts, fn): return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices - out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} - in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} - tp_out = out_ch['tp'] - tp_in = in_ch.get('tp', tp_out) + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # -- Process target -- diff --git a/src/hirad/eval/map_mean.py b/src/hirad/eval/map_mean.py index 2eae774..40f80c9 100644 --- a/src/hirad/eval/map_mean.py +++ b/src/hirad/eval/map_mean.py @@ -17,7 +17,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import plot_map_precipitation +from hirad.eval.plotting import plot_map_precipitation, get_channel_indices # Constants CONV_FACTOR = 100 * 24 # Convert meters to mm/day @@ -48,10 +48,9 @@ def load(ts, fn): return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices - out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} - in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} - tp_out = out_ch['tp'] - tp_in = in_ch.get('tp', tp_out) + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # -- Process target -- diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index 3ee9d1c..15aec33 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -19,6 +19,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices # Constants CONV_FACTOR = 100 * 24 # Convert meters to mm/day @@ -77,10 +78,9 @@ def load(ts, fn): return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices - out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} - in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} - tp_out = out_ch['tp'] - tp_in = in_ch.get('tp', tp_out) + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask diff --git a/src/hirad/eval/probability_of_exceedance.py b/src/hirad/eval/probability_of_exceedance.py index e5da6b9..52452c1 100644 --- a/src/hirad/eval/probability_of_exceedance.py +++ b/src/hirad/eval/probability_of_exceedance.py @@ -17,6 +17,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices # Constants CONV_FACTOR = 100 # Convert meters to mm/h @@ -125,10 +126,9 @@ def load(ts, fn): return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices - out_ch = {c.name: i for i, c in enumerate(dataset.output_channels())} - in_ch = {c.name: i for i, c in enumerate(dataset.input_channels())} - tp_out = out_ch['tp'] - tp_in = in_ch.get('tp', tp_out) + indices = get_channel_indices(dataset) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index bb4b810..44eb3bc 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -13,7 +13,7 @@ #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=./logs/plots_precipe.log +#SBATCH --output=./logs/plots_precip.log ### ENVIRONMENT #### #SBATCH -A a161 From 4dbaae3fbfff9c6d1d69d9afdc7e99ca9b37e390 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 6 Aug 2025 16:55:33 +0200 Subject: [PATCH 124/189] remove slurm account variables --- ci/cscs.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index 3e50b37..c3ac15c 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -25,5 +25,5 @@ test_job: variables: SLURM_JOB_NUM_NODES: 2 SLURM_NTASKS: 2 - SLURM_ACCOUNT: a161 - SBATCH_ACCOUNT: a161 + #SLURM_ACCOUNT: a161 + #SBATCH_ACCOUNT: a161 From 63bb3cbce268407d9d52013af84bef34d9682125 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 6 Aug 2025 17:04:34 +0200 Subject: [PATCH 125/189] put more constants in plotting, use land-mask from plotting --- src/hirad/eval/diurnal_cycle_precip.py | 11 +++-------- src/hirad/eval/diurnal_cycle_temp_wind.py | 8 +++----- src/hirad/eval/hist.py | 14 +++----------- src/hirad/eval/map_99pctl.py | 6 +----- src/hirad/eval/map_mean.py | 6 +----- src/hirad/eval/percentile99_cycle_precip.py | 12 ++---------- src/hirad/eval/plotting.py | 5 +++-- src/hirad/eval/probability_of_exceedance.py | 14 +++----------- 8 files changed, 19 insertions(+), 57 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index 56e19a5..d6f0eaf 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -12,12 +12,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import get_channel_indices - -# Constants -CONV_FACTOR = 100*24 # Convert meters to mm/day -WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h -LOG_INTERVAL = 24 # Log progress every N timesteps +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR, WET_THRESHOLD, LOG_INTERVAL @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig): @@ -47,8 +42,8 @@ def load(ts, fn): logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask - lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) - land_mask = np.where(lsm_data >= 0.5, 1.0, np.nan) + land_mask_da = load_land_sea_mask() + land_mask = land_mask_da.values coords = {"lat": np.arange(land_mask.shape[0]), "lon": np.arange(land_mask.shape[1])} # Prepare lists to collect DataArrays diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index 1859ad7..85ba9cf 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -12,9 +12,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import get_channel_indices - -LOG_INTERVAL = 24 +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, LOG_INTERVAL @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig): @@ -56,8 +54,8 @@ def load(ts, fn): return torch.load(out_root/ts/fn, weights_only=False) # Land-sea mask - lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) - land_mask = np.where(lsm_data >= 0.5, 1.0, np.nan) + land_mask_da = load_land_sea_mask() + land_mask = land_mask_da.values coords = {"lat": np.arange(land_mask.shape[0]), "lon": np.arange(land_mask.shape[1])} # Prepare lists to collect DataArrays diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py index 2a9f8a7..8f2da49 100644 --- a/src/hirad/eval/hist.py +++ b/src/hirad/eval/hist.py @@ -17,11 +17,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import get_channel_indices - -# Constants -CONV_FACTOR = 100 # Convert meters to mm/h -LOG_INTERVAL = 24 # Log progress every N timesteps +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR_HOURLY, LOG_INTERVAL def save_distribution_plot(hist_data_dict, bin_edges, labels, colors, title, ylabel, out_path, percentiles_data=None): @@ -126,7 +122,7 @@ def main(cfg: DictConfig): out_root = Path(cfg.generation.io.output_path or './outputs') def load(ts, fn): - return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR_HOURLY # Find channel indices indices = get_channel_indices(dataset) @@ -135,11 +131,7 @@ def load(ts, fn): logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask - lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) - land_mask = xr.DataArray( - np.where(lsm_data >= 0.5, 1.0, np.nan), - dims=['lat', 'lon'] - ) + land_mask = load_land_sea_mask() # Define histogram bins bins = np.logspace(-1, 1, 50) # Log-spaced bins for precipitation diff --git a/src/hirad/eval/map_99pctl.py b/src/hirad/eval/map_99pctl.py index 4e7395a..a399a82 100644 --- a/src/hirad/eval/map_99pctl.py +++ b/src/hirad/eval/map_99pctl.py @@ -17,11 +17,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import plot_map_precipitation, get_channel_indices - -# Constants -CONV_FACTOR = 100 * 24 # Convert meters to mm/day -LOG_INTERVAL = 24 # Log progress every N timesteps +from hirad.eval.plotting import plot_map_precipitation, get_channel_indices, CONV_FACTOR, LOG_INTERVAL @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") diff --git a/src/hirad/eval/map_mean.py b/src/hirad/eval/map_mean.py index 40f80c9..e7c33ae 100644 --- a/src/hirad/eval/map_mean.py +++ b/src/hirad/eval/map_mean.py @@ -17,11 +17,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import plot_map_precipitation, get_channel_indices - -# Constants -CONV_FACTOR = 100 * 24 # Convert meters to mm/day -LOG_INTERVAL = 24 # Log progress every N timesteps +from hirad.eval.plotting import plot_map_precipitation, get_channel_indices, CONV_FACTOR, LOG_INTERVAL @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/percentile99_cycle_precip.py index 15aec33..3ed38a0 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/percentile99_cycle_precip.py @@ -19,11 +19,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import get_channel_indices - -# Constants -CONV_FACTOR = 100 * 24 # Convert meters to mm/day -LOG_INTERVAL = 24 # Log progress every N timesteps +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR def hour_of(dt: str, fmt: str = "%Y%m%d-%H%M") -> int: @@ -84,11 +80,7 @@ def load(ts, fn): logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask - lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) - land_mask = xr.DataArray( - np.where(lsm_data >= 0.5, 1.0, np.nan), - dims=['lat', 'lon'] - ) + land_mask = load_land_sea_mask() # Storage for diurnal cycles pct99_mean = {} diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 39b3208..9c94e30 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -17,9 +17,10 @@ RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) # Constants for data processing -CONV_FACTOR_HOURLY = 100 # Convert precip meters to mm/h -CONV_FACTOR = CONV_FACTOR_HOURLY * 24 # Convert precip from meters to mm/day +CONV_FACTOR_HOURLY = 100 # Convert precip of ERA5 from meters to mm/h +CONV_FACTOR = CONV_FACTOR_HOURLY * 24 # Convert precip of ERA5 from from meters to mm/day WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h +LOG_INTERVAL = 24 # Log progress every N timesteps def get_channel_indices(dataset, channels=None): """ diff --git a/src/hirad/eval/probability_of_exceedance.py b/src/hirad/eval/probability_of_exceedance.py index 52452c1..3ab645a 100644 --- a/src/hirad/eval/probability_of_exceedance.py +++ b/src/hirad/eval/probability_of_exceedance.py @@ -17,11 +17,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import get_channel_indices - -# Constants -CONV_FACTOR = 100 # Convert meters to mm/h -LOG_INTERVAL = 24 # Log progress every N timesteps +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR_HOURLY, LOG_INTERVAL def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title, ylabel, out_path, percentiles_data=None): @@ -123,7 +119,7 @@ def main(cfg: DictConfig): out_root = Path(cfg.generation.io.output_path or './outputs') def load(ts, fn): - return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR + return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR_HOURLY # Find channel indices indices = get_channel_indices(dataset) @@ -132,11 +128,7 @@ def load(ts, fn): logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask - lsm_data = np.load('/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy').reshape(352,544) - land_mask = xr.DataArray( - np.where(lsm_data >= 0.5, 1.0, np.nan), - dims=['lat', 'lon'] - ) + land_mask = load_land_sea_mask() # Define thresholds for exceedance calculation thresholds = np.logspace(-2, 2, 200) # From 0.01 to 100 mm/h From 6236e7e594ce49042e78741e7a3a3d5cd43e55f9 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 6 Aug 2025 17:09:23 +0200 Subject: [PATCH 126/189] re-use concat_and_group_diurnal --- src/hirad/eval/diurnal_cycle_precip.py | 26 ++++++----------------- src/hirad/eval/diurnal_cycle_temp_wind.py | 26 ++++++----------------- 2 files changed, 14 insertions(+), 38 deletions(-) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip.py index d6f0eaf..2270db9 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip.py @@ -12,7 +12,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR, WET_THRESHOLD, LOG_INTERVAL +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR, WET_THRESHOLD, LOG_INTERVAL, concat_and_group_diurnal @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig): @@ -75,26 +75,14 @@ def load(ts, fn): if idx % LOG_INTERVAL == 0 or idx == len(times): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") - # Helper to concat and compute diurnal stats - def concat_and_group(list_of_da, is_member=False, scale=1.0): - da = xr.concat(list_of_da, dim="time").groupby("time.hour") - if is_member: - timmean = da.mean(dim='time') * scale - mean = timmean.mean(dim='member') - std = timmean.std(dim='member') - else: - mean = da.mean(dim='time') * scale - std = None - return mean, std - # Compute diurnal means and stds - amount_target_mean, _ = concat_and_group(target_precip) - amount_baseline_mean, _ = concat_and_group(baseline_precip) - amount_pred_mean, amount_pred_std = concat_and_group(pred_precip, is_member=True) + amount_target_mean, _ = concat_and_group_diurnal(target_precip) + amount_baseline_mean, _ = concat_and_group_diurnal(baseline_precip) + amount_pred_mean, amount_pred_std = concat_and_group_diurnal(pred_precip, is_member=True) - wet_target_mean, _ = concat_and_group(target_wet, scale=100.0) # scale to percentage - wet_baseline_mean, _ = concat_and_group(baseline_wet, scale=100.0) - wet_pred_mean, wet_pred_std = concat_and_group(pred_wet, is_member=True, scale=100.0) + wet_target_mean, _ = concat_and_group_diurnal(target_wet, scale=100.0) # scale to percentage + wet_baseline_mean, _ = concat_and_group_diurnal(baseline_wet, scale=100.0) + wet_pred_mean, wet_pred_std = concat_and_group_diurnal(pred_wet, is_member=True, scale=100.0) def save_plot(hour, means, stds, labels, ylabel, title, out_path): hrs = np.concatenate([hour.values, [24]]) diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index 85ba9cf..e624641 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -12,7 +12,7 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, LOG_INTERVAL +from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, LOG_INTERVAL, concat_and_group_diurnal @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig): @@ -115,26 +115,14 @@ def load(ts, fn): if idx % LOG_INTERVAL == 0 or idx == len(times): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") - # Helper to concat and compute diurnal stats - def concat_and_group(list_of_da, is_member=False, scale=1.0): - da = xr.concat(list_of_da, dim="time").groupby("time.hour") - if is_member: - timmean = da.mean(dim='time') * scale - mean = timmean.mean(dim='member') - std = timmean.std(dim='member') - else: - mean = da.mean(dim='time') * scale - std = None - return mean, std - # Compute diurnal means and stds - temp_target_mean, _ = concat_and_group(target_temp) - temp_baseline_mean, _ = concat_and_group(baseline_temp) - temp_pred_mean, temp_pred_std = concat_and_group(pred_temp, is_member=True) + temp_target_mean, _ = concat_and_group_diurnal(target_temp) + temp_baseline_mean, _ = concat_and_group_diurnal(baseline_temp) + temp_pred_mean, temp_pred_std = concat_and_group_diurnal(pred_temp, is_member=True) - wind_target_mean, _ = concat_and_group(target_wind) - wind_baseline_mean, _ = concat_and_group(baseline_wind) - wind_pred_mean, wind_pred_std = concat_and_group(pred_wind, is_member=True) + wind_target_mean, _ = concat_and_group_diurnal(target_wind) + wind_baseline_mean, _ = concat_and_group_diurnal(baseline_wind) + wind_pred_mean, wind_pred_std = concat_and_group_diurnal(pred_wind, is_member=True) def save_plot(hour, means, stds, labels, ylabel, title, out_path): hrs = np.concatenate([hour.values, [24]]) From 489d0a6e340c9ca4ae67ec2b55198c6bd08162f7 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 6 Aug 2025 17:23:11 +0200 Subject: [PATCH 127/189] remove unneeded function --- src/hirad/eval/plotting.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 9c94e30..edbcbc7 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -59,13 +59,6 @@ def load_land_sea_mask(path='/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy'): coords={"lat": np.arange(352), "lon": np.arange(544)} ) -def load_prediction_data_torch(out_root, times, filename_pattern, conv_factor=CONV_FACTOR): - """Generic loader for prediction data with conversion factor.""" - def load(ts, fn): - return torch.load(out_root/ts/fn, weights_only=False) * conv_factor - - return lambda ts, fn: load(ts, filename_pattern.format(ts=ts, fn=fn)) - def concat_and_group_diurnal(list_of_da, is_member=False, scale=1.0): """Helper to concatenate DataArrays and compute diurnal statistics.""" da = xr.concat(list_of_da, dim="time").groupby("time.hour") From e2b1aecc0c2ef609095bbfd265a85933caa6e285 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 6 Aug 2025 17:41:24 +0200 Subject: [PATCH 128/189] cleanup dcycles --- ... => diurnal_cycle_precip_mean_wet-hour.py} | 68 ++++++++--------- ..._precip.py => diurnal_cycle_precip_p99.py} | 12 +-- src/hirad/eval/diurnal_cycle_temp_wind.py | 73 +++++++------------ src/hirad/eval_precip.sh | 4 +- 4 files changed, 66 insertions(+), 91 deletions(-) rename src/hirad/eval/{diurnal_cycle_precip.py => diurnal_cycle_precip_mean_wet-hour.py} (72%) rename src/hirad/eval/{percentile99_cycle_precip.py => diurnal_cycle_precip_p99.py} (93%) diff --git a/src/hirad/eval/diurnal_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py similarity index 72% rename from src/hirad/eval/diurnal_cycle_precip.py rename to src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py index 2270db9..f271588 100644 --- a/src/hirad/eval/diurnal_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py @@ -14,6 +14,27 @@ from hirad.utils.function_utils import get_time_from_range from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR, WET_THRESHOLD, LOG_INTERVAL, concat_and_group_diurnal +def save_plot(hour, means, stds, labels, ylabel, title, out_path): + hrs = np.concatenate([hour.values, [24]]) + plt.figure(figsize=(8,4)) + for mean, std, label in zip(means, stds, labels): + vals = np.append(mean.values, mean.values[0]) + line, = plt.plot(hrs, vals, label=label) + if std is not None: + stdv = np.append(std.values, std.values[0]) + plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3) + plt.xlabel('Hour (UTC)') + plt.xticks(range(0,25,3)) + plt.xlim(0,24) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True) + plt.legend() + plt.tight_layout() + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path) + plt.close() + @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig): # Setup logging @@ -31,9 +52,8 @@ def main(cfg: DictConfig): ds_cfg, times, cfg.generation.get('has_lead_time', False) ) + # Location of the output from inference out_root = Path(cfg.generation.io.output_path or './outputs') - def load(ts, fn): - return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices indices = get_channel_indices(dataset) @@ -42,9 +62,7 @@ def load(ts, fn): logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") # Land-sea mask - land_mask_da = load_land_sea_mask() - land_mask = land_mask_da.values - coords = {"lat": np.arange(land_mask.shape[0]), "lon": np.arange(land_mask.shape[1])} + land_mask = load_land_sea_mask() # Prepare lists to collect DataArrays target_precip, baseline_precip, pred_precip = [], [], [] @@ -53,14 +71,19 @@ def load(ts, fn): # Collect data for idx, ts in enumerate(times, 1): dt = datetimes[idx-1] - target = load(ts, f"{ts}-target")[tp_out] * land_mask - baseline = load(ts, f"{ts}-baseline")[tp_in] * land_mask / 6. # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset - preds = load(ts, f"{ts}-predictions")[:, tp_out, :, :] * land_mask + target = torch.load(out_root/ts/f"{ts}-target", weights_only=False)[tp_out] * CONV_FACTOR + baseline = torch.load(out_root/ts/f"{ts}-baseline", weights_only=False)[tp_in] * CONV_FACTOR / 6. # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False)[:, tp_out, :, :] * CONV_FACTOR # DataArrays for spatial means at each timestep - da_target = xr.DataArray(target, dims=("lat","lon"), coords=coords) - da_baseline = xr.DataArray(baseline, dims=("lat","lon"), coords=coords) - da_preds = xr.DataArray(preds, dims=("member","lat","lon"), coords={"member": np.arange(preds.shape[0]), **coords}) + da_target = xr.DataArray(target, dims=("lat","lon"), coords=land_mask.coords) + da_baseline = xr.DataArray(baseline, dims=("lat","lon"), coords=land_mask.coords) + da_preds = xr.DataArray(preds, dims=("member","lat","lon"), coords={"member": np.arange(preds.shape[0]), **land_mask.coords}) + + # Apply land mask after conversion to xarray + da_target = da_target * land_mask + da_baseline = da_baseline * land_mask + da_preds = da_preds * land_mask # Spatial mean target_precip.append(da_target.mean(dim=("lat","lon")).assign_coords(time=dt)) @@ -80,31 +103,10 @@ def load(ts, fn): amount_baseline_mean, _ = concat_and_group_diurnal(baseline_precip) amount_pred_mean, amount_pred_std = concat_and_group_diurnal(pred_precip, is_member=True) - wet_target_mean, _ = concat_and_group_diurnal(target_wet, scale=100.0) # scale to percentage + wet_target_mean, _ = concat_and_group_diurnal(target_wet, scale=100.0) # scale to obtain percentages wet_baseline_mean, _ = concat_and_group_diurnal(baseline_wet, scale=100.0) wet_pred_mean, wet_pred_std = concat_and_group_diurnal(pred_wet, is_member=True, scale=100.0) - def save_plot(hour, means, stds, labels, ylabel, title, out_path): - hrs = np.concatenate([hour.values, [24]]) - plt.figure(figsize=(8,4)) - for mean, std, label in zip(means, stds, labels): - vals = np.append(mean.values, mean.values[0]) - line, = plt.plot(hrs, vals, label=label) - if std is not None: - stdv = np.append(std.values, std.values[0]) - plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3) - plt.xlabel('Hour (UTC)') - plt.xticks(range(0,25,3)) - plt.xlim(0,24) - plt.ylabel(ylabel) - plt.title(title) - plt.grid(True) - plt.legend() - plt.tight_layout() - Path(out_path).parent.mkdir(parents=True, exist_ok=True) - plt.savefig(out_path) - plt.close() - # Generate plots save_plot( amount_target_mean.hour, diff --git a/src/hirad/eval/percentile99_cycle_precip.py b/src/hirad/eval/diurnal_cycle_precip_p99.py similarity index 93% rename from src/hirad/eval/percentile99_cycle_precip.py rename to src/hirad/eval/diurnal_cycle_precip_p99.py index 3ed38a0..bb64153 100644 --- a/src/hirad/eval/percentile99_cycle_precip.py +++ b/src/hirad/eval/diurnal_cycle_precip_p99.py @@ -22,10 +22,6 @@ from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, CONV_FACTOR -def hour_of(dt: str, fmt: str = "%Y%m%d-%H%M") -> int: - return datetime.strptime(dt, fmt).hour - - def save_plot(hours, lines, labels, ylabel, title, out_path): Path(out_path).parent.mkdir(parents=True, exist_ok=True) plt.figure(figsize=(8,4)) @@ -68,10 +64,8 @@ def main(cfg: DictConfig): ) logger.info("Dataset and sampler initialized") - # Output root and loader + # Output root out_root = Path(cfg.generation.io.output_path or './outputs') - def load(ts, fn): - return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR # Find channel indices indices = get_channel_indices(dataset) @@ -92,7 +86,7 @@ def load(ts, fn): data_list = [] for ts in times: - data = load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] * land_mask + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode == 'target' else tp_in] * CONV_FACTOR * land_mask data_list.append(data) da = xr.DataArray( @@ -116,7 +110,7 @@ def load(ts, fn): # Load all prediction data at once into xarray pred_data_list = [] for ts in times: - preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) * CONV_FACTOR # [n_members, n_channels, lat, lon] # Extract precipitation channel and convert to xarray for proper broadcasting tp_data = preds[:, tp_out] # [n_members, lat, lon] tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon']) diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index e624641..c0c97fe 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -54,63 +54,42 @@ def load(ts, fn): return torch.load(out_root/ts/fn, weights_only=False) # Land-sea mask - land_mask_da = load_land_sea_mask() - land_mask = land_mask_da.values - coords = {"lat": np.arange(land_mask.shape[0]), "lon": np.arange(land_mask.shape[1])} + land_mask = load_land_sea_mask() # Prepare lists to collect DataArrays target_temp, baseline_temp, pred_temp = [], [], [] target_wind, baseline_wind, pred_wind = [], [], [] + def mean_over_land(data, dims, coords, time_coord): + da = xr.DataArray(data, dims=dims, coords=coords) * land_mask + return da.mean(dim=("lat","lon")).assign_coords(time=time_coord) + # Loop over timestamps for idx, ts in enumerate(times, 1): dt = datetimes[idx-1] - # Load and apply land mask - target = load(ts, f"{ts}-target") * land_mask - baseline = load(ts, f"{ts}-baseline") * land_mask - predictions = load(ts, f"{ts}-predictions") * land_mask - - # Wrap into DataArrays (convert temperature to Celsius inline) - da_tgt_temp = xr.DataArray( - target[t2m_out] - 273.15, dims=("lat","lon"), coords=coords - ) - da_bsl_temp = xr.DataArray( - baseline[t2m_in] - 273.15, dims=("lat","lon"), coords=coords - ) - tgt_wind = np.hypot(target[u_out], target[v_out]) - bsl_wind = np.hypot(baseline[u_in], baseline[v_in]) - da_tgt_wind = xr.DataArray(tgt_wind, dims=("lat","lon"), coords=coords) - da_bsl_wind = xr.DataArray(bsl_wind, dims=("lat","lon"), coords=coords) - - da_pred_members_temp = xr.DataArray( - predictions[:, t2m_out, :, :] - 273.15, dims=("member","lat","lon"), - coords={"member": np.arange(predictions.shape[0]), **coords} - ) - da_pred_members_wind = xr.DataArray( + # Load data + target = load(ts, f"{ts}-target") + baseline = load(ts, f"{ts}-baseline") + predictions = load(ts, f"{ts}-predictions") + + # Process temperature (convert to Celsius) + target_temp.append(mean_over_land( + target[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) + baseline_temp.append(mean_over_land( + baseline[t2m_in] - 273.15, ("lat","lon"), land_mask.coords, dt)) + pred_temp.append(mean_over_land( + predictions[:, t2m_out, :, :] - 273.15, ("member","lat","lon"), + {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) + + # Process wind speed + target_wind.append(mean_over_land( + np.hypot(target[u_out], target[v_out]), ("lat","lon"), land_mask.coords, dt)) + baseline_wind.append(mean_over_land( + np.hypot(baseline[u_in], baseline[v_in]), ("lat","lon"), land_mask.coords, dt)) + pred_wind.append(mean_over_land( np.hypot(predictions[:, u_out, :, :], predictions[:, v_out, :, :]), - dims=("member","lat","lon"), coords={"member": np.arange(predictions.shape[0]), **coords} - ) - - # Compute spatial mean and assign time coordinate - target_temp.append( - da_tgt_temp.mean(dim=("lat","lon")).assign_coords(time=dt) - ) - baseline_temp.append( - da_bsl_temp.mean(dim=("lat","lon")).assign_coords(time=dt) - ) - pred_temp.append( - da_pred_members_temp.mean(dim=("lat","lon")).assign_coords(time=dt) - ) - target_wind.append( - da_tgt_wind.mean(dim=("lat","lon")).assign_coords(time=dt) - ) - baseline_wind.append( - da_bsl_wind.mean(dim=("lat","lon")).assign_coords(time=dt) - ) - pred_wind.append( - da_pred_members_wind.mean(dim=("lat","lon")).assign_coords(time=dt) - ) + ("member","lat","lon"), {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) if idx % LOG_INTERVAL == 0 or idx == len(times): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index 44eb3bc..52232c4 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -45,8 +45,8 @@ export OMP_NUM_THREADS=72 # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies - python src/hirad/eval/diurnal_cycle_precip.py --config-name=generate_era_cosmo.yaml - python src/hirad/eval/percentile99_cycle_precip.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/diurnal_cycle_precip_p99.py --config-name=generate_era_cosmo.yaml python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=generate_era_cosmo.yaml " From 3928d6e9ea0e1e921f698445df56645495884c99 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 6 Aug 2025 18:10:25 +0200 Subject: [PATCH 129/189] cleanup --- src/hirad/eval/hist.py | 54 ++++++++------------- src/hirad/eval/probability_of_exceedance.py | 48 +++++++----------- 2 files changed, 37 insertions(+), 65 deletions(-) diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py index 8f2da49..bdd241b 100644 --- a/src/hirad/eval/hist.py +++ b/src/hirad/eval/hist.py @@ -118,11 +118,8 @@ def main(cfg: DictConfig): ) logger.info("Dataset and sampler initialized") - # Output root and loader + # Output root out_root = Path(cfg.generation.io.output_path or './outputs') - - def load(ts, fn): - return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR_HOURLY # Find channel indices indices = get_channel_indices(dataset) @@ -136,16 +133,14 @@ def load(ts, fn): # Define histogram bins bins = np.logspace(-1, 1, 50) # Log-spaced bins for precipitation - # Storage for histogram data + # Storage for histogram data and land values hist_data = {} - # Store all land values for percentile calculation all_land_values = {} # -- Process target and baseline -- for mode in ['target', 'baseline']: logger.info(f"Processing mode: {mode}") - # Initialize histogram accumulator and collect all values hist_counts = np.zeros(len(bins) - 1) total_samples = 0 all_values = [] @@ -154,17 +149,15 @@ def load(ts, fn): if i % LOG_INTERVAL == 0: logger.info(f"Processing timestep {i+1}/{len(times)}") - data = load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] * land_mask + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode == 'target' else tp_in] * CONV_FACTOR_HOURLY * land_mask # Apply scaling factor for baseline if mode == 'baseline': data = data / 6.0 - # Extract land values (remove NaN values) land_values = data.values[~np.isnan(data.values)] all_values.extend(land_values) - # Accumulate histogram counts counts, _ = np.histogram(land_values, bins=bins) hist_counts += counts total_samples += len(land_values) @@ -180,28 +173,25 @@ def load(ts, fn): n_members = None member_hist_data = [] - all_member_values = [] # Store all values for each member + all_member_values = [] for i, ts in enumerate(times): if i % LOG_INTERVAL == 0: logger.info(f"Processing timestep {i+1}/{len(times)}") - preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) * CONV_FACTOR_HOURLY # [n_members, n_channels, lat, lon] if n_members is None: n_members = preds.shape[0] - # Initialize histogram accumulators for each member member_hist_data = [np.zeros(len(bins) - 1) for _ in range(n_members)] member_sample_counts = [0 for _ in range(n_members)] - all_member_values = [[] for _ in range(n_members)] # Initialize value storage + all_member_values = [[] for _ in range(n_members)] for member_idx in range(n_members): data = preds[member_idx, tp_out] * land_mask - # Extract land values (remove NaN values) land_values = data.values[~np.isnan(data.values)] - all_member_values[member_idx].extend(land_values) # Store values for percentiles + all_member_values[member_idx].extend(land_values) - # Accumulate histogram counts for this member counts, _ = np.histogram(land_values, bins=bins) member_hist_data[member_idx] += counts member_sample_counts[member_idx] += len(land_values) @@ -219,30 +209,24 @@ def load(ts, fn): # Compute percentiles for all datasets percentiles_data = {} + percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999} - # Target percentiles - target_data_array = xr.DataArray(all_land_values['target']) - target_p99 = target_data_array.quantile(0.99).item() - target_p999 = target_data_array.quantile(0.999).item() - target_p9999 = target_data_array.quantile(0.9999).item() - percentiles_data['target'] = {99: target_p99, 99.9: target_p999, 99.99: target_p9999} - - # Baseline percentiles - baseline_data_array = xr.DataArray(all_land_values['baseline']) - baseline_p99 = baseline_data_array.quantile(0.99).item() - baseline_p999 = baseline_data_array.quantile(0.999).item() - baseline_p9999 = baseline_data_array.quantile(0.9999).item() - percentiles_data['baseline'] = {99: baseline_p99, 99.9: baseline_p999, 99.99: baseline_p9999} + # Target and baseline percentiles + for mode in ['target', 'baseline']: + data_array = xr.DataArray(all_land_values[mode]) + percentiles_data[mode] = { + key: data_array.quantile(p).item() + for key, p in percentiles.items() + } # Ensemble member percentiles percentiles_data['predictions'] = {} for member_idx in range(n_members): member_data_array = xr.DataArray(all_member_values[member_idx]) - member_p99 = member_data_array.quantile(0.99).item() - member_p999 = member_data_array.quantile(0.999).item() - member_p9999 = member_data_array.quantile(0.9999).item() - percentiles_data['predictions'][f'member_{member_idx}'] = {99: member_p99, 99.9: member_p999, 99.99: member_p9999} - + percentiles_data['predictions'][f'member_{member_idx}'] = { + key: member_data_array.quantile(p).item() + for key, p in percentiles.items() + } # Create distribution plots labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] diff --git a/src/hirad/eval/probability_of_exceedance.py b/src/hirad/eval/probability_of_exceedance.py index 3ab645a..febda1f 100644 --- a/src/hirad/eval/probability_of_exceedance.py +++ b/src/hirad/eval/probability_of_exceedance.py @@ -115,11 +115,8 @@ def main(cfg: DictConfig): ) logger.info("Dataset and sampler initialized") - # Output root and loader + # Output root out_root = Path(cfg.generation.io.output_path or './outputs') - - def load(ts, fn): - return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR_HOURLY # Find channel indices indices = get_channel_indices(dataset) @@ -133,9 +130,8 @@ def load(ts, fn): # Define thresholds for exceedance calculation thresholds = np.logspace(-2, 2, 200) # From 0.01 to 100 mm/h - # Storage for exceedance data + # Storage for exceedance data and land values exceedance_data = {} - # Store all land values for percentile calculation all_land_values = {} # -- Process target and baseline -- @@ -148,13 +144,12 @@ def load(ts, fn): if i % LOG_INTERVAL == 0: logger.info(f"Processing timestep {i+1}/{len(times)}") - data = load(ts, f"{ts}-{mode}")[tp_out if mode == 'target' else tp_in] * land_mask + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode == 'target' else tp_in] * CONV_FACTOR_HOURLY * land_mask # Apply scaling factor for baseline if mode == 'baseline': data = data / 6.0 - # Extract land values (remove NaN values) land_values = data.values[~np.isnan(data.values)] all_values.extend(land_values) @@ -173,21 +168,20 @@ def load(ts, fn): logger.info("Processing predictions") n_members = None - all_member_values = [] # Store all values for each member + all_member_values = [] for i, ts in enumerate(times): if i % LOG_INTERVAL == 0: logger.info(f"Processing timestep {i+1}/{len(times)}") - preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) * CONV_FACTOR_HOURLY # [n_members, n_channels, lat, lon] if n_members is None: n_members = preds.shape[0] - all_member_values = [[] for _ in range(n_members)] # Initialize value storage + all_member_values = [[] for _ in range(n_members)] for member_idx in range(n_members): data = preds[member_idx, tp_out] * land_mask - # Extract land values (remove NaN values) land_values = data.values[~np.isnan(data.values)] all_member_values[member_idx].extend(land_values) @@ -207,30 +201,24 @@ def load(ts, fn): # Compute percentiles for all datasets percentiles_data = {} + percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999} - # Target percentiles - target_data_array = xr.DataArray(all_land_values['target']) - target_p99 = target_data_array.quantile(0.99).item() - target_p999 = target_data_array.quantile(0.999).item() - target_p9999 = target_data_array.quantile(0.9999).item() - percentiles_data['target'] = {99: target_p99, 99.9: target_p999, 99.99: target_p9999} - - # Baseline percentiles - baseline_data_array = xr.DataArray(all_land_values['baseline']) - baseline_p99 = baseline_data_array.quantile(0.99).item() - baseline_p999 = baseline_data_array.quantile(0.999).item() - baseline_p9999 = baseline_data_array.quantile(0.9999).item() - percentiles_data['baseline'] = {99: baseline_p99, 99.9: baseline_p999, 99.99: baseline_p9999} + # Target and baseline percentiles + for mode in ['target', 'baseline']: + data_array = xr.DataArray(all_land_values[mode]) + percentiles_data[mode] = { + key: data_array.quantile(p).item() + for key, p in percentiles.items() + } # Ensemble member percentiles percentiles_data['predictions'] = {} for member_idx in range(n_members): member_data_array = xr.DataArray(all_member_values[member_idx]) - member_p99 = member_data_array.quantile(0.99).item() - member_p999 = member_data_array.quantile(0.999).item() - member_p9999 = member_data_array.quantile(0.9999).item() - percentiles_data['predictions'][f'member_{member_idx}'] = {99: member_p99, 99.9: member_p999, 99.99: member_p9999} - + percentiles_data['predictions'][f'member_{member_idx}'] = { + key: member_data_array.quantile(p).item() + for key, p in percentiles.items() + } # Create exceedance plots labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] From 0738df23afc7d211b5463da4757ad150019ce416 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 7 Aug 2025 11:13:05 +0200 Subject: [PATCH 130/189] rename --- src/hirad/eval/{map_99pctl.py => map_precip_99pctl.py} | 0 src/hirad/eval/{map_mean.py => map_precip_mean.py} | 0 src/hirad/maps.sh | 4 ++-- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/hirad/eval/{map_99pctl.py => map_precip_99pctl.py} (100%) rename src/hirad/eval/{map_mean.py => map_precip_mean.py} (100%) diff --git a/src/hirad/eval/map_99pctl.py b/src/hirad/eval/map_precip_99pctl.py similarity index 100% rename from src/hirad/eval/map_99pctl.py rename to src/hirad/eval/map_precip_99pctl.py diff --git a/src/hirad/eval/map_mean.py b/src/hirad/eval/map_precip_mean.py similarity index 100% rename from src/hirad/eval/map_mean.py rename to src/hirad/eval/map_precip_mean.py diff --git a/src/hirad/maps.sh b/src/hirad/maps.sh index 543de5a..770074f 100644 --- a/src/hirad/maps.sh +++ b/src/hirad/maps.sh @@ -46,6 +46,6 @@ export OMP_NUM_THREADS=72 srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies python src/hirad/eval/snapshots.py --config-name=generate_era_cosmo.yaml - python src/hirad/eval/map_99pctl.py --config-name=generate_era_cosmo.yaml - python src/hirad/eval/map_mean.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/map_precip_99pctl.py --config-name=generate_era_cosmo.yaml + python src/hirad/eval/map_precip_mean.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file From ecb686fb7473f0e029c79a094504b46c72262e4f Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 7 Aug 2025 12:01:27 +0200 Subject: [PATCH 131/189] unify maps in one script and more stats --- src/hirad/eval/map_precip_99pctl.py | 168 ---------------------- src/hirad/eval/map_precip_mean.py | 168 ---------------------- src/hirad/eval/map_precip_stats.py | 207 ++++++++++++++++++++++++++++ src/hirad/eval_precip.sh | 14 +- src/hirad/maps.sh | 51 ------- 5 files changed, 216 insertions(+), 392 deletions(-) delete mode 100644 src/hirad/eval/map_precip_99pctl.py delete mode 100644 src/hirad/eval/map_precip_mean.py create mode 100644 src/hirad/eval/map_precip_stats.py delete mode 100644 src/hirad/maps.sh diff --git a/src/hirad/eval/map_precip_99pctl.py b/src/hirad/eval/map_precip_99pctl.py deleted file mode 100644 index a399a82..0000000 --- a/src/hirad/eval/map_precip_99pctl.py +++ /dev/null @@ -1,168 +0,0 @@ -""" -Plots maps of the 99th percentile of precipitation over the entire time period. - -This script computes the all-time 99th percentile of precipitation for each grid point -and creates maps for target (COSMO-2), baseline (ERA5), and predictions (CorrDiff). -""" -import logging -from datetime import datetime -from pathlib import Path - -import hydra -import numpy as np -import torch -from omegaconf import DictConfig, OmegaConf -import xarray as xr - -from hirad.datasets import get_dataset_and_sampler_inference -from hirad.distributed import DistributedManager -from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import plot_map_precipitation, get_channel_indices, CONV_FACTOR, LOG_INTERVAL - - -@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") -def main(cfg: DictConfig): - # Setup logging - DistributedManager.initialize() - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - logger.info("Starting computation for 99th percentile precipitation maps") - times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") - logger.info(f"Loaded {len(times)} timesteps to process") - - # Initialize dataset - ds_cfg = OmegaConf.to_container(cfg.dataset) - dataset, _ = get_dataset_and_sampler_inference( - ds_cfg, times, cfg.generation.get('has_lead_time', False) - ) - logger.info("Dataset and sampler initialized") - - # Output root and loader - out_root = Path(cfg.generation.io.output_path or './outputs') - def load(ts, fn): - return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR - - # Find channel indices - indices = get_channel_indices(dataset) - tp_out = indices['output']['tp'] - tp_in = indices['input'].get('tp', tp_out) - logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") - - # -- Process target -- - logger.info("Processing target (COSMO-2)") - target_data_list = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Processing target timestep {i+1}/{len(times)}: {ts}") - data = load(ts, f"{ts}-target")[tp_out] - target_data_list.append(data) - - target_da = xr.DataArray( - np.stack(target_data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} - ) - - # Compute 99th percentile across all time steps - target_p99 = target_da.quantile(0.99, dim='time') - logger.info("Target 99th percentile computed") - - # -- Process baseline -- - logger.info("Processing baseline (ERA5)") - baseline_data_list = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Processing baseline timestep {i+1}/{len(times)}: {ts}") - data = load(ts, f"{ts}-baseline")[tp_in] - baseline_data_list.append(data) - - baseline_da = xr.DataArray( - np.stack(baseline_data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} - ) - - # Apply scaling factor for baseline and compute 99th percentile - baseline_p99 = (baseline_da / 6.0).quantile(0.99, dim='time') - logger.info("Baseline 99th percentile computed") - - # -- Process predictions -- - logger.info("Processing predictions (CorrDiff)") - pred_data_list = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Processing predictions timestep {i+1}/{len(times)}: {ts}") - preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] - # Extract precipitation channel - tp_data = preds[:, tp_out] # [n_members, lat, lon] - tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon']) - pred_data_list.append(tp_da) - - pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon] - pred_da = pred_da.assign_coords({ - 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] - }) - pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') - - # Compute 99th percentile across time for each member, then ensemble mean - pred_p99_by_member = pred_da.quantile(0.99, dim='time') - pred_p99_mean = pred_p99_by_member.mean(dim='member') - logger.info("Predictions 99th percentile computed") - - # Create output directory - map_output_dir = out_root / 'maps_99th_percentile' - map_output_dir.mkdir(parents=True, exist_ok=True) - - # Plot maps using the precipitation-specific plotting function - logger.info("Creating precipitation maps") - - # Target map - plot_map_precipitation( - target_p99.values, - str(map_output_dir / 'target_99th_percentile'), - title='COSMO-2 Analysis: 99th Percentile Precipitation', - threshold=0.1, - rfac=1.0 # Already converted to mm/day - ) - logger.info("Target map saved") - - # Baseline map - plot_map_precipitation( - baseline_p99.values, - str(map_output_dir / 'baseline_99th_percentile'), - title='ERA5: 99th Percentile Precipitation', - threshold=0.1, - rfac=1.0 # Already converted to mm/day - ) - logger.info("Baseline map saved") - - # Prediction ensemble mean map - plot_map_precipitation( - pred_p99_mean.values, - str(map_output_dir / 'prediction_ensmean_99th_percentile'), - title='CorrDiff Ensemble Mean: 99th Percentile Precipitation', - threshold=0.1, - rfac=1.0 # Already converted to mm/day - ) - logger.info("Prediction ensemble mean map saved") - - # Individual ensemble member maps - logger.info("Creating individual ensemble member maps") - n_members = pred_p99_by_member.shape[0] - for member_idx in range(n_members): - plot_map_precipitation( - pred_p99_by_member[member_idx].values, - str(map_output_dir / f'prediction_member_{member_idx:02d}_99th_percentile'), - title=f'CorrDiff Member {member_idx+1}: 99th Percentile Precipitation', - threshold=0.1, - rfac=1.0 # Already converted to mm/day - ) - logger.info(f"Individual ensemble member maps saved ({n_members} members)") - - logger.info(f"All maps saved to {map_output_dir}") - logger.info("99th percentile precipitation mapping completed successfully") - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/src/hirad/eval/map_precip_mean.py b/src/hirad/eval/map_precip_mean.py deleted file mode 100644 index e7c33ae..0000000 --- a/src/hirad/eval/map_precip_mean.py +++ /dev/null @@ -1,168 +0,0 @@ -""" -Plots maps of the mean precipitation over the entire time period. - -This script computes the temporal mean of precipitation for each grid point -and creates maps for target (COSMO-2), baseline (ERA5), and predictions (CorrDiff). -""" -import logging -from datetime import datetime -from pathlib import Path - -import hydra -import numpy as np -import torch -from omegaconf import DictConfig, OmegaConf -import xarray as xr - -from hirad.datasets import get_dataset_and_sampler_inference -from hirad.distributed import DistributedManager -from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import plot_map_precipitation, get_channel_indices, CONV_FACTOR, LOG_INTERVAL - - -@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") -def main(cfg: DictConfig): - # Setup logging - DistributedManager.initialize() - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - logger.info("Starting computation for mean precipitation maps") - times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") - logger.info(f"Loaded {len(times)} timesteps to process") - - # Initialize dataset - ds_cfg = OmegaConf.to_container(cfg.dataset) - dataset, _ = get_dataset_and_sampler_inference( - ds_cfg, times, cfg.generation.get('has_lead_time', False) - ) - logger.info("Dataset and sampler initialized") - - # Output root and loader - out_root = Path(cfg.generation.io.output_path or './outputs') - def load(ts, fn): - return torch.load(out_root/ts/fn, weights_only=False) * CONV_FACTOR - - # Find channel indices - indices = get_channel_indices(dataset) - tp_out = indices['output']['tp'] - tp_in = indices['input'].get('tp', tp_out) - logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}") - - # -- Process target -- - logger.info("Processing target (COSMO-2)") - target_data_list = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Processing target timestep {i+1}/{len(times)}: {ts}") - data = load(ts, f"{ts}-target")[tp_out] - target_data_list.append(data) - - target_da = xr.DataArray( - np.stack(target_data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} - ) - - # Compute mean across all time steps - target_mean = target_da.mean(dim='time') - logger.info("Target mean computed") - - # -- Process baseline -- - logger.info("Processing baseline (ERA5)") - baseline_data_list = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Processing baseline timestep {i+1}/{len(times)}: {ts}") - data = load(ts, f"{ts}-baseline")[tp_in] - baseline_data_list.append(data) - - baseline_da = xr.DataArray( - np.stack(baseline_data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} - ) - - # Apply scaling factor for baseline and compute mean - baseline_mean = (baseline_da / 6.0).mean(dim='time') - logger.info("Baseline mean computed") - - # -- Process predictions -- - logger.info("Processing predictions (CorrDiff)") - pred_data_list = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Processing predictions timestep {i+1}/{len(times)}: {ts}") - preds = load(ts, f"{ts}-predictions") # [n_members, n_channels, lat, lon] - # Extract precipitation channel - tp_data = preds[:, tp_out] # [n_members, lat, lon] - tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon']) - pred_data_list.append(tp_da) - - pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon] - pred_da = pred_da.assign_coords({ - 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] - }) - pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') - - # Compute mean across time for each member, then ensemble mean - pred_mean_by_member = pred_da.mean(dim='time') - pred_mean_mean = pred_mean_by_member.mean(dim='member') - logger.info("Predictions mean computed") - - # Create output directory - map_output_dir = out_root / 'maps_mean' - map_output_dir.mkdir(parents=True, exist_ok=True) - - # Plot maps using the precipitation-specific plotting function - logger.info("Creating precipitation maps") - - # Target map - plot_map_precipitation( - target_mean.values, - str(map_output_dir / 'target_mean'), - title='COSMO-2 Analysis: Mean Precipitation', - threshold=0.01, # Lower threshold for mean precipitation - rfac=1.0 # Already converted to mm/day - ) - logger.info("Target map saved") - - # Baseline map - plot_map_precipitation( - baseline_mean.values, - str(map_output_dir / 'baseline_mean'), - title='ERA5: Mean Precipitation', - threshold=0.01, # Lower threshold for mean precipitation - rfac=1.0 # Already converted to mm/day - ) - logger.info("Baseline map saved") - - # Prediction ensemble mean map - plot_map_precipitation( - pred_mean_mean.values, - str(map_output_dir / 'prediction_ensmean_mean'), - title='CorrDiff Ensemble Mean: Mean Precipitation', - threshold=0.01, # Lower threshold for mean precipitation - rfac=1.0 # Already converted to mm/day - ) - logger.info("Prediction ensemble mean map saved") - - # Individual ensemble member maps - logger.info("Creating individual ensemble member maps") - n_members = pred_mean_by_member.shape[0] - for member_idx in range(n_members): - plot_map_precipitation( - pred_mean_by_member[member_idx].values, - str(map_output_dir / f'prediction_member_{member_idx:02d}_mean'), - title=f'CorrDiff Member {member_idx+1}: Mean Precipitation', - threshold=0.01, # Lower threshold for mean precipitation - rfac=1.0 # Already converted to mm/day - ) - logger.info(f"Individual ensemble member maps saved ({n_members} members)") - - logger.info(f"All maps saved to {map_output_dir}") - logger.info("Mean precipitation mapping completed successfully") - - -if __name__ == '__main__': - main() diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py new file mode 100644 index 0000000..62c574d --- /dev/null +++ b/src/hirad/eval/map_precip_stats.py @@ -0,0 +1,207 @@ +""" +Generates: + - outputs/maps_mean/ + - outputs/maps_p99/ + - outputs/maps_p99.9/ + - outputs/maps_p99.99/ + - outputs/maps_Rx1hr/ + - outputs/maps_Rx1day/ + - outputs/maps_Rx5day/ + - outputs/maps_cdd/ + - outputs/maps_cwd/ + - outputs/maps_weth_freq/ +""" +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import plot_map_precipitation, plot_map, get_channel_indices, CONV_FACTOR, LOG_INTERVAL, WET_THRESHOLD + + +def process_data_for_stat(times, out_root, tp_channel, mode, stat_type, stat_param, logger): + """Process data for a given mode and compute the specified statistic.""" + def consecutive_spell(condition): + """Calculate longest consecutive spell where condition is True.""" + def _spell_length(x): + x = np.asarray(x, dtype=bool) + if len(x) == 0: + return 0 + runs = np.diff(np.concatenate(([False], x, [False])).astype(int)) + starts = np.where(runs == 1)[0] + ends = np.where(runs == -1)[0] + return int(np.max(ends - starts)) if len(starts) > 0 else 0 + return xr.apply_ufunc(_spell_length, condition, input_core_dims=[['time']], vectorize=True) + + def apply_statistic(data, stat_type, stat_param): + if stat_type == 'mean': + return data.mean(dim='time') + elif stat_type == 'quantile': + return data.quantile(stat_param, dim='time') + elif stat_type == 'Rx1hr': + return data.max(dim='time') + elif stat_type == 'Rx1day': + # Maximum daily precipitation total + daily = data.resample(time="1D").sum("time") + return daily.max(dim='time') + elif stat_type == 'Rx5day': + # Maximum 5-consecutive-day precipitation total + daily = data.resample(time="1D").sum("time") + return daily.rolling(time=5, center=False).sum().max(dim='time') + elif stat_type == 'cdd': # Consecutive Dry Days (< 1 mm) + daily = data.resample(time="1D").sum("time") + return consecutive_spell(daily < 1.0) + elif stat_type == 'cwd': # Consecutive Wet Days (≥ 1 mm) + daily = data.resample(time="1D").sum("time") + return consecutive_spell(daily >= 1.0) + elif stat_type == 'weth_freq': + return (data / 24 > WET_THRESHOLD).mean(dim='time') * 100 + else: + raise ValueError(f"Unsupported statistic type: {stat_type}") + + # Load data + data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing {mode} timestep {i+1}/{len(times)}: {ts}") + + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) * CONV_FACTOR + if mode == 'predictions': + tp_da = xr.DataArray(data[:, tp_channel], dims=['member', 'lat', 'lon']) + data_list.append(tp_da) + else: + data_list.append(data[tp_channel]) + + # Process data based on mode + if mode == 'predictions': + pred_da = xr.concat(data_list, dim='time').assign_coords({ + 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] + }).transpose('member', 'time', 'lat', 'lon') + + result_by_member = apply_statistic(pred_da, stat_type, stat_param) + return result_by_member.mean(dim='member'), result_by_member, pred_da.shape[0] + + else: + da = xr.DataArray( + np.stack(data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + if mode == 'baseline': + da = da / 6.0 + + return apply_statistic(da, stat_type, stat_param), None, None + + +# Statistics configuration +STATISTICS_CONFIG = { + 'mean': {'type': 'mean', 'threshold': 0.01, 'title': 'Mean'}, + 'p99': {'type': 'quantile', 'param': 0.99, 'threshold': 0.1, 'title': '99th Percentile'}, + 'p99.9': {'type': 'quantile', 'param': 0.999, 'threshold': 0.1, 'title': '99.9th Percentile'}, + 'p99.99': {'type': 'quantile', 'param': 0.9999, 'threshold': 0.1, 'title': '99.99th Percentile'}, + 'Rx1hr': {'type': 'Rx1hr', 'threshold': 0.1, 'title': 'Maximum (Rx1hr)'}, + 'Rx1day': {'type': 'Rx1day', 'threshold': 0.1, 'title': 'Maximum 1-day Amount (Rx1day)'}, + 'Rx5day': {'type': 'Rx5day', 'threshold': 0.1, 'title': 'Maximum 5-day Total (Rx5day)'}, + 'cdd': {'type': 'cdd', 'threshold': 0.1, 'title': 'Consecutive Dry Days (CDD)'}, + 'cwd': {'type': 'cwd', 'threshold': 0.1, 'title': 'Consecutive Wet Days (CWD)'}, + 'weth_freq': {'type': 'weth_freq', 'threshold': 0.01, 'title': 'Wet-Hour Frequency'} +} + +def get_all_stat_configs(): + """Get all statistic configurations to generate.""" + return [ + { + 'stat_name': name, + 'title_stat': config['title'], + 'param': config.get('param'), # Use get() to handle missing params + **config + } + for name, config in STATISTICS_CONFIG.items() + ] + + +def plot_stat_map(data, filename, stat_config, label): + """Plot a single statistic map with appropriate styling.""" + if stat_config['type'] == 'weth_freq': + plot_map(data, filename, title=f'{label}: {stat_config["title_stat"]} (%)', + label='Wet-Hour Frequency [%]', vmin=0, vmax=100, cmap='Blues') + elif stat_config['type'] in ['cdd', 'cwd']: + plot_map(data, filename, title=f'{label}: {stat_config["title_stat"]}', + label='Days', vmin=0, vmax=None, cmap='viridis') + else: + plot_map_precipitation(data, filename, title=f'{label}: {stat_config["title_stat"]} Precipitation', + threshold=stat_config['threshold'], rfac=1.0) + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting precipitation statistics generation") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Processing {len(times)} timesteps") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + + # Output root and channel indices + out_root = Path(cfg.generation.io.output_path or './outputs') + indices = get_channel_indices(dataset) + tp_out, tp_in = indices['output']['tp'], indices['input'].get('tp', indices['output']['tp']) + + # Mode configuration + modes = { + 'target': (tp_out, 'COSMO-2 Analysis'), + 'baseline': (tp_in, 'ERA5'), + 'predictions': (tp_out, 'CorrDiff Ensemble Mean') + } + + all_stat_configs = get_all_stat_configs() + logger.info(f"Generating {len(all_stat_configs)} statistics for {len(modes)} modes") + + # Process each statistic + for stat_config in all_stat_configs: + logger.info(f"Processing {stat_config['title_stat']}...") + + # Process all modes for this statistic + results = {} + for mode, (tp_channel, label) in modes.items(): + result_mean, result_by_member, n_members = process_data_for_stat( + times, out_root, tp_channel, mode, stat_config['type'], stat_config['param'], logger + ) + results[mode] = (result_mean, result_by_member, n_members, label) + + # Create maps + map_output_dir = out_root / f"maps_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + + for mode, (result_mean, result_by_member, n_members, label) in results.items(): + # Main ensemble mean map + plot_stat_map(result_mean.values, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label) + + # Individual ensemble member maps for predictions + if mode == 'predictions' and result_by_member is not None: + for member_idx in range(n_members): + member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') + member_label = f'CorrDiff Member {member_idx+1}' + plot_stat_map(result_by_member[member_idx].values, member_filename, stat_config, member_label) + + logger.info("All precipitation statistics maps generated successfully") + + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index 52232c4..db9b190 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -42,16 +42,20 @@ export OMP_NUM_THREADS=72 # echo "Local processes: $LOCAL_PROCS" # echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" -# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml + srun --environment=./ci/edf/modulus_env.toml bash -c " pip install -e . --no-dependencies + + # Diurnal cycle python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=generate_era_cosmo.yaml python src/hirad/eval/diurnal_cycle_precip_p99.py --config-name=generate_era_cosmo.yaml - python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=generate_era_cosmo.yaml -" + python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=generate_era_cosmo.yaml # TODO: Transfer to relevant script. -srun --environment=./ci/edf/modulus_env.toml bash -c " - pip install -e . --no-dependencies + # Histograms python src/hirad/eval/hist.py --config-name=generate_era_cosmo.yaml python src/hirad/eval/probability_of_exceedance.py --config-name=generate_era_cosmo.yaml + + # Maps + python src/hirad/eval/map_precip_stats.py --config-name=generate_era_cosmo.yaml + # python src/hirad/eval/snapshots.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file diff --git a/src/hirad/maps.sh b/src/hirad/maps.sh deleted file mode 100644 index 770074f..0000000 --- a/src/hirad/maps.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="plot" - -### HARDWARE ### -#SBATCH --partition=normal -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=2 -#SBATCH --gpus-per-node=2 -#SBATCH --cpus-per-task=72 -#SBATCH --time=01:00:00 -#SBATCH --no-requeue -#SBATCH --exclusive - -### OUTPUT ### -#SBATCH --output=./logs/plot_maps.log - -### ENVIRONMENT #### -#SBATCH -A a161 - -# Choose method to initialize dist in pythorch -export DISTRIBUTED_INITIALIZATION_METHOD=SLURM - -MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" -echo "Master node : $MASTER_ADDR" -# Get IP for hostname. -MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" -echo "Master address : $MASTER_ADDR" -export MASTER_ADDR -export MASTER_PORT=29500 -echo "Master port: $MASTER_PORT" - -# Get number of physical cores using Python -# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -# # Use SLURM_NTASKS (number of processes to be launched by torchrun) -# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# # Compute threads per process -# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) -# export OMP_NUM_THREADS=$OMP_THREADS -export OMP_NUM_THREADS=72 -# echo "Physical cores: $PHYSICAL_CORES" -# echo "Local processes: $LOCAL_PROCS" -# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" - -# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun --environment=./ci/edf/modulus_env.toml bash -c " - pip install -e . --no-dependencies - python src/hirad/eval/snapshots.py --config-name=generate_era_cosmo.yaml - python src/hirad/eval/map_precip_99pctl.py --config-name=generate_era_cosmo.yaml - python src/hirad/eval/map_precip_mean.py --config-name=generate_era_cosmo.yaml -" \ No newline at end of file From 622a36ccb958fd3a0da7bb2b14d10131c2e3f806 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 7 Aug 2025 14:32:17 +0200 Subject: [PATCH 132/189] point to store --- src/hirad/eval/plotting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index edbcbc7..13849f1 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -22,6 +22,8 @@ WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h LOG_INTERVAL = 24 # Log progress every N timesteps +LAND_SEA_MASK_PATH = '/capstor/store/mch/msopr/hirad-gen/eval/lsm.npy' + def get_channel_indices(dataset, channels=None): """ Get channel indices for input and output channels from dataset. @@ -50,7 +52,7 @@ def get_channel_indices(dataset, channels=None): return {'input': filtered_in, 'output': filtered_out} -def load_land_sea_mask(path='/iopsstor/scratch/cscs/davidle/HiRAD-Gen/lsm.npy'): +def load_land_sea_mask(path=LAND_SEA_MASK_PATH): """Load and retrun a land-sea mask as xarray DataArray.""" lsm_data = np.load(path).reshape(352, 544) return xr.DataArray( From 42250be8a262cb078136c53414fd8e56342d97cc Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Thu, 7 Aug 2025 16:48:04 +0200 Subject: [PATCH 133/189] conserve memory and cleanup --- src/hirad/eval/map_precip_stats.py | 308 ++++++++++++++--------------- src/hirad/eval/plotting.py | 3 +- 2 files changed, 151 insertions(+), 160 deletions(-) diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py index 62c574d..26740c1 100644 --- a/src/hirad/eval/map_precip_stats.py +++ b/src/hirad/eval/map_precip_stats.py @@ -1,16 +1,3 @@ -""" -Generates: - - outputs/maps_mean/ - - outputs/maps_p99/ - - outputs/maps_p99.9/ - - outputs/maps_p99.99/ - - outputs/maps_Rx1hr/ - - outputs/maps_Rx1day/ - - outputs/maps_Rx5day/ - - outputs/maps_cdd/ - - outputs/maps_cwd/ - - outputs/maps_weth_freq/ -""" import logging from datetime import datetime from pathlib import Path @@ -24,126 +11,81 @@ from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager from hirad.utils.function_utils import get_time_from_range -from hirad.eval.plotting import plot_map_precipitation, plot_map, get_channel_indices, CONV_FACTOR, LOG_INTERVAL, WET_THRESHOLD - - -def process_data_for_stat(times, out_root, tp_channel, mode, stat_type, stat_param, logger): - """Process data for a given mode and compute the specified statistic.""" - def consecutive_spell(condition): - """Calculate longest consecutive spell where condition is True.""" - def _spell_length(x): - x = np.asarray(x, dtype=bool) - if len(x) == 0: - return 0 - runs = np.diff(np.concatenate(([False], x, [False])).astype(int)) - starts = np.where(runs == 1)[0] - ends = np.where(runs == -1)[0] - return int(np.max(ends - starts)) if len(starts) > 0 else 0 - return xr.apply_ufunc(_spell_length, condition, input_core_dims=[['time']], vectorize=True) - - def apply_statistic(data, stat_type, stat_param): - if stat_type == 'mean': - return data.mean(dim='time') - elif stat_type == 'quantile': - return data.quantile(stat_param, dim='time') - elif stat_type == 'Rx1hr': - return data.max(dim='time') - elif stat_type == 'Rx1day': - # Maximum daily precipitation total - daily = data.resample(time="1D").sum("time") - return daily.max(dim='time') - elif stat_type == 'Rx5day': - # Maximum 5-consecutive-day precipitation total - daily = data.resample(time="1D").sum("time") - return daily.rolling(time=5, center=False).sum().max(dim='time') - elif stat_type == 'cdd': # Consecutive Dry Days (< 1 mm) - daily = data.resample(time="1D").sum("time") - return consecutive_spell(daily < 1.0) - elif stat_type == 'cwd': # Consecutive Wet Days (≥ 1 mm) - daily = data.resample(time="1D").sum("time") - return consecutive_spell(daily >= 1.0) - elif stat_type == 'weth_freq': - return (data / 24 > WET_THRESHOLD).mean(dim='time') * 100 - else: - raise ValueError(f"Unsupported statistic type: {stat_type}") - - # Load data - data_list = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Processing {mode} timestep {i+1}/{len(times)}: {ts}") - - data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) * CONV_FACTOR - if mode == 'predictions': - tp_da = xr.DataArray(data[:, tp_channel], dims=['member', 'lat', 'lon']) - data_list.append(tp_da) - else: - data_list.append(data[tp_channel]) - - # Process data based on mode - if mode == 'predictions': - pred_da = xr.concat(data_list, dim='time').assign_coords({ - 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times] - }).transpose('member', 'time', 'lat', 'lon') - - result_by_member = apply_statistic(pred_da, stat_type, stat_param) - return result_by_member.mean(dim='member'), result_by_member, pred_da.shape[0] - - else: - da = xr.DataArray( - np.stack(data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} - ) - if mode == 'baseline': - da = da / 6.0 - - return apply_statistic(da, stat_type, stat_param), None, None - - -# Statistics configuration -STATISTICS_CONFIG = { - 'mean': {'type': 'mean', 'threshold': 0.01, 'title': 'Mean'}, - 'p99': {'type': 'quantile', 'param': 0.99, 'threshold': 0.1, 'title': '99th Percentile'}, - 'p99.9': {'type': 'quantile', 'param': 0.999, 'threshold': 0.1, 'title': '99.9th Percentile'}, - 'p99.99': {'type': 'quantile', 'param': 0.9999, 'threshold': 0.1, 'title': '99.99th Percentile'}, - 'Rx1hr': {'type': 'Rx1hr', 'threshold': 0.1, 'title': 'Maximum (Rx1hr)'}, - 'Rx1day': {'type': 'Rx1day', 'threshold': 0.1, 'title': 'Maximum 1-day Amount (Rx1day)'}, - 'Rx5day': {'type': 'Rx5day', 'threshold': 0.1, 'title': 'Maximum 5-day Total (Rx5day)'}, - 'cdd': {'type': 'cdd', 'threshold': 0.1, 'title': 'Consecutive Dry Days (CDD)'}, - 'cwd': {'type': 'cwd', 'threshold': 0.1, 'title': 'Consecutive Wet Days (CWD)'}, - 'weth_freq': {'type': 'weth_freq', 'threshold': 0.01, 'title': 'Wet-Hour Frequency'} -} - -def get_all_stat_configs(): - """Get all statistic configurations to generate.""" - return [ - { - 'stat_name': name, - 'title_stat': config['title'], - 'param': config.get('param'), # Use get() to handle missing params - **config - } - for name, config in STATISTICS_CONFIG.items() - ] +from hirad.eval.plotting import ( + plot_map_precipitation, plot_map, get_channel_indices, + CONV_FACTOR, LOG_INTERVAL, WET_THRESHOLD +) + + +def consecutive_spell(condition): + """Return longest consecutive spell where condition is True (per gridpoint).""" + def _spell_length(x): + x = np.asarray(x, dtype=bool) + if len(x) == 0: + return 0 + runs = np.diff(np.concatenate(([False], x, [False])).astype(int)) + starts = np.where(runs == 1)[0] + ends = np.where(runs == -1)[0] + return int(np.max(ends - starts)) if len(starts) > 0 else 0 + return xr.apply_ufunc(_spell_length, condition, input_core_dims=[['time']], vectorize=True) + + +def apply_statistic(data, stat_type, stat_param): + """Apply a statistic to the data along the time dimension.""" + if stat_type == 'mean': + return data.mean(dim='time') + if stat_type == 'quantile': + return data.quantile(stat_param, dim='time') + if stat_type == 'Rx1hr': + return data.max(dim='time') + if stat_type == 'Rx1day': + daily = data.resample(time="1D").sum("time") + return daily.max(dim='time') + if stat_type == 'Rx5day': + daily = data.resample(time="1D").sum("time") + return daily.rolling(time=5, center=False).sum().max(dim='time') + if stat_type == 'cdd': + daily = data.resample(time="1D").sum("time") + return consecutive_spell(daily < 1.0) + if stat_type == 'cwd': + daily = data.resample(time="1D").sum("time") + return consecutive_spell(daily >= 1.0) + if stat_type == 'weth_freq': + return (data / 24 > WET_THRESHOLD).mean(dim='time') * 100 + raise ValueError(f"Unsupported statistic type: {stat_type}") def plot_stat_map(data, filename, stat_config, label): """Plot a single statistic map with appropriate styling.""" if stat_config['type'] == 'weth_freq': - plot_map(data, filename, title=f'{label}: {stat_config["title_stat"]} (%)', - label='Wet-Hour Frequency [%]', vmin=0, vmax=100, cmap='Blues') - elif stat_config['type'] in ['cdd', 'cwd']: - plot_map(data, filename, title=f'{label}: {stat_config["title_stat"]}', - label='Days', vmin=0, vmax=None, cmap='viridis') + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]} (%)', + label='Wet-Hour Frequency [%]', vmin=0, vmax=10, cmap='PuBu', extend='max' + ) + elif stat_config['type'] == 'cdd': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Days', vmin=0, vmax=60, cmap='viridis', extend='max' + ) + elif stat_config['type'] == 'cwd': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Days', vmin=0, vmax=20, cmap='viridis', extend='max' + ) else: - plot_map_precipitation(data, filename, title=f'{label}: {stat_config["title_stat"]} Precipitation', - threshold=stat_config['threshold'], rfac=1.0) + plot_map_precipitation( + data, filename, + title=f'{label}: {stat_config["title_stat"]} Precipitation', + threshold=stat_config['threshold'], rfac=1.0 + ) @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig): - # Setup logging + # Setup and config DistributedManager.initialize() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -152,54 +94,102 @@ def main(cfg: DictConfig): times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") logger.info(f"Processing {len(times)} timesteps") - # Initialize dataset ds_cfg = OmegaConf.to_container(cfg.dataset) dataset, _ = get_dataset_and_sampler_inference( ds_cfg, times, cfg.generation.get('has_lead_time', False) ) - - # Output root and channel indices out_root = Path(cfg.generation.io.output_path or './outputs') indices = get_channel_indices(dataset) - tp_out, tp_in = indices['output']['tp'], indices['input'].get('tp', indices['output']['tp']) + tp_out = indices['output']['tp'] + tp_in = indices['input'].get('tp', tp_out) + + # Statistic configuration + STATISTICS_CONFIG = { + 'mean': {'type': 'mean', 'threshold': 0.01, 'title': 'Mean'}, + 'p99': {'type': 'quantile', 'param': 0.99, 'threshold': 0.1, 'title': '99th Percentile'}, + 'p99.9': {'type': 'quantile', 'param': 0.999, 'threshold': 0.1, 'title': '99.9th Percentile'}, + 'p99.99': {'type': 'quantile', 'param': 0.9999, 'threshold': 0.1, 'title': '99.99th Percentile'}, + 'Rx1hr': {'type': 'Rx1hr', 'threshold': 0.1, 'title': 'Maximum (Rx1hr)'}, + 'Rx1day': {'type': 'Rx1day', 'threshold': 0.1, 'title': 'Maximum 1-day Amount (Rx1day)'}, + 'Rx5day': {'type': 'Rx5day', 'threshold': 0.1, 'title': 'Maximum 5-day Total (Rx5day)'}, + 'cdd': {'type': 'cdd', 'threshold': 0.1, 'title': 'Consecutive Dry Days (CDD)'}, + 'cwd': {'type': 'cwd', 'threshold': 0.1, 'title': 'Consecutive Wet Days (CWD)'}, + 'weth_freq': {'type': 'weth_freq', 'threshold': 0.01, 'title': 'Wet-Hour Frequency'} + } + stat_configs = [ + { + 'stat_name': name, + 'title_stat': config['title'], + 'param': config.get('param'), + **config + } + for name, config in STATISTICS_CONFIG.items() + ] - # Mode configuration - modes = { + # Target and baseline modes + basic_modes = { 'target': (tp_out, 'COSMO-2 Analysis'), - 'baseline': (tp_in, 'ERA5'), - 'predictions': (tp_out, 'CorrDiff Ensemble Mean') + 'baseline': (tp_in, 'ERA5') } + logger.info(f"Generating {len(stat_configs)} statistics for {len(basic_modes)} basic modes + predictions") + + for mode, (tp_channel, label) in basic_modes.items(): + logger.info(f"Processing mode: {mode}") + # Load all timesteps for this mode + data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Loading {mode} timestep {i+1}/{len(times)}: {ts}") + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) * CONV_FACTOR + data_list.append(data[tp_channel]) + mode_data = xr.DataArray( + np.stack(data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + if mode == 'baseline': + mode_data = mode_data / 6.0 + # Compute and plot all statistics for this mode + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for {mode}...") + result = apply_statistic(mode_data, stat_config['type'], stat_config['param']) + map_output_dir = out_root / f"maps_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_stat_map(result.values, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label) + + # Predictions mode: process each member separately to save memory + logger.info("Processing predictions mode...") + data = torch.load(out_root/times[0]/f"{times[0]}-predictions", weights_only=False) + n_members = data.shape[0] + logger.info(f"Found {n_members} ensemble members") - all_stat_configs = get_all_stat_configs() - logger.info(f"Generating {len(all_stat_configs)} statistics for {len(modes)} modes") - - # Process each statistic - for stat_config in all_stat_configs: - logger.info(f"Processing {stat_config['title_stat']}...") - - # Process all modes for this statistic - results = {} - for mode, (tp_channel, label) in modes.items(): - result_mean, result_by_member, n_members = process_data_for_stat( - times, out_root, tp_channel, mode, stat_config['type'], stat_config['param'], logger - ) - results[mode] = (result_mean, result_by_member, n_members, label) - - # Create maps - map_output_dir = out_root / f"maps_{stat_config['stat_name']}" - map_output_dir.mkdir(parents=True, exist_ok=True) + for member_idx in range(n_members): + logger.info(f"Processing prediction member {member_idx+1}/{n_members}") + # Load all timesteps for this member + data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Loading prediction member {member_idx} timestep {i+1}/{len(times)}: {ts}") + pred_data = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) * CONV_FACTOR + data_list.append(pred_data[member_idx, tp_out]) + member_data = xr.DataArray( + np.stack(data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) - for mode, (result_mean, result_by_member, n_members, label) in results.items(): - # Main ensemble mean map - plot_stat_map(result_mean.values, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label) + # Compute and plot all statistics for this member + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") + member_result = apply_statistic(member_data, stat_config['type'], stat_config['param']) - # Individual ensemble member maps for predictions - if mode == 'predictions' and result_by_member is not None: - for member_idx in range(n_members): - member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') - member_label = f'CorrDiff Member {member_idx+1}' - plot_stat_map(result_by_member[member_idx].values, member_filename, stat_config, member_label) - + # Create map + map_output_dir = out_root / f"maps_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') + member_label = f'CorrDiff Member {member_idx+1}' + plot_stat_map(member_result.values, member_filename, stat_config, member_label) + logger.info("All precipitation statistics maps generated successfully") diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 13849f1..558b415 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -120,6 +120,7 @@ def plot_map(values: np.array, ) if ticks is not None: cbar.set_ticks(ticks) + cbar.set_ticklabels([f'{tick:g}' for tick in ticks]) plt.tight_layout() fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight") @@ -136,7 +137,7 @@ def plot_map_precipitation(values, filename, title='', threshold=0.1, rfac=100.0 'forestgreen', 'limegreen', 'lawngreen', 'yellow', 'gold', 'darkorange', 'red', 'darkviolet', 'violet', 'thistle'] - bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 30, 50, 70, 100, 150, 200] + bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000] cmap = ListedColormap(colors) norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False) From 27296546de71b4ad5db71fe76762767d1a040bb2 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 25 Aug 2025 15:48:58 +0200 Subject: [PATCH 134/189] remove dependencies from pyproject --- pyproject.toml | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1477899..7ffe3f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,19 +10,10 @@ authors = [ { name="Petar Stamenkovic", email="petar.stamenkovic@meteoswiss.ch" } ] readme = "README.md" -requires-python = ">=3.12" +#requires-python = ">=3.12" license = {file = "LICENSE"} dependencies = [ - "cartopy>=0.24.1", - "cftime>=1.6.4", - "hydra-core>=1.3.2", - "matplotlib>=3.10.1", - "omegaconf>=2.3.0", - "tensorboard>=2.19.0", - "termcolor>=3.1.0", - "torchinfo>=1.8.0", - "treelib>=1.7.1" ] [tool.setuptools] From 8b53777eddda39f3d36a3955d9337f36ea9f7c11 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 25 Aug 2025 15:49:23 +0200 Subject: [PATCH 135/189] change conversion factor to get mm instead of cm --- src/hirad/eval/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 558b415..e863053 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -17,7 +17,7 @@ RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) # Constants for data processing -CONV_FACTOR_HOURLY = 100 # Convert precip of ERA5 from meters to mm/h +CONV_FACTOR_HOURLY = 1000 # Convert precip of ERA5 from meters to mm/h CONV_FACTOR = CONV_FACTOR_HOURLY * 24 # Convert precip of ERA5 from from meters to mm/day WET_THRESHOLD = 0.1 # Threshold for wet-hour in mm/h LOG_INTERVAL = 24 # Log progress every N timesteps @@ -126,7 +126,7 @@ def plot_map(values: np.array, fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight") plt.close(fig) -def plot_map_precipitation(values, filename, title='', threshold=0.1, rfac=100.0): +def plot_map_precipitation(values, filename, title='', threshold=0.1, rfac=1000.0): """Plot precipitation data with specific colormap and thresholds.""" # Scale and mask values below threshold values = rfac * values # m/h --> mm/h From 9fe0b69b25219d8c02527e71de18204e0b08b5dc Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 25 Aug 2025 15:49:38 +0200 Subject: [PATCH 136/189] add __init__ package for input_data --- src/hirad/input_data/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/hirad/input_data/__init__.py diff --git a/src/hirad/input_data/__init__.py b/src/hirad/input_data/__init__.py new file mode 100644 index 0000000..e69de29 From 630086a539a890170977f5d2689c21f6ebe085c0 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 25 Aug 2025 15:50:19 +0200 Subject: [PATCH 137/189] Clean up copernicus processing --- src/hirad/input_data/read_tp.py | 206 ++++++++++++++------------------ 1 file changed, 91 insertions(+), 115 deletions(-) diff --git a/src/hirad/input_data/read_tp.py b/src/hirad/input_data/read_tp.py index a13fed3..37767a0 100644 --- a/src/hirad/input_data/read_tp.py +++ b/src/hirad/input_data/read_tp.py @@ -1,5 +1,6 @@ import logging import netCDF4 +import xarray from anemoi.datasets import open_dataset import numpy as np import yaml @@ -9,6 +10,10 @@ import cartopy.feature as cfeature from matplotlib.colors import BoundaryNorm, ListedColormap +from hirad.eval.plotting import plot_map_precipitation, plot_scores_vs_t +from hirad.eval.metrics import compute_mae, absolute_error + + import interpolate_basic import sys @@ -24,18 +29,20 @@ COSMO_6H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr" COSMO_1H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr" COSMO_CONFIG_FILE="src/input_data/cosmo.yaml" -CDF_FILENAME = "8e49f064d738154bed136666ff72ae1c.nc" +#CDF_FILENAME = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-janfeb2020.nc" +CDF_FILENAME = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020.zip" +GRIB_FILENAME = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020.grib" LAT = np.arange(-4.42, 3.36 + 0.02, 0.02) LON = np.arange(-6.82, 4.80 + 0.02, 0.02) RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) - -def extract_values(netcdf_data): +def extract_netcdf_values(netcdf_data): netcdf_lat = netcdf_data['latitude'][:] netcdf_lon = netcdf_data['longitude'][:] netcdf_tp = netcdf_data['tp'][:,:] + logging.info(f'num nonzeros in tp is {np.count_nonzero(netcdf_tp)}') values = np.zeros((netcdf_tp.shape[0], netcdf_tp.shape[1]*netcdf_tp.shape[2])) latitudes = np.zeros(values.shape[1]) longitudes = np.zeros(values.shape[1]) @@ -44,111 +51,94 @@ def extract_values(netcdf_data): if i % 10 == 0: print(i) for j in range(len(netcdf_lon)): - grid_index = i * len(netcdf_lon) + j - values[:,grid_index] = netcdf_tp[:,i,j] - latitudes[grid_index] = netcdf_lat[i] - longitudes[grid_index] = netcdf_lon[j] - return values, latitudes, longitudes - -def plot_map(values: np.array, - filename: str, - label='', - title='', - vmin=None, - vmax=None, - cmap=None, - extend='neither', - norm=None, - ticks=None): - """Plot observed or interpolated data in a scatter plot.""" - logging.info(f'Creating map: {filename}') - - latitudes = LAT[RELAX_ZONE : RELAX_ZONE + 352] - longitudes = LON[RELAX_ZONE : RELAX_ZONE + 544] - lon2d, lat2d = np.meshgrid(longitudes, latitudes) - - fig, ax = plt.subplots( - figsize=(8, 6), - subplot_kw={"projection": ccrs.RotatedPole(pole_longitude=-170.0, - pole_latitude= 43.0)} - ) - values = values.reshape((len(latitudes), len(longitudes))) - contour = ax.pcolormesh( - lon2d, lat2d, values, - cmap=cmap, shading="auto", - norm=norm if norm else None, - vmin=None if norm else vmin, - vmax=None if norm else vmax, - ) - ax.coastlines() - ax.add_feature(cfeature.BORDERS, linewidth=1) - ax.gridlines(visible=False) - ax.set_xticks([]) - ax.set_yticks([]) - - plt.title(title) - cbar = plt.colorbar( - contour, - label=label, - orientation="horizontal", - extend=extend, - shrink=0.75, - pad=0.02 - ) - if ticks is not None: - cbar.set_ticks(ticks) - cbar.set_ticklabels([f'{tick:g}' for tick in ticks]) - - plt.tight_layout() - fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight") - plt.close(fig) - -def plot_map_precipitation(values, filename, title='', threshold=0.1, rfac=1000.0): - """Plot precipitation data with specific colormap and thresholds.""" - # Scale and mask values below threshold - values = rfac * values # m/h --> mm/h - values = np.ma.masked_where(values <= threshold, values) - - # Predefined colors and bounds specific for precipitation - colors = ['none', 'powderblue', 'dodgerblue', 'mediumblue', - 'forestgreen', 'limegreen', 'lawngreen', - 'yellow', 'gold', 'darkorange', 'red', - 'darkviolet', 'violet', 'thistle'] - bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 30, 50, 70, 100, 150, 200] - bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000] - - cmap = ListedColormap(colors) - norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False) - - plot_map( - values, filename, - cmap=cmap, - norm=norm, - ticks=bounds, - title=title, - label='mm/h', - extend='max' - ) + grid_index = i * len(netcdf_lon) + j + values[:,grid_index] = netcdf_tp[:,i,j] + latitudes[grid_index] = netcdf_lat[i] + longitudes[grid_index] = netcdf_lon[j] + reshape_values = np.reshape(netcdf_tp, (netcdf_tp.shape[0], netcdf_tp.shape[1]*netcdf_tp.shape[2])) + logging.info(f'Array equal? {np.array_equal(values, reshape_values)}') + return reshape_values, latitudes, longitudes +def reshape_to_cosmo(vals): + return vals.reshape((len(LAT)-RELAX_ZONE*2, len(LON)-RELAX_ZONE*2)) -print(interpolate_basic.regrid) - -file_id = netCDF4.Dataset(CDF_FILENAME) -#anemoi1 = open_dataset(ANEMOI_1H_FILENAME) -#anemoi6 = open_dataset(ANEMOI_6H_FILENAME) -#with open(COSMO_CONFIG_FILE) as cosmo_file: -# cosmo_config = yaml.safe_load(cosmo_file) -#cosmo = open_dataset(cosmo_config) +root = logging.getLogger() +root.setLevel(logging.INFO) + +logging.info('loading data') +grib_data = xarray.load_dataset(GRIB_FILENAME) +netcdf_data = netCDF4.Dataset(CDF_FILENAME) cosmo1 = open_dataset(COSMO_1H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') cosmo6 = open_dataset(COSMO_6H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') - +era1 = open_dataset(ANEMOI_1H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') +era6 = open_dataset(ANEMOI_6H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') +logging.info('loading data complete') output_grid= np.column_stack((cosmo1.longitudes, cosmo1.latitudes)) -print(output_grid.shape) -print(cosmo1[0,0,0,:].shape) -plot_map_precipitation(values=cosmo1[0,:], filename="cosmo1.png") -plot_map_precipitation(values=cosmo6[0,:], filename="cosmo6.png") +logging.info('processing netcdf data') +netcdf_values, netcdf_latitudes, netcdf_longitudes = extract_netcdf_values(netcdf_data=netcdf_data) +netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) + +make_plots = False + +prev_netcdf_regrid = [] + +netcdf_error = np.zeros(cosmo1.dates.shape) +era_norm_error = np.zeros(cosmo1.dates.shape) +netcdf_early_error = np.zeros(cosmo1.dates.shape) +netcdf_late_error = np.zeros(cosmo1.dates.shape) + +for t in range(len(cosmo1.dates)): + date = cosmo1.dates[t] + era_date = era1.dates[t] + if date != era_date: + logging.error('dates do not match: cosmo date: {date}, era date: {era_date}') + if date != netcdf_data['valid_time'][t]: + logging.error(f'dates do not match: cosmo date: {date}, netcdf: {netcdf_data["valid_time"][t]}') + + + # plot cosmo + if make_plots: + plot_map_precipitation(values=reshape_to_cosmo(cosmo1[t,:]), filename=f'plots/tp/{date}-cosmo1') + if t % 6 == 0: + if cosmo6.dates[t//6] != date: + logging.error(f'dates do not match: cosmo1: {date}, cosmo6: {cosmo6.dates[t//6]}') + plot_map_precipitation(values=reshape_to_cosmo(cosmo6[t//6,:]), filename=f'plots/tp/{date}-cosmo6') + + # plot netcdf + netcdf_vals = netcdf_values[t,:].reshape((1,1,netcdf_values.shape[1])) + netcdf_regrid=interpolate_basic.regrid(netcdf_vals, netcdf_grid, output_grid) + if make_plots: + plot_map_precipitation(reshape_to_cosmo(netcdf_regrid), f'plots/tp/{date}-netcdf') + + # plot era + era_grid = np.column_stack((era1.longitudes, era1.latitudes)) + era1_regrid = interpolate_basic.regrid(era1[t,:], era_grid, output_grid) + if make_plots: + plot_map_precipitation(reshape_to_cosmo(era1_regrid/6), f'plots/tp/{date}-era1-norm') + + #if t % 6 == 0: + # if era6.dates[t//6] != date: + # logging.error(f'dates do not match: era1: {date}, era6: {era6.dates[t//6]}') + # era6_regrid = interpolate_basic.regrid(era6[t//6,:], era_grid, output_grid) + # plot_map_precipitation(reshape_to_cosmo(era6_regrid), f'plots/tp/{date}-era6') + + era_norm_error[t] = np.mean(absolute_error(era1_regrid/6, cosmo1[t,:])) + netcdf_error[t] = np.mean(absolute_error(netcdf_regrid, cosmo1[t,:])) + logging.info(f'era norm error: {era_norm_error[t]} netcdf err: {netcdf_error[t]}') + if t>0: + netcdf_early_error[t] = np.mean(absolute_error(prev_netcdf_regrid, cosmo1[t,:])) + netcdf_late_error[t-1] = np.mean(absolute_error(netcdf_regrid, cosmo1[t-1,:])) + logging.info(f'netcdf early err: {netcdf_early_error[t]}, netcdf late err: {netcdf_late_error[t-1]}') + prev_netcdf_regrid = netcdf_regrid + +maes = {} +maes['era normalized'] = era_norm_error +maes['copernicus'] = netcdf_error +maes['copernicus-early'] = netcdf_early_error +maes['copernicus-late'] = netcdf_late_error +plot_scores_vs_t(maes, times=cosmo1.dates, filename='plots/errors.png') #fig = plt.figure() #fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) @@ -161,23 +151,9 @@ def plot_map_precipitation(values, filename, title='', threshold=0.1, rfac=1000. #fig.savefig('cosmo6.png') -values, latitudes, longitudes = extract_values(netcdf_data=file_id) -input_grid=np.column_stack((longitudes, latitudes)) -vals = values[0,:].reshape((1,1,values.shape[1])) -regrid=interpolate_basic.regrid(vals, input_grid, output_grid) -plot_map_precipitation(regrid, 'netcdf.png') - - -era1 = open_dataset(ANEMOI_1H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') -era6 = open_dataset(ANEMOI_6H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') -era_grid = np.column_stack((era1.longitudes, era1.latitudes)) -era1_regrid = interpolate_basic.regrid(era1[0,:], era_grid, output_grid) -plot_map_precipitation(era1_regrid, "era1.png") -era6_regrid = interpolate_basic.regrid(era6[0,:], era_grid, output_grid) -plot_map_precipitation(era1_regrid, "era6.png") From a7a541f3af32ba36ba111d74dc17ddcbfdd03c94 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 26 Aug 2025 08:54:59 +0200 Subject: [PATCH 138/189] Generalize copernicus script to work with grib/netcdf --- src/hirad/input_data/read_tp.py | 37 +++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/src/hirad/input_data/read_tp.py b/src/hirad/input_data/read_tp.py index 37767a0..9953086 100644 --- a/src/hirad/input_data/read_tp.py +++ b/src/hirad/input_data/read_tp.py @@ -29,8 +29,7 @@ COSMO_6H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr" COSMO_1H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr" COSMO_CONFIG_FILE="src/input_data/cosmo.yaml" -#CDF_FILENAME = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-janfeb2020.nc" -CDF_FILENAME = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020.zip" +CDF_FILENAME = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-janfeb2020.nc" GRIB_FILENAME = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020.grib" @@ -38,6 +37,29 @@ LON = np.arange(-6.82, 4.80 + 0.02, 0.02) RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) +def extract_grib_values(grib_data): + grib_lat = grib_data['latitude'][:] + grib_lon = grib_data['longitude'][:] + grib_t2m = grib_data['t2m'][:] + +def extract_lat_lon(data): + lat = data['latitude'][:] + lon = data['longitude'][:] + output_lat = np.zeros(len(lat)* len(lon)) + output_lon = np.zeros(len(lat) * len(lon)) + for i in range(len(lat)): + if i % 10 == 0: + print(i) + for j in range(len(lon)): + grid_index = i * len(lon) + j + output_lat[grid_index] = lat[i] + output_lon[grid_index] = lon[j] + return output_lat, output_lon + +def extract_values(data, variable): + values = data[variable][:] + return np.reshape(values, (values.shape[0], values.shape[1]*values.shape[2])) + def extract_netcdf_values(netcdf_data): netcdf_lat = netcdf_data['latitude'][:] netcdf_lon = netcdf_data['longitude'][:] @@ -77,10 +99,12 @@ def reshape_to_cosmo(vals): output_grid= np.column_stack((cosmo1.longitudes, cosmo1.latitudes)) logging.info('processing netcdf data') -netcdf_values, netcdf_latitudes, netcdf_longitudes = extract_netcdf_values(netcdf_data=netcdf_data) +netcdf_latitudes, netcdf_longitudes = extract_lat_lon(netcdf_data) +netcdf_values = extract_values(netcdf_data, 'tp') +#netcdf_values, netcdf_latitudes, netcdf_longitudes = extract_netcdf_values(netcdf_data=netcdf_data) netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) -make_plots = False +make_plots = True prev_netcdf_regrid = [] @@ -89,7 +113,8 @@ def reshape_to_cosmo(vals): netcdf_early_error = np.zeros(cosmo1.dates.shape) netcdf_late_error = np.zeros(cosmo1.dates.shape) -for t in range(len(cosmo1.dates)): +for t in range(4): +#for t in range(len(cosmo1.dates)): date = cosmo1.dates[t] era_date = era1.dates[t] if date != era_date: @@ -110,7 +135,7 @@ def reshape_to_cosmo(vals): netcdf_vals = netcdf_values[t,:].reshape((1,1,netcdf_values.shape[1])) netcdf_regrid=interpolate_basic.regrid(netcdf_vals, netcdf_grid, output_grid) if make_plots: - plot_map_precipitation(reshape_to_cosmo(netcdf_regrid), f'plots/tp/{date}-netcdf') + plot_map_precipitation(reshape_to_cosmo(netcdf_regrid), f'plots/tp/{date}-netcdf-refactor') # plot era era_grid = np.column_stack((era1.longitudes, era1.latitudes)) From 4a963ee91936e092110fe82e109b30a43d6cbd32 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 26 Aug 2025 08:59:41 +0200 Subject: [PATCH 139/189] add copy anemoi script for copying from catalogue --- copyanemoi.sh | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 copyanemoi.sh diff --git a/copyanemoi.sh b/copyanemoi.sh new file mode 100644 index 0000000..361a9fb --- /dev/null +++ b/copyanemoi.sh @@ -0,0 +1,11 @@ +#!/bin/bash -l +# +#SBATCH --time=23:59:00 +#SBATCH --ntasks=1 +#SBATCH --partition=xfer + +echo -e "$SLURM_JOB_NAME started on $(date):\n $command $1 $2" +cp -rvn $1 $2 + +echo -e "$SLURM_JOB_NAME finished on $(date)\n" + From fcef8fc920979fef4b3b4b48176211c2038bdba7 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 26 Aug 2025 09:00:55 +0200 Subject: [PATCH 140/189] move anemoi script to input_data directory --- copyanemoi.sh => src/hirad/input_data/copyanemoi.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename copyanemoi.sh => src/hirad/input_data/copyanemoi.sh (100%) diff --git a/copyanemoi.sh b/src/hirad/input_data/copyanemoi.sh similarity index 100% rename from copyanemoi.sh rename to src/hirad/input_data/copyanemoi.sh From 41208b2c8d5a862c6a2eb26bf9937a2c46825923 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 26 Aug 2025 16:55:57 +0200 Subject: [PATCH 141/189] Set training parameters in configs; add input/output channel names selection to dataset config and update era5cosmo dataset class; replace image saving function with results saving during training; improve channel transformation logic. --- src/hirad/conf/dataset/era_cosmo.yaml | 5 +- .../conf/dataset/era_cosmo_inference.yaml | 4 +- .../conf/generation/era_cosmo_training.yaml | 4 +- src/hirad/conf/plot_maps.yaml | 21 ++++++ .../conf/training/era_cosmo_diffusion.yaml | 12 ++-- .../conf/training/era_cosmo_regression.yaml | 10 +-- .../conf/training_era_cosmo_diffusion.yaml | 4 +- .../conf/training_era_cosmo_regression.yaml | 4 +- src/hirad/datasets/era5_cosmo.py | 44 ++++++++----- src/hirad/distributed/manager.py | 30 ++++----- src/hirad/inference/generate.py | 5 -- src/hirad/training/train.py | 4 +- src/hirad/utils/inference_utils.py | 65 +++++++++++++------ 13 files changed, 133 insertions(+), 79 deletions(-) create mode 100644 src/hirad/conf/plot_maps.yaml diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml index 5d32f4e..37b3df2 100644 --- a/src/hirad/conf/dataset/era_cosmo.yaml +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -1,3 +1,6 @@ type: era5_cosmo dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/train -validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation \ No newline at end of file +# dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-all-channels +validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation +input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp] +output_channel_names: [2t, 10u, 10v, tp] \ No newline at end of file diff --git a/src/hirad/conf/dataset/era_cosmo_inference.yaml b/src/hirad/conf/dataset/era_cosmo_inference.yaml index f819b18..bca045d 100644 --- a/src/hirad/conf/dataset/era_cosmo_inference.yaml +++ b/src/hirad/conf/dataset/era_cosmo_inference.yaml @@ -1,2 +1,4 @@ type: era5_cosmo -dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation \ No newline at end of file +dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation +input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp] +output_channel_names: [2t, 10u, 10v, tp] \ No newline at end of file diff --git a/src/hirad/conf/generation/era_cosmo_training.yaml b/src/hirad/conf/generation/era_cosmo_training.yaml index 4374fc8..54dc840 100644 --- a/src/hirad/conf/generation/era_cosmo_training.yaml +++ b/src/hirad/conf/generation/era_cosmo_training.yaml @@ -11,8 +11,8 @@ num_ensembles: 16 # artifact. times_range: null times: - - 20200926-1800 - - 20200927-0000 + - 20200721-1900 + - 20200722-1900 perf: num_writer_workers: 10 \ No newline at end of file diff --git a/src/hirad/conf/plot_maps.yaml b/src/hirad/conf/plot_maps.yaml new file mode 100644 index 0000000..07ecb85 --- /dev/null +++ b/src/hirad/conf/plot_maps.yaml @@ -0,0 +1,21 @@ +hydra: + job: + chdir: true + name: plot_maps + run: + dir: . + +# Get defaults +defaults: + - _self_ + # Dataset + - dataset/era_cosmo_inference + +results_dir: /capstor/scratch/cscs/pstamenk/outputs/generation/test_period_advanced_stats_2 + +time_steps: null # If null, will use all time steps in the results directory + +output_dir: ${results_dir}/plots_boxcox_masked_limits # Directory to save the plots + +plot_box_precipitation: False # Whether to plot box precipitation maps, otherwise plot with box-cox transform +tp_threshold: 0.0002 # Threshold in m/h for masking precipitation values \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index 2a2fd33..f673e57 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -1,6 +1,6 @@ # Hyperparameters hp: - training_duration: 10000 + training_duration: 3500000 # Training duration based on the number of processed samples total_batch_size: "auto" # Total batch size @@ -34,14 +34,14 @@ io: # regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo_perf_apex_gn/checkpoints_regression # Where to load the regression checkpoint - print_progress_freq: 500 + print_progress_freq: 2000 # How often to print progress - save_checkpoint_freq: 10000 + save_checkpoint_freq: 250000 # How often to save the checkpoints, measured in number of processed samples - visualization_freq: 10000 + visualization_freq: 250000 # how often to visualize network outputs - validation_freq: 2000 + validation_freq: 50000 # how often to record the validation loss, measured in number of processed samples - validation_steps: 2 + validation_steps: 30 # how many loss evaluations are used to compute the validation loss per checkpoint checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index 40e9005..7d35d3d 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,6 +1,6 @@ # Hyperparameters hp: - training_duration: 10000 + training_duration: 1000000 # Training duration based on the number of processed samples total_batch_size: "auto" # Total batch size -- based 8 per GPU -- 2 nodes is 2x8x4 -- see sbatch vars for how many gpus. diffusion need to point to the rgression. @@ -34,12 +34,12 @@ perf: io: print_progress_freq: 500 # How often to print progress - save_checkpoint_freq: 10000 + save_checkpoint_freq: 100000 # How often to save the checkpoints, measured in number of processed samples - visualization_freq: 10000 + visualization_freq: 50000 # how often to visualize network output - validation_freq: 2000 + validation_freq: 20000 # how often to record the validation loss, measured in number of processed samples - validation_steps: 2 + validation_steps: 30 # how many loss evaluations are used to compute the validation loss per checkpoint checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index fee7627..467c52d 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -15,13 +15,13 @@ defaults: # Model - model/era_cosmo_diffusion - - model_size/mini + - model_size/normal # Training - training/era_cosmo_diffusion # Inference visualization - - generation/era_cosmo_training + # - generation/era_cosmo_training # Logging - logging/era_cosmo_diffusion \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml index ce04119..8059f33 100644 --- a/src/hirad/conf/training_era_cosmo_regression.yaml +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -15,13 +15,13 @@ defaults: # Model - model/era_cosmo_regression - - model_size/mini + - model_size/normal # Training - training/era_cosmo_regression # Inference visualization - - generation/era_cosmo_training + # - generation/era_cosmo_training # Logging - logging/era_cosmo_regression \ No newline at end of file diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index f97dbc6..76d1982 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -7,7 +7,7 @@ import torch.nn.functional as F class ERA5_COSMO(DownscalingDataset): - def __init__(self, dataset_path: str): + def __init__(self, dataset_path: str, input_channel_names: List[str] = [], output_channel_names: List[str] = []): super().__init__() #TODO switch hanbdling paths to Path rather than pure strings @@ -22,51 +22,61 @@ def __init__(self, dataset_path: str): # Load cosmo info and channel names with open(os.path.join(self._info_path,'cosmo.yaml'), 'r') as file: self._cosmo_info = yaml.safe_load(file) - self._cosmo_channels = [ChannelMetadata(name) for name in self._cosmo_info['select']] + if output_channel_names: + self._cosmo_indeces = [self._cosmo_info['select'].index(name) for name in output_channel_names] + else: + self._cosmo_indeces = list(range(len(self._cosmo_info['select']))) + output_channel_names = self._cosmo_info['select'] + self._cosmo_channels = [ChannelMetadata(name) if len(name.split('_'))==1 + else ChannelMetadata(name.split('_')[0],name.split('_')[1]) + for name in self._cosmo_info['select'] if name in output_channel_names] # Load era5 info and channel names with open(os.path.join(self._info_path,'era.yaml'), 'r') as file: self._era_info = yaml.safe_load(file) + if input_channel_names: + self._era_indeces = [self._era_info['select'].index(name) for name in input_channel_names] + else: + self._era_indeces = list(range(len(self._era_info['select']))) + input_channel_names = self._era_info['select'] self._era_channels = [ChannelMetadata(name) if len(name.split('_'))==1 - else ChannelMetadata(name.split('_')[0],name.split('_')[1]) - for name in self._era_info['select']] + else ChannelMetadata(name.split('_')[0],name.split('_')[1]) + for name in self._era_info['select'] if name in input_channel_names] # Load stats for normalizing channels of input and output cosmo_stats = torch.load(os.path.join(self._info_path,'cosmo-stats'), weights_only=False) - self.output_mean = cosmo_stats['mean'] - self.output_std = cosmo_stats['stdev'] + self.output_mean = cosmo_stats['mean'][self._cosmo_indeces] + self.output_std = cosmo_stats['stdev'][self._cosmo_indeces] era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False) - self.input_mean = era_stats['mean'] - self.input_std = era_stats['stdev'] - + self.input_mean = era_stats['mean'][self._era_indeces] + self.input_std = era_stats['stdev'][self._era_indeces] def __getitem__(self, idx): """Get cosmo and era5 interpolated to cosmo grid""" - # get era5 data point + # get data point # squeeze the ensemble dimesnsion # reshape to image_shape # flip so that it starts in top-left corner (by default it is bottom left) # orig_shape = [350,542] #TODO currently padding to be divisible by 16 orig_shape = self.image_shape() - era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ + era5_data = torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)[self._era_indeces] + era5_data = np.flip(era5_data \ .squeeze() \ .reshape(-1,*orig_shape), 1) era5_data = self.normalize_input(era5_data) - # get cosmo data point - cosmo_data = np.flip(torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ + + cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)[self._cosmo_indeces] + cosmo_data = np.flip(cosmo_data\ .squeeze() \ .reshape(-1,*orig_shape), 1) cosmo_data = self.normalize_output(cosmo_data) - # return samples + return torch.tensor(cosmo_data),\ torch.tensor(era5_data), - # return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ - # F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ - # 0 def __len__(self): return len(self._file_list) diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py index 9695818..75c0fe7 100644 --- a/src/hirad/distributed/manager.py +++ b/src/hirad/distributed/manager.py @@ -529,7 +529,7 @@ def setup( DistributedManager._shared_state["_is_initialized"] = True manager = DistributedManager() - manager._distributed = torch.distributed.is_available() + manager._distributed = torch.distributed.is_available() and world_size > 1 if manager._distributed: # Update rank and world_size if using distributed manager._rank = rank @@ -546,23 +546,23 @@ def setup( #TODO device_id makes the init hang, couldn't figure out why if manager._distributed: # Setup distributed process group - # try: - dist.init_process_group( - backend, - rank=manager.rank, - world_size=manager.world_size, - device_id=manager.device, - ) + try: + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + device_id=manager.device, + ) # rank=manager.rank, # world_size=manager.world_size, # device_id=manager.device, - # except TypeError: - # # device_id only introduced in PyTorch 2.3 - # dist.init_process_group( - # backend, - # rank=manager.rank, - # world_size=manager.world_size, - # ) + except TypeError: + # device_id only introduced in PyTorch 2.3 + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + ) if torch.cuda.is_available(): # Set device for this process and empty cache to optimize memory usage diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 49f82f1..3790820 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -90,7 +90,6 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - #TODO fix to use channels_last which is optimal for H100 net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_res.use_fp16 = True @@ -172,10 +171,6 @@ def main(cfg: DictConfig) -> None: logger0.info(f"Generating images, saving results to {output_path}...") batch_size = 1 warmup_steps = min(len(times) - 1, 2) - # Generates model predictions from the input data using the specified - # `generate_fn`, and save the predictions to the provided NetCDF file. It iterates - # through the dataset using a data loader, computes predictions, and saves them along - # with associated metadata. torch_cuda_profiler = ( torch.cuda.profiler.profile() diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index ba1b975..e865bca 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -25,7 +25,7 @@ from hirad.utils.checkpoint import load_checkpoint, save_checkpoint from hirad.utils.patching import RandomPatching2D from hirad.utils.function_utils import get_time_from_range -from hirad.utils.inference_utils import save_images +from hirad.utils.inference_utils import save_results_as_torch from hirad.utils.env_info import get_env_info, flatten_dict from hirad.models import UNet, EDMPrecondSuperResolution from hirad.losses import ResidualLoss, RegressionLoss, RegressionLossCE @@ -830,7 +830,7 @@ def main(cfg: DictConfig) -> None: os.makedirs(output_path) writer_threads.append( writer_executor.submit( - save_images, + save_results_as_torch, output_path, times[visualization_sampler[time_index]], visualization_dataset, diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 14be446..b2a1168 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -210,20 +210,17 @@ def diffusion_step( def save_results_as_torch(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): os.makedirs(output_path, exist_ok=True) - - target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) - # prediction.shape = (num_channels, X, Y) - # prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) - # prediction_ensemble.shape = (num_ensembles, num_channels, X, Y) + target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) prediction_ensemble = np.flip(dataset.denormalize_output(image_pred.squeeze()),-2) - baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) + baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1) if mean_pred is not None: - mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) + torch.save(mean_pred, os.path.join(output_path, f'{time_step}-regression-prediction')) torch.save(target, os.path.join(output_path, f'{time_step}-target')) torch.save(prediction_ensemble, os.path.join(output_path, f'{time_step}-predictions')) torch.save(baseline, os.path.join(output_path, f'{time_step}-baseline')) - +@DeprecationWarning def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): os.makedirs(output_path, exist_ok=True) @@ -250,11 +247,11 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, input_channel_idx = input_channels.index(channel) if channel.name=="tp": - target[idx,::] = _prepare_precipitation(target[idx,:,:]) - prediction[:,idx,::] = _prepare_precipitation(prediction[:,idx,:,:]) - baseline[input_channel_idx,:,:] = _prepare_precipitation(baseline[input_channel_idx,::]) + target[idx,::] = transform_channel(target[idx,:,:]) + prediction[:,idx,::] = transform_channel(prediction[:,idx,:,:]) + baseline[input_channel_idx,:,:] = transform_channel(baseline[input_channel_idx,::]) if mean_pred is not None: - mean_pred[idx,::] = _prepare_precipitation(mean_pred[idx,::]) + mean_pred[idx,::] = transform_channel(mean_pred[idx,::]) if mean_pred is not None: vmin, vmax = calculate_bounds(target[idx,:,:], @@ -314,17 +311,23 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, power['mean_prediction'] = mp_power plot_power_spectra(freqs, power, channel.name, os.path.join(output_path_channel, f'{time_step}-{channel.name}-spectra.jpg')) -def _prepare_precipitation(precip_array): - precip_array = np.clip(precip_array, 0, None) - precip_array = np.where(precip_array == 0, 1e-6, precip_array) +def transform_channel(channel_array, channel_name="tp"): + # precip_array = np.clip(precip_array, 0, None) + # precip_array = np.where(precip_array == 0, 1e-6, precip_array) # epsilon = 1e-2 # precip_array = precip_array + epsilon - precip_array = np.log10(precip_array) + # precip_array = np.log10(precip_array) # log_min, log_max = precip_array.min(), precip_array.max() # precip_array = (precip_array-log_min)/(log_max-log_min) - return precip_array - - + if channel_name == "tp": + channel_array = np.clip(channel_array, 0, None) + channel_array = (np.power(channel_array,0.25)-1)/0.25 + elif channel_name == "2t": + channel_array = channel_array - 273.15 + # precip_array = np.sqrt(precip_array) + return channel_array + +@DeprecationWarning def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): """Plot observed or interpolated data in a scatter plot.""" @@ -339,6 +342,26 @@ def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array plt.close('all') def calculate_bounds(*arrays: np.ndarray) -> tuple[float]: - vmin = min(*[np.min(array).item() for array in arrays]) - vmax = max(*[np.max(array).item() for array in arrays]) + """Calculate consistent bounds across all arrays""" + valid_arrays = [arr for arr in arrays if arr is not None] + if not valid_arrays: + return 0, 1 + + # hanndle if there are masked arrays with invalid values (e.g. NaNs) + all_values = [] + for arr in valid_arrays: + if hasattr(arr, 'compressed'): # Masked array + compressed = arr.compressed() + if len(compressed) > 0: + all_values.extend(compressed) + elif hasattr(arr, 'flatten'): # Regular numpy array + all_values.extend(arr.flatten()) + else: + all_values.append(arr) + + if not all_values: + return 0, 1 + + vmin = min(all_values) + vmax = max(all_values) return vmin, vmax \ No newline at end of file From d28a8becc3bdee7bb2f64511238d5258775bcfa5 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 26 Aug 2025 18:15:21 +0200 Subject: [PATCH 142/189] Enhance evaluation scripts to include regression predictions - Modified `diurnal_cycle_precip_mean_wet-hour.py`, `diurnal_cycle_precip_p99.py`, `diurnal_cycle_temp_wind.py`, `hist.py`, `probability_of_exceedance.py`, `snapshots.py`, and `map_precip_stats.py` to handle regression predictions. Should be refactored for better error exception handling. - Updated plotting functions to include regression predictions where applicable. --- src/hirad/conf/generation/era_cosmo.yaml | 7 +- .../diurnal_cycle_precip_mean_wet-hour.py | 32 +++- src/hirad/eval/diurnal_cycle_precip_p99.py | 19 +- src/hirad/eval/diurnal_cycle_temp_wind.py | 39 ++++- src/hirad/eval/hist.py | 60 ++++--- src/hirad/eval/map_precip_stats.py | 17 +- src/hirad/eval/plot_maps.py | 163 ++++++++++++++++++ src/hirad/eval/probability_of_exceedance.py | 53 +++--- src/hirad/eval/snapshots.py | 42 ++++- 9 files changed, 346 insertions(+), 86 deletions(-) create mode 100644 src/hirad/eval/plot_maps.py diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index a50bf12..3396d43 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -1,4 +1,4 @@ -num_ensembles: 8 +num_ensembles: 16 # Number of ensembles to generate per input seed_batch_size: 4 # Size of the batched inference @@ -37,9 +37,8 @@ perf: io: # res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/diffusion_full/checkpoints_diffusion res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/diffusion_era5_cosmo/checkpoints_diffusion - # res_ckpt_path: null + # res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/diffusion_era5_cosmo_cmp_fix_2/checkpoints_diffusion # Checkpoint filename for the diffusion model reg_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression - # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression # Checkpoint filename for the mean predictor model - output_path: ./outputs/evaluation + output_path: . \ No newline at end of file diff --git a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py index f271588..988024a 100644 --- a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py +++ b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py @@ -65,8 +65,8 @@ def main(cfg: DictConfig): land_mask = load_land_sea_mask() # Prepare lists to collect DataArrays - target_precip, baseline_precip, pred_precip = [], [], [] - target_wet, baseline_wet, pred_wet = [], [], [] + target_precip, baseline_precip, pred_precip, mean_pred_precip = [], [], [], [] + target_wet, baseline_wet, pred_wet, mean_pred_wet = [], [], [], [] # Collect data for idx, ts in enumerate(times, 1): @@ -74,26 +74,38 @@ def main(cfg: DictConfig): target = torch.load(out_root/ts/f"{ts}-target", weights_only=False)[tp_out] * CONV_FACTOR baseline = torch.load(out_root/ts/f"{ts}-baseline", weights_only=False)[tp_in] * CONV_FACTOR / 6. # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False)[:, tp_out, :, :] * CONV_FACTOR + try: + mean_pred = torch.load(out_root/ts/f"{ts}-regression-prediction", weights_only=False)[tp_out] * CONV_FACTOR + except: + mean_pred = None # DataArrays for spatial means at each timestep da_target = xr.DataArray(target, dims=("lat","lon"), coords=land_mask.coords) da_baseline = xr.DataArray(baseline, dims=("lat","lon"), coords=land_mask.coords) da_preds = xr.DataArray(preds, dims=("member","lat","lon"), coords={"member": np.arange(preds.shape[0]), **land_mask.coords}) + if mean_pred is not None: + da_mean_pred = xr.DataArray(mean_pred, dims=("lat","lon"), coords=land_mask.coords) # Apply land mask after conversion to xarray da_target = da_target * land_mask da_baseline = da_baseline * land_mask da_preds = da_preds * land_mask + if mean_pred is not None: + da_mean_pred = da_mean_pred * land_mask # Spatial mean target_precip.append(da_target.mean(dim=("lat","lon")).assign_coords(time=dt)) baseline_precip.append(da_baseline.mean(dim=("lat","lon")).assign_coords(time=dt)) pred_precip.append(da_preds.mean(dim=("lat","lon")).assign_coords(time=dt)) + if mean_pred is not None: + mean_pred_precip.append(da_mean_pred.mean(dim=("lat","lon")).assign_coords(time=dt)) # Wet-hour fraction, i.e., freq(precip) > WET_THRESHOLD target_wet.append(((da_target / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) baseline_wet.append(((da_baseline / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) pred_wet.append(((da_preds / 24> WET_THRESHOLD).mean(dim=("lat","lon")).assign_coords(time=dt))) + if mean_pred is not None: + mean_pred_wet.append(((da_mean_pred / 24 > WET_THRESHOLD).mean().assign_coords(time=dt))) if idx % LOG_INTERVAL == 0 or idx == len(times): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") @@ -102,26 +114,30 @@ def main(cfg: DictConfig): amount_target_mean, _ = concat_and_group_diurnal(target_precip) amount_baseline_mean, _ = concat_and_group_diurnal(baseline_precip) amount_pred_mean, amount_pred_std = concat_and_group_diurnal(pred_precip, is_member=True) + if mean_pred_precip: + amount_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_precip) wet_target_mean, _ = concat_and_group_diurnal(target_wet, scale=100.0) # scale to obtain percentages wet_baseline_mean, _ = concat_and_group_diurnal(baseline_wet, scale=100.0) wet_pred_mean, wet_pred_std = concat_and_group_diurnal(pred_wet, is_member=True, scale=100.0) + if mean_pred_wet: + wet_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_wet, scale=100.0) # Generate plots save_plot( amount_target_mean.hour, - [amount_target_mean, amount_baseline_mean, amount_pred_mean], - [None, None, amount_pred_std], - ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)'], + [amount_target_mean, amount_baseline_mean, amount_pred_mean, amount_mean_pred_mean] if mean_pred_precip else [amount_target_mean, amount_baseline_mean, amount_pred_mean], + [None, None, amount_pred_std, None] if mean_pred_precip else [None, None, amount_pred_std], + ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_precip else ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)'], 'Precipitation (mm/day)', 'Diurnal Cycle of Precip Amount', out_root / 'diurnal_cycle_precip_amount.png' ) save_plot( wet_target_mean.hour, - [wet_target_mean, wet_baseline_mean, wet_pred_mean], - [None, None, wet_pred_std], - ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)'], + [wet_target_mean, wet_baseline_mean, wet_pred_mean, wet_mean_pred_mean] if mean_pred_wet else [wet_target_mean, wet_baseline_mean, wet_pred_mean], + [None, None, wet_pred_std, None] if mean_pred_wet else [None, None, wet_pred_std], + ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wet else ['COSMO-2 Analysis','ERA5','CorrDiff ± Std(Members)'], 'Wet-Hour Fraction [%]', 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)', out_root / 'diurnal_cycle_precip_wethours.png' diff --git a/src/hirad/eval/diurnal_cycle_precip_p99.py b/src/hirad/eval/diurnal_cycle_precip_p99.py index bb64153..316a387 100644 --- a/src/hirad/eval/diurnal_cycle_precip_p99.py +++ b/src/hirad/eval/diurnal_cycle_precip_p99.py @@ -81,13 +81,17 @@ def main(cfg: DictConfig): pct99_std = {} # -- Process target and baseline -- - for mode in ['target', 'baseline']: + for mode in ['target', 'baseline', 'regression-prediction']: logger.info(f"Processing mode: {mode}") data_list = [] - for ts in times: - data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode == 'target' else tp_in] * CONV_FACTOR * land_mask - data_list.append(data) + try: + for ts in times: + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * CONV_FACTOR * land_mask + data_list.append(data) + except: + logger.error(f"Error loading data for mode {mode}. Skipping.") + continue da = xr.DataArray( np.stack(data_list, axis=0), @@ -123,6 +127,7 @@ def main(cfg: DictConfig): # Transpose to get the expected dimension order: [member, time, lat, lon] pred_da = pred_da.transpose('member', 'time', 'lat', 'lon') + logger.info('Calculating 99th percentile for predictions') # Group by hour, compute 99th percentile across time, then spatial mean hourly_p99_by_member = pred_da.groupby('time.hour').quantile(0.99, dim='time').mean(dim=['lat', 'lon']) @@ -134,6 +139,7 @@ def main(cfg: DictConfig): def cycle_fn(x): return x.values.tolist() + [x.values.tolist()[0]] + logger.info("Preparing data for plotting") hrs_c = list(range(24)) + [0 + 24] pct99_lines = [ cycle_fn(pct99_mean['target']), @@ -143,13 +149,16 @@ def cycle_fn(x): cycle_fn(pct99_std['prediction']) ) ] + if 'regression-prediction' in pct99_mean: + pct99_lines.append(cycle_fn(pct99_mean['regression-prediction'])) # Plot combined diurnal 99th-percentile cycle + labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff 99th Pct ± Std', 'Regression Prediction'] if 'regression-prediction' in pct99_mean else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff 99th Pct ± Std'] fn = out_root/'diurnal_cycle_precip_99th_percentile.png' save_plot( hrs_c, pct99_lines, - ['COSMO-2 Analysis','ERA5','CorrDiff 99th Pct ± Std'], + labels, 'Precipitation (mm/day)', 'Diurnal Cycle of 99th-Percentile Precipitation', fn diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py index c0c97fe..5b4c36f 100644 --- a/src/hirad/eval/diurnal_cycle_temp_wind.py +++ b/src/hirad/eval/diurnal_cycle_temp_wind.py @@ -57,8 +57,8 @@ def load(ts, fn): land_mask = load_land_sea_mask() # Prepare lists to collect DataArrays - target_temp, baseline_temp, pred_temp = [], [], [] - target_wind, baseline_wind, pred_wind = [], [], [] + target_temp, baseline_temp, pred_temp, mean_pred_temp = [], [], [], [] + target_wind, baseline_wind, pred_wind, mean_pred_wind = [], [], [], [] def mean_over_land(data, dims, coords, time_coord): da = xr.DataArray(data, dims=dims, coords=coords) * land_mask @@ -72,6 +72,10 @@ def mean_over_land(data, dims, coords, time_coord): target = load(ts, f"{ts}-target") baseline = load(ts, f"{ts}-baseline") predictions = load(ts, f"{ts}-predictions") + try: + regression_pred = load(ts, f"{ts}-regression-prediction") + except: + regression_pred = None # Process temperature (convert to Celsius) target_temp.append(mean_over_land( @@ -81,6 +85,10 @@ def mean_over_land(data, dims, coords, time_coord): pred_temp.append(mean_over_land( predictions[:, t2m_out, :, :] - 273.15, ("member","lat","lon"), {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) + if regression_pred is not None: + mean_pred_temp.append(mean_over_land( + regression_pred[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt)) + # Process wind speed target_wind.append(mean_over_land( @@ -90,6 +98,9 @@ def mean_over_land(data, dims, coords, time_coord): pred_wind.append(mean_over_land( np.hypot(predictions[:, u_out, :, :], predictions[:, v_out, :, :]), ("member","lat","lon"), {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt)) + if regression_pred is not None: + mean_pred_wind.append(mean_over_land( + np.hypot(regression_pred[u_out], regression_pred[v_out]), ("lat","lon"), land_mask.coords, dt)) if idx % LOG_INTERVAL == 0 or idx == len(times): logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})") @@ -98,10 +109,14 @@ def mean_over_land(data, dims, coords, time_coord): temp_target_mean, _ = concat_and_group_diurnal(target_temp) temp_baseline_mean, _ = concat_and_group_diurnal(baseline_temp) temp_pred_mean, temp_pred_std = concat_and_group_diurnal(pred_temp, is_member=True) + if mean_pred_temp: + temp_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_temp) wind_target_mean, _ = concat_and_group_diurnal(target_wind) wind_baseline_mean, _ = concat_and_group_diurnal(baseline_wind) wind_pred_mean, wind_pred_std = concat_and_group_diurnal(pred_wind, is_member=True) + if mean_pred_wind: + wind_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_wind) def save_plot(hour, means, stds, labels, ylabel, title, out_path): hrs = np.concatenate([hour.values, [24]]) @@ -124,22 +139,30 @@ def save_plot(hour, means, stds, labels, ylabel, title, out_path): plt.savefig(out_path) plt.close() + data = [temp_target_mean, temp_baseline_mean, temp_pred_mean, temp_mean_pred_mean] if mean_pred_temp else [temp_target_mean, temp_baseline_mean, temp_pred_mean] + labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_temp else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)'] + stds = [None, None, temp_pred_std, None] if mean_pred_temp else [None, None, temp_pred_std] + # Generate plots save_plot( temp_target_mean.hour, - [temp_target_mean, temp_baseline_mean, temp_pred_mean], - [None, None, temp_pred_std], - ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)'], + data, + stds, + labels, '2m Temperature [°C]', 'Diurnal Cycle of 2m Temperature', out_root / 'diurnal_cycle_2t.png' ) + data = [wind_target_mean, wind_baseline_mean, wind_pred_mean, wind_mean_pred_mean] if mean_pred_wind else [wind_target_mean, wind_baseline_mean, wind_pred_mean] + labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wind else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)'] + stds = [None, None, wind_pred_std, None] if mean_pred_wind else [None, None, wind_pred_std] + save_plot( wind_target_mean.hour, - [wind_target_mean, wind_baseline_mean, wind_pred_mean], - [None, None, wind_pred_std], - ['COSMO-2 Analysis', 'ERA5', 'CorrDiff ± Std(Members)'], + data, + stds, + labels, 'Windspeed [m/s]', 'Diurnal Cycle of Windspeed', out_root / 'diurnal_cycle_windspeed.png' diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py index bdd241b..183dee8 100644 --- a/src/hirad/eval/hist.py +++ b/src/hirad/eval/hist.py @@ -63,14 +63,14 @@ def save_distribution_plot(hist_data_dict, bin_edges, labels, colors, title, yla # Define line styles for percentiles percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'} percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'} - colors = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green'} + colors = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green', 'regression-prediction': 'red'} legend_added = set() # Plot all percentile lines for dataset_name, data in percentiles_data.items(): color = colors[dataset_name] - if dataset_name in ['target', 'baseline']: + if dataset_name in ['target', 'baseline', 'regression-prediction']: # Single dataset for percentile, value in data.items(): linestyle = percentile_styles[percentile] @@ -138,30 +138,33 @@ def main(cfg: DictConfig): all_land_values = {} # -- Process target and baseline -- - for mode in ['target', 'baseline']: + for mode in ['target', 'baseline', 'regression-prediction']: logger.info(f"Processing mode: {mode}") hist_counts = np.zeros(len(bins) - 1) total_samples = 0 all_values = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Processing timestep {i+1}/{len(times)}") - - data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode == 'target' else tp_in] * CONV_FACTOR_HOURLY * land_mask - - # Apply scaling factor for baseline - if mode == 'baseline': - data = data / 6.0 - - land_values = data.values[~np.isnan(data.values)] - all_values.extend(land_values) - - counts, _ = np.histogram(land_values, bins=bins) - hist_counts += counts - total_samples += len(land_values) - + try: + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target', 'regression-prediction'] else tp_in] * CONV_FACTOR_HOURLY * land_mask + + # Apply scaling factor for baseline + if mode == 'baseline': + data = data / 6.0 + + land_values = data.values[~np.isnan(data.values)] + all_values.extend(land_values) + + counts, _ = np.histogram(land_values, bins=bins) + hist_counts += counts + total_samples += len(land_values) + except: + logger.warning(f"{mode} not available, skipping") + continue # Normalize to probability density bin_widths = np.diff(bins) hist_data[mode] = hist_counts / (total_samples * bin_widths) @@ -212,12 +215,13 @@ def main(cfg: DictConfig): percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999} # Target and baseline percentiles - for mode in ['target', 'baseline']: - data_array = xr.DataArray(all_land_values[mode]) - percentiles_data[mode] = { - key: data_array.quantile(p).item() - for key, p in percentiles.items() - } + for mode in ['target', 'baseline', 'regression-prediction']: + if mode in all_land_values: + data_array = xr.DataArray(all_land_values[mode]) + percentiles_data[mode] = { + key: data_array.quantile(p).item() + for key, p in percentiles.items() + } # Ensemble member percentiles percentiles_data['predictions'] = {} @@ -229,8 +233,8 @@ def main(cfg: DictConfig): } # Create distribution plots - labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] - colors = ['blue', 'orange', 'green'] + labels = ['COSMO-2 Analysis', 'ERA5', 'Regression Prediction', 'CorrDiff Ensemble'] if 'regression-prediction' in hist_data else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] + colors = ['blue', 'orange', 'red', 'green'] if 'regression-prediction' in hist_data else ['blue', 'orange', 'green'] fn = out_root / 'precipitation_distribution_over_land.png' save_distribution_plot( diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py index 26740c1..47e4125 100644 --- a/src/hirad/eval/map_precip_stats.py +++ b/src/hirad/eval/map_precip_stats.py @@ -129,7 +129,8 @@ def main(cfg: DictConfig): # Target and baseline modes basic_modes = { 'target': (tp_out, 'COSMO-2 Analysis'), - 'baseline': (tp_in, 'ERA5') + 'baseline': (tp_in, 'ERA5'), + 'regression-prediction': (tp_out, 'Regression Prediction') } logger.info(f"Generating {len(stat_configs)} statistics for {len(basic_modes)} basic modes + predictions") @@ -137,11 +138,15 @@ def main(cfg: DictConfig): logger.info(f"Processing mode: {mode}") # Load all timesteps for this mode data_list = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Loading {mode} timestep {i+1}/{len(times)}: {ts}") - data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) * CONV_FACTOR - data_list.append(data[tp_channel]) + try: + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Loading {mode} timestep {i+1}/{len(times)}: {ts}") + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) * CONV_FACTOR + data_list.append(data[tp_channel]) + except: + logger.warning(f"{mode} not available, skipping") + continue mode_data = xr.DataArray( np.stack(data_list, axis=0), dims=['time', 'lat', 'lon'], diff --git a/src/hirad/eval/plot_maps.py b/src/hirad/eval/plot_maps.py new file mode 100644 index 0000000..a2230a6 --- /dev/null +++ b/src/hirad/eval/plot_maps.py @@ -0,0 +1,163 @@ +import torch +import yaml +import numpy as np +import os +from pathlib import Path +import argparse +from hirad.eval import plotting +from hirad.utils.inference_utils import calculate_bounds, transform_channel +import hydra +from omegaconf import DictConfig, OmegaConf + +channel_plot_args = { + "2t": {"label": "°C"}, + "10u": {"label": "m/s"}, + "10v": {"label": "m/s"}, + "tp": {"label": "boxcox(mm/h)"}, +} + + +def get_available_time_steps(results_dir): + """Find all available time step directories""" + results_path = Path(results_dir) + time_step_dirs = [d.name for d in results_path.iterdir() if d.is_dir() and not d.name.startswith('plots') and not d.name.startswith('maps')] + return sorted(time_step_dirs) + +def plot_time_step(results_dir, output_dir, time_step, input_channels, output_channels, cfg): + """Plot all channels for a single time step""" + print(f"Processing time step: {time_step}") + + # Set up paths for this time step + ts_results_dir = Path(results_dir) / time_step + + # Load tensors + try: + target = torch.load(ts_results_dir / f"{time_step}-target", weights_only=False) + baseline = torch.load(ts_results_dir / f"{time_step}-baseline", weights_only=False) + predictions = torch.load(ts_results_dir / f"{time_step}-predictions", weights_only=False) + + try: + mean_pred = torch.load(ts_results_dir / f"{time_step}-regression-prediction", weights_only=False) + except FileNotFoundError: + mean_pred = None + print(f" Warning: No mean prediction found for {time_step}") + + except FileNotFoundError as e: + print(f" Error: Missing required file for {time_step}: {e}") + return + + # Create output directory for this time step + ts_output_dir = Path(output_dir) / time_step + ts_output_dir.mkdir(parents=True, exist_ok=True) + + # Plot each channel + for idx, channel in enumerate(output_channels): + print(f" Plotting channel: {channel}") + + # Get input channel index (handle case where input/output channels differ) + try: + input_idx = input_channels.index(channel) + except ValueError: + print(f" Warning: Channel {channel} not in input channels, skipping baseline") + continue + + # Transform data + if channel != "tp" or not cfg.get("plot_box_precipitation", False): + tgt = transform_channel(target[idx], channel) + base = transform_channel(baseline[input_idx], channel) + preds = transform_channel(predictions[:, idx], channel) + mean = transform_channel(mean_pred[idx], channel) if mean_pred is not None else None + if channel == "tp": + threshold = transform_channel(np.array([cfg.get("tp_threshold", 0.002)]), "tp")[0] # Transform threshold too + tgt = np.ma.masked_where(tgt <= threshold, tgt) + base = np.ma.masked_where(base <= threshold, base) + preds = np.ma.masked_where(preds <= threshold, preds) + if mean is not None: + mean = np.ma.masked_where(mean <= threshold, mean) + else: + # For precipitation, use raw values if plotting box precipitation + tgt = target[idx] + base = baseline[input_idx] + preds = predictions[:, idx] + mean = mean_pred[idx] if mean_pred is not None else None + + # Calculate consistent bounds (skip for precipitation) + if channel != "tp" or not cfg.get("plot_box_precipitation", False): + arrays = [tgt, base] + [preds[i] for i in range(preds.shape[0])] + if mean is not None: + arrays.append(mean) + vmin, vmax = calculate_bounds(*arrays) + else: + vmin, vmax = None, None + + base_channel_dir = ts_output_dir / channel + base_channel_dir.mkdir(parents=True, exist_ok=True) + + # Plot target + fname = ts_output_dir / channel / f"target" + if channel == "tp" and cfg.get("plot_box_precipitation", False): + plotting.plot_map_precipitation(tgt, str(fname), title=f"Target - {channel}") + else: + plotting.plot_map(tgt, str(fname), vmin=vmin, vmax=vmax, + title=f"Target - {channel}", **channel_plot_args.get(channel, {})) + + # Plot baseline + fname = ts_output_dir / channel / "baseline" + if channel == "tp" and cfg.get("plot_box_precipitation", False): + plotting.plot_map_precipitation(base, str(fname), title=f"Baseline - {channel}") + else: + plotting.plot_map(base, str(fname), vmin=vmin, vmax=vmax, + title=f"Baseline - {channel}", **channel_plot_args.get(channel, {})) + + # Plot mean prediction if available + if mean is not None: + fname = ts_output_dir / channel / "mean-prediction" + if channel == "tp" and cfg.get("plot_box_precipitation", False): + plotting.plot_map_precipitation(mean, str(fname), title=f"Mean Prediction - {channel}") + else: + plotting.plot_map(mean, str(fname), vmin=vmin, vmax=vmax, + title=f"Mean Prediction - {channel}", **channel_plot_args.get(channel, {})) + + # Plot ensemble members + for member_idx in range(preds.shape[0]): + fname = ts_output_dir / channel / f"prediction_{member_idx:02d}" + if channel == "tp" and cfg.get("plot_box_precipitation", False): + plotting.plot_map_precipitation( + preds[member_idx], str(fname), + title=f"Prediction {member_idx} - {channel}" + ) + else: + plotting.plot_map( + preds[member_idx], str(fname), + vmin=vmin, vmax=vmax, + title=f"Prediction {member_idx} - {channel}", + **channel_plot_args.get(channel, {}) + ) + +@hydra.main(version_base=None, config_path="../conf", config_name="plotting") +def main(cfg: DictConfig) -> None: + OmegaConf.resolve(cfg) + + input_channels = cfg.dataset.input_channel_names + output_channels = cfg.dataset.output_channel_names + + # Set up directories + results_dir = Path(cfg.results_dir) + output_dir = Path(cfg.output_dir) if "output_dir" in cfg and cfg.output_dir else results_dir / "plots" + output_dir.mkdir(exist_ok=True) + + # Determine time steps to process + if cfg.time_steps: + time_steps = cfg.time_steps + else: + time_steps = get_available_time_steps(results_dir) + print(f"Found {len(time_steps)} time steps: {time_steps}") + + # Process each time step + for time_step in time_steps: + plot_time_step(results_dir, output_dir, time_step, input_channels, output_channels, cfg) + + print(f"Plotting complete. Results saved to: {output_dir}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/eval/probability_of_exceedance.py b/src/hirad/eval/probability_of_exceedance.py index febda1f..e301016 100644 --- a/src/hirad/eval/probability_of_exceedance.py +++ b/src/hirad/eval/probability_of_exceedance.py @@ -60,14 +60,14 @@ def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title # Define line styles for percentiles percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'} percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'} - colors_perc = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green'} + colors_perc = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green', 'regression-prediction': 'red'} legend_added = set() # Plot all percentile lines for dataset_name, data in percentiles_data.items(): color = colors_perc[dataset_name] - if dataset_name in ['target', 'baseline']: + if dataset_name in ['target', 'baseline', 'regression-prediction']: # Single dataset for percentile, value in data.items(): linestyle = percentile_styles[percentile] @@ -135,24 +135,28 @@ def main(cfg: DictConfig): all_land_values = {} # -- Process target and baseline -- - for mode in ['target', 'baseline']: + for mode in ['target', 'baseline', 'regression-prediction']: logger.info(f"Processing mode: {mode}") all_values = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Processing timestep {i+1}/{len(times)}") - - data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode == 'target' else tp_in] * CONV_FACTOR_HOURLY * land_mask - - # Apply scaling factor for baseline - if mode == 'baseline': - data = data / 6.0 - - land_values = data.values[~np.isnan(data.values)] - all_values.extend(land_values) - + try: + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * CONV_FACTOR_HOURLY * land_mask + + # Apply scaling factor for baseline + if mode == 'baseline': + data = data / 6.0 + + land_values = data.values[~np.isnan(data.values)] + all_values.extend(land_values) + except: + logger.warning(f"{mode} data not found, skipping") + continue + # Compute exceedance probabilities all_values = np.array(all_values) exceedance_probs = [] @@ -204,12 +208,13 @@ def main(cfg: DictConfig): percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999} # Target and baseline percentiles - for mode in ['target', 'baseline']: - data_array = xr.DataArray(all_land_values[mode]) - percentiles_data[mode] = { - key: data_array.quantile(p).item() - for key, p in percentiles.items() - } + for mode in ['target', 'baseline', 'regression-prediction']: + if mode in all_land_values: + data_array = xr.DataArray(all_land_values[mode]) + percentiles_data[mode] = { + key: data_array.quantile(p).item() + for key, p in percentiles.items() + } # Ensemble member percentiles percentiles_data['predictions'] = {} @@ -221,8 +226,8 @@ def main(cfg: DictConfig): } # Create exceedance plots - labels = ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] - colors = ['blue', 'orange', 'green'] + labels = ['COSMO-2 Analysis', 'ERA5', 'Regression Prediction', 'CorrDiff Ensemble'] if 'regression-prediction' in exceedance_data else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] + colors = ['blue', 'orange', 'red', 'green'] if 'regression-prediction' in exceedance_data else ['blue', 'orange', 'green'] fn = out_root / 'precipitation_exceedance_over_land.png' save_exceedance_plot( diff --git a/src/hirad/eval/snapshots.py b/src/hirad/eval/snapshots.py index e68f956..08d6e6b 100644 --- a/src/hirad/eval/snapshots.py +++ b/src/hirad/eval/snapshots.py @@ -139,6 +139,10 @@ def main(cfg: DictConfig) -> None: prediction = files.load(curr_time, f'{curr_time}-predictions') baseline = files.load(curr_time, f'{curr_time}-baseline') target = files.load(curr_time, f'{curr_time}-target') + try: + mean_pred = files.load(curr_time, f'{curr_time}-regression-prediction') + except: + mean_pred = None input_channels = dataset.input_channels() output_channels = dataset.output_channels() @@ -150,7 +154,8 @@ def main(cfg: DictConfig) -> None: vmin, vmax = calculate_bounds( target[idx,:,:], prediction[:,idx,:,:], - baseline[input_channel_idx,:,:] + baseline[input_channel_idx,:,:], + mean_pred[idx,:,:] if mean_pred is not None else None ) metadata = ChannelMeta.get(channel, vmin=vmin, vmax=vmax) @@ -174,11 +179,18 @@ def main(cfg: DictConfig) -> None: "prediction", prediction[0, idx, :, :], metadata, files, channel, curr_time, plot_func=plot_map_precipitation, title=plot_title ) + if mean_pred is not None: + save_field( + "mean-prediction", mean_pred[idx, :, :], metadata, files, channel, curr_time, + plot_func=plot_map_precipitation, title=plot_title + ) continue - # Plot target and baseline + # Plot target and baseline and regression prediction if available save_field("target", target[idx, :, :], metadata, files, channel, curr_time, title=plot_title) save_field("baseline", baseline[input_channel_idx, :, :], metadata, files, channel, curr_time, title=plot_title) + if mean_pred is not None: + save_field("mean-prediction", mean_pred[idx, :, :], metadata, files, channel, curr_time, title=plot_title) # Baseline MAE and ME _, baseline_mae = compute_mae(baseline[input_channel_idx, :, :], target[idx, :, :]) @@ -186,6 +198,13 @@ def main(cfg: DictConfig) -> None: save_field("baseline", baseline_mae.reshape(baseline[input_channel_idx, :, :].shape), metadata, files, channel, curr_time, kind="mae", cmap=metadata.cmap if channel.name not in ("10u", "10v", "2t") else 'viridis', vmin=0, vmax=metadata.err_vmax, title=plot_title) save_field("baseline", baseline_me, metadata, files, channel, curr_time, kind="me", cmap=metadata.me_cmap, vmin=metadata.err_vmin, vmax=metadata.err_vmax, title=plot_title) + # Regression prediction MAE and ME + if mean_pred is not None: + _, mean_mae = compute_mae(mean_pred[idx, :, :], target[idx, :, :]) + mean_me = (mean_pred[idx, :, :] - target[idx, :, :]) + save_field("mean-prediction", mean_mae.reshape(mean_pred[idx, :, :].shape), metadata, files, channel, curr_time, kind="mae", cmap=metadata.cmap if channel.name not in ("10u", "10v", "2t") else 'viridis', vmin=0, vmax=metadata.err_vmax, title=plot_title) + save_field("mean-prediction", mean_me, metadata, files, channel, curr_time, kind="me", cmap=metadata.me_cmap, vmin=metadata.err_vmin, vmax=metadata.err_vmax, title=plot_title) + # Ensemble predictions for member_idx in range(prediction.shape[0]): member = prediction[member_idx, idx, :, :] @@ -203,13 +222,16 @@ def main(cfg: DictConfig) -> None: input_idx_10u = output_to_input_channel_map[idx_10u] input_idx_10v = output_to_input_channel_map[idx_10v] - # Compute windspeed and direction for target, baseline, prediction + # Compute windspeed and direction for target, baseline, prediction and mean prediction target_wind_speed = np.hypot(target[idx_10u, :, :], target[idx_10v, :, :]) target_wind_dir = wind_direction(target[idx_10u, :, :], target[idx_10v, :, :]) baseline_wind_speed = np.hypot(baseline[input_idx_10u, :, :], baseline[input_idx_10v, :, :]) baseline_wind_dir = wind_direction(baseline[input_idx_10u, :, :], baseline[input_idx_10v, :, :]) prediction_wind_speed = np.hypot(prediction[:, idx_10u, :, :], prediction[:, idx_10v, :, :]) prediction_wind_dir = wind_direction(prediction[:, idx_10u, :, :], prediction[:, idx_10v, :, :]) + if mean_pred is not None: + mean_wind_speed = np.hypot(mean_pred[idx_10u, :, :], mean_pred[idx_10v, :, :]) + mean_wind_dir = wind_direction(mean_pred[idx_10u, :, :], mean_pred[idx_10v, :, :]) plot_title_speed = f"{format_time_str(curr_time)}: FF10m" plot_title_dir = f"{format_time_str(curr_time)}: DD10m" @@ -237,6 +259,13 @@ def main(cfg: DictConfig) -> None: custom_path=files.wind_file("FF10m", curr_time, "FF10m-prediction", member_idx), plot_func=plot_map, title=plot_title_speed ) + if mean_pred is not None: + save_field( + "FF10m-mean-prediction", mean_wind_speed, wind_meta, files, None, curr_time, + cmap="viridis", vmin=0, vmax=10, extend='max', + custom_path=files.wind_file("FF10m", curr_time, "FF10m-mean-prediction"), + plot_func=plot_map, title=plot_title_speed + ) # Save wind direction plots save_field( @@ -258,6 +287,13 @@ def main(cfg: DictConfig) -> None: custom_path=files.wind_file("DD10m", curr_time, "DD10m-prediction", member_idx), plot_func=plot_map, title=plot_title_dir ) + if mean_pred is not None: + save_field( + "DD10m-mean-prediction", mean_wind_dir, dir_meta, files, None, curr_time, + cmap="twilight", vmin=0, vmax=360, + custom_path=files.wind_file("DD10m", curr_time, "DD10m-mean-prediction"), + plot_func=plot_map, title=plot_title_dir + ) logger.info("Image loading and plotting completed.") From 5a5412e0fd315b3c8926f9a117237efaddcad471 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 27 Aug 2025 08:21:05 +0200 Subject: [PATCH 143/189] one-off script to regrid copernicus data --- src/hirad/input_data/copernicus-tp.sh | 5 + src/hirad/input_data/regrid_copernicus_tp.py | 173 +++++++++++++++++++ 2 files changed, 178 insertions(+) create mode 100644 src/hirad/input_data/copernicus-tp.sh create mode 100644 src/hirad/input_data/regrid_copernicus_tp.py diff --git a/src/hirad/input_data/copernicus-tp.sh b/src/hirad/input_data/copernicus-tp.sh new file mode 100644 index 0000000..943812f --- /dev/null +++ b/src/hirad/input_data/copernicus-tp.sh @@ -0,0 +1,5 @@ +#!/bin/sh + +pip install -e . +pip install anemoi.datasets +python src/hirad/input_data/read_tp.py diff --git a/src/hirad/input_data/regrid_copernicus_tp.py b/src/hirad/input_data/regrid_copernicus_tp.py new file mode 100644 index 0000000..a098c68 --- /dev/null +++ b/src/hirad/input_data/regrid_copernicus_tp.py @@ -0,0 +1,173 @@ +import logging +import netCDF4 +import xarray +import numpy as np +import torch +import datetime +from scipy.interpolate import griddata + +from hirad.eval.plotting import plot_map_precipitation, plot_scores_vs_t +from hirad.eval.metrics import absolute_error + +import interpolate_basic + +import sys +from pathlib import Path + +import os +print (os.getcwd()) + +sys.path.insert(0, Path(__file__).parent.as_posix()) + +ANEMOI_1H_FILENAME = "/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr" +ANEMOI_6H_FILENAME = "/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr" +COSMO_6H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr" +COSMO_1H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr" +COSMO_CONFIG_FILE="src/input_data/cosmo.yaml" +CDF_FILENAME_BALFRIN = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-janfeb2020.nc" +CDF_FILENAME_CLARIDEN_TP = "/capstor/scratch/cscs/mmcgloho/datasets/copernicus/tp-2019-2020/data_stream-oper_stepType-accum.nc" +#CDF_FILENAME_CLARIDEN_INSTANT = "/capstor/store/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020-netcdf/data_stream-oper_stepType-instant.nc" +GRIB_FILENAME_BALFRIN = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020.grib" +GRIB_FILENAME_CLARIDEN = "/capstor/store/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020.grib" +COSMO_GRID_FILENAME = " /capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/info/cosmo-lat-lon" +INPUT_DATA_FILEPATH = "mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/" +BASE_FILEPATH = "/capstor/store/" +OUTPUT_DATA_FILEPATH = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp" +TP_INDEX = 12 + +LAT = np.arange(-4.42, 3.36 + 0.02, 0.02) +LON = np.arange(-6.82, 4.80 + 0.02, 0.02) +RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) + +def extract_grib_values(grib_data): + grib_lat = grib_data['latitude'][:] + grib_lon = grib_data['longitude'][:] + grib_t2m = grib_data['t2m'][:] + +def extract_lat_lon(data): + logging.info('extracting lat/lon') + lat = data['latitude'][:] + lon = data['longitude'][:] + output_lat = np.zeros(len(lat)* len(lon)) + output_lon = np.zeros(len(lat) * len(lon)) + for i in range(len(lat)): + if i % 10 == 0: + print(i) + for j in range(len(lon)): + grid_index = i * len(lon) + j + output_lat[grid_index] = lat[i] + output_lon[grid_index] = lon[j] + return output_lat, output_lon + +def extract_values(data, variable): + values = data[variable][:] + return np.reshape(values, (values.shape[0], values.shape[1]*values.shape[2])) + +def reshape_to_cosmo(vals): + return vals.reshape((len(LAT)-RELAX_ZONE*2, len(LON)-RELAX_ZONE*2)) + +def calc_errors(cosmo1, era1): + make_plots = True + + prev_netcdf_regrid = [] + + netcdf_error = np.zeros(cosmo1.dates.shape) + era_norm_error = np.zeros(cosmo1.dates.shape) + netcdf_early_error = np.zeros(cosmo1.dates.shape) + netcdf_late_error = np.zeros(cosmo1.dates.shape) + + output_grid= np.column_stack((cosmo1.longitudes, cosmo1.latitudes)) + + for t in range(4): + #for t in range(len(cosmo1.dates)): + date = cosmo1.dates[t] + era_date = era1.dates[t] + if date != era_date: + logging.error('dates do not match: cosmo date: {date}, era date: {era_date}') + if date != netcdf_data['valid_time'][t]: + logging.error(f'dates do not match: cosmo date: {date}, netcdf: {netcdf_data["valid_time"][t]}') + + + # plot cosmo + if make_plots: + plot_map_precipitation(values=reshape_to_cosmo(cosmo1[t,:]), filename=f'plots/tp/{date}-cosmo1') + + # plot netcdf + netcdf_vals = netcdf_values[t,:].reshape((1,1,netcdf_values.shape[1])) + netcdf_regrid=interpolate_basic.regrid(netcdf_vals, netcdf_grid, output_grid) + if make_plots: + plot_map_precipitation(reshape_to_cosmo(netcdf_regrid), f'plots/tp/{date}-netcdf-refactor') + + # plot era + era_grid = np.column_stack((era1.longitudes, era1.latitudes)) + era1_regrid = interpolate_basic.regrid(era1[t,:], era_grid, output_grid) + if make_plots: + plot_map_precipitation(reshape_to_cosmo(era1_regrid/6), f'plots/tp/{date}-era1-norm') + + #if t % 6 == 0: + # if era6.dates[t//6] != date: + # logging.error(f'dates do not match: era1: {date}, era6: {era6.dates[t//6]}') + # era6_regrid = interpolate_basic.regrid(era6[t//6,:], era_grid, output_grid) + # plot_map_precipitation(reshape_to_cosmo(era6_regrid), f'plots/tp/{date}-era6') + + era_norm_error[t] = np.mean(absolute_error(era1_regrid/6, cosmo1[t,:])) + netcdf_error[t] = np.mean(absolute_error(netcdf_regrid, cosmo1[t,:])) + logging.info(f'era norm error: {era_norm_error[t]} netcdf err: {netcdf_error[t]}') + if t>0: + netcdf_early_error[t] = np.mean(absolute_error(prev_netcdf_regrid, cosmo1[t,:])) + netcdf_late_error[t-1] = np.mean(absolute_error(netcdf_regrid, cosmo1[t-1,:])) + logging.info(f'netcdf early err: {netcdf_early_error[t]}, netcdf late err: {netcdf_late_error[t-1]}') + prev_netcdf_regrid = netcdf_regrid + + maes = {} + maes['era normalized'] = era_norm_error + maes['copernicus'] = netcdf_error + maes['copernicus-early'] = netcdf_early_error + maes['copernicus-late'] = netcdf_late_error + plot_scores_vs_t(maes, times=cosmo1.dates, filename='plots/errors.png') + + +root = logging.getLogger() +root.setLevel(logging.INFO) + +logging.info('loading data') +netcdf_data = netCDF4.Dataset(CDF_FILENAME_CLARIDEN_TP) + +#cosmo1 = open_dataset(COSMO_1H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') +#cosmo6 = open_dataset(COSMO_6H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') +#era1 = open_dataset(ANEMOI_1H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') +#era6 = open_dataset(ANEMOI_6H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') +logging.info('loading data complete') + +logging.info(os.listdir(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'info'))) + +cosmo_grid = torch.load(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'info/cosmo-lat-lon'), weights_only=False) + + +logging.info('processing netcdf data') +netcdf_latitudes, netcdf_longitudes = extract_lat_lon(netcdf_data) +netcdf_tp_values = extract_values(netcdf_data, 'tp') +netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) + +#for t in range(10): +for t in range(netcdf_tp_values.shape[0]): + netcdf_date = netcdf_data['valid_time'][t] + date_filename = datetime.datetime.fromtimestamp(netcdf_date, datetime.UTC).strftime('%Y%m%d-%H%M') + t1 = datetime.datetime.now() + era_filename = os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, "era-interpolated", date_filename) + if os.path.exists(era_filename) and netcdf_date > 1560229200: + era_data = torch.load(era_filename, weights_only=False) + t2 = datetime.datetime.now() + logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})') + interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], cosmo_grid, method='linear') + t3 = datetime.datetime.now() + era_data[TP_INDEX,0,:] = interpolated_tp + torch.save(era_data, os.path.join(OUTPUT_DATA_FILEPATH, date_filename)) + t4 = datetime.datetime.now() + + + + + + + From f1d71017672b95b3aa8097d65fb4a91cce140cbe Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 27 Aug 2025 08:22:47 +0200 Subject: [PATCH 144/189] Add input_data dependencies to container --- ci/docker/Dockerfile.corrdiff | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ci/docker/Dockerfile.corrdiff b/ci/docker/Dockerfile.corrdiff index 7908197..3187c84 100644 --- a/ci/docker/Dockerfile.corrdiff +++ b/ci/docker/Dockerfile.corrdiff @@ -6,6 +6,8 @@ RUN pip install --upgrade pip # Install the rest of dependencies. RUN pip install \ + anemoi.datasets \ Cartopy==0.22.0 \ xskillscore \ + scoringrules \ mlflow From 3642980b368c2178aef492207ea8e329a4d4b049 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 27 Aug 2025 16:33:30 +0200 Subject: [PATCH 145/189] add grid variable --- src/hirad/input_data/download_copernicus_tp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hirad/input_data/download_copernicus_tp.py b/src/hirad/input_data/download_copernicus_tp.py index 55031ee..6457eed 100644 --- a/src/hirad/input_data/download_copernicus_tp.py +++ b/src/hirad/input_data/download_copernicus_tp.py @@ -34,7 +34,8 @@ "21:00", "22:00", "23:00" ], "data_format": "netcdf", - "download_format": "unarchived" + "download_format": "unarchived", + "grid": "N320" } client = cdsapi.Client() From fe621ac9e4d4a7bfa61eef9441d26ef488ab347d Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 1 Sep 2025 10:38:20 +0200 Subject: [PATCH 146/189] more on processing copernicus data --- .../input_data/download_copernicus_tp.py | 11 +- src/hirad/input_data/regrid_copernicus_tp.py | 110 ++++++++++++------ 2 files changed, 84 insertions(+), 37 deletions(-) diff --git a/src/hirad/input_data/download_copernicus_tp.py b/src/hirad/input_data/download_copernicus_tp.py index 6457eed..6eee50c 100644 --- a/src/hirad/input_data/download_copernicus_tp.py +++ b/src/hirad/input_data/download_copernicus_tp.py @@ -5,10 +5,14 @@ "product_type": ["reanalysis"], "variable": ["total_precipitation"], "year": [ - "2016" + "2015", "2016", "2017", + "2018", "2019", "2020", ], "month": [ - "01", "02" + "01", "02", "03", + "04", "05", "06", + "07", "08", "09", + "10", "11", "12", ], "day": [ "01", "02", "03", @@ -35,7 +39,8 @@ ], "data_format": "netcdf", "download_format": "unarchived", - "grid": "N320" + "grid": "N320", + "area": [60, 0, 40, 20] } client = cdsapi.Client() diff --git a/src/hirad/input_data/regrid_copernicus_tp.py b/src/hirad/input_data/regrid_copernicus_tp.py index a098c68..93a142f 100644 --- a/src/hirad/input_data/regrid_copernicus_tp.py +++ b/src/hirad/input_data/regrid_copernicus_tp.py @@ -19,20 +19,17 @@ sys.path.insert(0, Path(__file__).parent.as_posix()) -ANEMOI_1H_FILENAME = "/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr" -ANEMOI_6H_FILENAME = "/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr" -COSMO_6H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr" -COSMO_1H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr" -COSMO_CONFIG_FILE="src/input_data/cosmo.yaml" + CDF_FILENAME_BALFRIN = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-janfeb2020.nc" -CDF_FILENAME_CLARIDEN_TP = "/capstor/scratch/cscs/mmcgloho/datasets/copernicus/tp-2019-2020/data_stream-oper_stepType-accum.nc" -#CDF_FILENAME_CLARIDEN_INSTANT = "/capstor/store/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020-netcdf/data_stream-oper_stepType-instant.nc" -GRIB_FILENAME_BALFRIN = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020.grib" -GRIB_FILENAME_CLARIDEN = "/capstor/store/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020.grib" -COSMO_GRID_FILENAME = " /capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/info/cosmo-lat-lon" -INPUT_DATA_FILEPATH = "mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/" +#CDF_FILENAME_CLARIDEN_TP = "/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc" +#CDF_FILENAME_CLARIDEN_TP = "/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2017-2018-n320.nc" +CDF_FILENAME_CLARIDEN_TP = "/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2015-2016.nc" + + BASE_FILEPATH = "/capstor/store/" -OUTPUT_DATA_FILEPATH = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp" +INPUT_DATA_FILEPATH = "mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/" +OUTPUT_DATA_FILEPATH_ERA_INTERPOLATED = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp" +OUTPUT_DATA_FILEPATH_ERA = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-with-copernicus-tp" TP_INDEX = 12 LAT = np.arange(-4.42, 3.36 + 0.02, 0.02) @@ -44,10 +41,10 @@ def extract_grib_values(grib_data): grib_lon = grib_data['longitude'][:] grib_t2m = grib_data['t2m'][:] -def extract_lat_lon(data): +def extract_lat_lon_025(data): logging.info('extracting lat/lon') - lat = data['latitude'][:] - lon = data['longitude'][:] + lat = data['latitudes'][:] + lon = data['longitudes'][:] output_lat = np.zeros(len(lat)* len(lon)) output_lon = np.zeros(len(lat) * len(lon)) for i in range(len(lat)): @@ -59,8 +56,27 @@ def extract_lat_lon(data): output_lon[grid_index] = lon[j] return output_lat, output_lon -def extract_values(data, variable): +def extract_lat_lon_n320(data): + lat = data['latitudes'][:] + lon = data['longitudes'][:] + logging.info('extracting lat/lon') + logging.info(f'lat lon shapes {lat.shape} {lon.shape}') + +def extract_values(data, variable, area=None): values = data[variable][:] + print(values.shape) + if area: + lat = data['latitude'][:] + lon = data['longitude'][:] + # https://stackoverflow.com/questions/29135885/netcdf4-extract-for-subset-of-lat-lon + latli = np.argmin( np.abs(lat - area[2])) + latui = np.argmin( np.abs(lat - area[0])) + lonli = np.argmin( np.abs(lon - area[1])) + lonui = np.argmin( np.abs(lon - area[3])) + lat = data['latitude'][latli:latui] + lon = data['longitude'][lonli:lonui] + + values = data[variable][latli:latui,lonli:lonui] return np.reshape(values, (values.shape[0], values.shape[1]*values.shape[2])) def reshape_to_cosmo(vals): @@ -126,6 +142,37 @@ def calc_errors(cosmo1, era1): maes['copernicus-late'] = netcdf_late_error plot_scores_vs_t(maes, times=cosmo1.dates, filename='plots/errors.png') +def process_era_interpolated(netcdf_data, netcdf_tp_values): + #for t in range(10): + for t in range(netcdf_tp_values.shape[0]): + netcdf_date = netcdf_data['valid_time'][t] + date_filename = datetime.datetime.fromtimestamp(netcdf_date, datetime.UTC).strftime('%Y%m%d-%H%M') + t1 = datetime.datetime.now() + era_filename = os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, "era-interpolated", date_filename) + if os.path.exists(era_filename): + era_data = torch.load(era_filename, weights_only=False) + t2 = datetime.datetime.now() + logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})') + interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], cosmo_grid, method='linear') + t3 = datetime.datetime.now() + era_data[TP_INDEX,0,:] = interpolated_tp + torch.save(era_data, os.path.join(OUTPUT_DATA_FILEPATH_ERA_INTERPOLATED, date_filename)) + t4 = datetime.datetime.now() + +def process_era(netcdf_data, netcdf_tp_values): + for t in range(netcdf_tp_values.shape[0]): + netcdf_date = netcdf_data['valid_time'][t] + date_filename = datetime.datetime.fromtimestamp(netcdf_date, datetime.UTC).strftime('%Y%m%d-%H%M') + era_filename = os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, "era", date_filename) + if os.path.exists(era_filename): + era_data = torch.load(era_filename, weights_only=False) + t2 = datetime.datetime.now() + logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})') + interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], era_grid, method='linear') + t3 = datetime.datetime.now() + era_data[TP_INDEX,0,:] = interpolated_tp + torch.save(era_data, os.path.join(OUTPUT_DATA_FILEPATH_ERA, date_filename)) + t4 = datetime.datetime.now() root = logging.getLogger() root.setLevel(logging.INFO) @@ -139,31 +186,26 @@ def calc_errors(cosmo1, era1): #era6 = open_dataset(ANEMOI_6H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') logging.info('loading data complete') -logging.info(os.listdir(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'info'))) - cosmo_grid = torch.load(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'info/cosmo-lat-lon'), weights_only=False) +era_grid = torch.load(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'info/era-lat-lon'), weights_only=False) logging.info('processing netcdf data') -netcdf_latitudes, netcdf_longitudes = extract_lat_lon(netcdf_data) + +netcdf_latitudes, netcdf_longitudes = extract_lat_lon_n320(netcdf_data) netcdf_tp_values = extract_values(netcdf_data, 'tp') + +#netcdf_latitudes, netcdf_longitudes = extract_lat_lon(netcdf_data, area=[60, 0, 40, 20]) +#netcdf_tp_values = extract_values(netcdf_data, 'tp', area=[60, 0, 40, 20]) netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) +logging.info(f'netcdf shape {netcdf_grid.shape}') +logging.info(f'{netcdf_grid[1:10,:]}') +logging.info(f'era shape {era_grid.shape}') +logging.info(f'{era_grid[1:10,:]}') + +process_era_interpolated(netcdf_data, netcdf_tp_values) +process_era(netcdf_data, netcdf_tp_values) -#for t in range(10): -for t in range(netcdf_tp_values.shape[0]): - netcdf_date = netcdf_data['valid_time'][t] - date_filename = datetime.datetime.fromtimestamp(netcdf_date, datetime.UTC).strftime('%Y%m%d-%H%M') - t1 = datetime.datetime.now() - era_filename = os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, "era-interpolated", date_filename) - if os.path.exists(era_filename) and netcdf_date > 1560229200: - era_data = torch.load(era_filename, weights_only=False) - t2 = datetime.datetime.now() - logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})') - interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], cosmo_grid, method='linear') - t3 = datetime.datetime.now() - era_data[TP_INDEX,0,:] = interpolated_tp - torch.save(era_data, os.path.join(OUTPUT_DATA_FILEPATH, date_filename)) - t4 = datetime.datetime.now() From 4aecb65bb54f7d03a4a3ce5bd478b5cea136bd90 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 1 Sep 2025 10:38:33 +0200 Subject: [PATCH 147/189] removing old file --- src/hirad/input_data/read_tp.py | 184 -------------------------------- 1 file changed, 184 deletions(-) delete mode 100644 src/hirad/input_data/read_tp.py diff --git a/src/hirad/input_data/read_tp.py b/src/hirad/input_data/read_tp.py deleted file mode 100644 index 9953086..0000000 --- a/src/hirad/input_data/read_tp.py +++ /dev/null @@ -1,184 +0,0 @@ -import logging -import netCDF4 -import xarray -from anemoi.datasets import open_dataset -import numpy as np -import yaml - -import matplotlib.pyplot as plt -import cartopy.crs as ccrs -import cartopy.feature as cfeature -from matplotlib.colors import BoundaryNorm, ListedColormap - -from hirad.eval.plotting import plot_map_precipitation, plot_scores_vs_t -from hirad.eval.metrics import compute_mae, absolute_error - - -import interpolate_basic - -import sys -from pathlib import Path - -import os -print (os.getcwd()) - -sys.path.insert(0, Path(__file__).parent.as_posix()) - -ANEMOI_1H_FILENAME = "/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr" -ANEMOI_6H_FILENAME = "/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr" -COSMO_6H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr" -COSMO_1H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr" -COSMO_CONFIG_FILE="src/input_data/cosmo.yaml" -CDF_FILENAME = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-janfeb2020.nc" -GRIB_FILENAME = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/surface-janfeb2020.grib" - - -LAT = np.arange(-4.42, 3.36 + 0.02, 0.02) -LON = np.arange(-6.82, 4.80 + 0.02, 0.02) -RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone) - -def extract_grib_values(grib_data): - grib_lat = grib_data['latitude'][:] - grib_lon = grib_data['longitude'][:] - grib_t2m = grib_data['t2m'][:] - -def extract_lat_lon(data): - lat = data['latitude'][:] - lon = data['longitude'][:] - output_lat = np.zeros(len(lat)* len(lon)) - output_lon = np.zeros(len(lat) * len(lon)) - for i in range(len(lat)): - if i % 10 == 0: - print(i) - for j in range(len(lon)): - grid_index = i * len(lon) + j - output_lat[grid_index] = lat[i] - output_lon[grid_index] = lon[j] - return output_lat, output_lon - -def extract_values(data, variable): - values = data[variable][:] - return np.reshape(values, (values.shape[0], values.shape[1]*values.shape[2])) - -def extract_netcdf_values(netcdf_data): - netcdf_lat = netcdf_data['latitude'][:] - netcdf_lon = netcdf_data['longitude'][:] - netcdf_tp = netcdf_data['tp'][:,:] - logging.info(f'num nonzeros in tp is {np.count_nonzero(netcdf_tp)}') - values = np.zeros((netcdf_tp.shape[0], netcdf_tp.shape[1]*netcdf_tp.shape[2])) - latitudes = np.zeros(values.shape[1]) - longitudes = np.zeros(values.shape[1]) - # You could probably get this by reshaping, but I can't be bothered. - for i in range(len(netcdf_lat)): - if i % 10 == 0: - print(i) - for j in range(len(netcdf_lon)): - grid_index = i * len(netcdf_lon) + j - values[:,grid_index] = netcdf_tp[:,i,j] - latitudes[grid_index] = netcdf_lat[i] - longitudes[grid_index] = netcdf_lon[j] - reshape_values = np.reshape(netcdf_tp, (netcdf_tp.shape[0], netcdf_tp.shape[1]*netcdf_tp.shape[2])) - logging.info(f'Array equal? {np.array_equal(values, reshape_values)}') - return reshape_values, latitudes, longitudes - -def reshape_to_cosmo(vals): - return vals.reshape((len(LAT)-RELAX_ZONE*2, len(LON)-RELAX_ZONE*2)) - -root = logging.getLogger() -root.setLevel(logging.INFO) - -logging.info('loading data') -grib_data = xarray.load_dataset(GRIB_FILENAME) -netcdf_data = netCDF4.Dataset(CDF_FILENAME) -cosmo1 = open_dataset(COSMO_1H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') -cosmo6 = open_dataset(COSMO_6H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') -era1 = open_dataset(ANEMOI_1H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') -era6 = open_dataset(ANEMOI_6H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') -logging.info('loading data complete') - -output_grid= np.column_stack((cosmo1.longitudes, cosmo1.latitudes)) - -logging.info('processing netcdf data') -netcdf_latitudes, netcdf_longitudes = extract_lat_lon(netcdf_data) -netcdf_values = extract_values(netcdf_data, 'tp') -#netcdf_values, netcdf_latitudes, netcdf_longitudes = extract_netcdf_values(netcdf_data=netcdf_data) -netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) - -make_plots = True - -prev_netcdf_regrid = [] - -netcdf_error = np.zeros(cosmo1.dates.shape) -era_norm_error = np.zeros(cosmo1.dates.shape) -netcdf_early_error = np.zeros(cosmo1.dates.shape) -netcdf_late_error = np.zeros(cosmo1.dates.shape) - -for t in range(4): -#for t in range(len(cosmo1.dates)): - date = cosmo1.dates[t] - era_date = era1.dates[t] - if date != era_date: - logging.error('dates do not match: cosmo date: {date}, era date: {era_date}') - if date != netcdf_data['valid_time'][t]: - logging.error(f'dates do not match: cosmo date: {date}, netcdf: {netcdf_data["valid_time"][t]}') - - - # plot cosmo - if make_plots: - plot_map_precipitation(values=reshape_to_cosmo(cosmo1[t,:]), filename=f'plots/tp/{date}-cosmo1') - if t % 6 == 0: - if cosmo6.dates[t//6] != date: - logging.error(f'dates do not match: cosmo1: {date}, cosmo6: {cosmo6.dates[t//6]}') - plot_map_precipitation(values=reshape_to_cosmo(cosmo6[t//6,:]), filename=f'plots/tp/{date}-cosmo6') - - # plot netcdf - netcdf_vals = netcdf_values[t,:].reshape((1,1,netcdf_values.shape[1])) - netcdf_regrid=interpolate_basic.regrid(netcdf_vals, netcdf_grid, output_grid) - if make_plots: - plot_map_precipitation(reshape_to_cosmo(netcdf_regrid), f'plots/tp/{date}-netcdf-refactor') - - # plot era - era_grid = np.column_stack((era1.longitudes, era1.latitudes)) - era1_regrid = interpolate_basic.regrid(era1[t,:], era_grid, output_grid) - if make_plots: - plot_map_precipitation(reshape_to_cosmo(era1_regrid/6), f'plots/tp/{date}-era1-norm') - - #if t % 6 == 0: - # if era6.dates[t//6] != date: - # logging.error(f'dates do not match: era1: {date}, era6: {era6.dates[t//6]}') - # era6_regrid = interpolate_basic.regrid(era6[t//6,:], era_grid, output_grid) - # plot_map_precipitation(reshape_to_cosmo(era6_regrid), f'plots/tp/{date}-era6') - - era_norm_error[t] = np.mean(absolute_error(era1_regrid/6, cosmo1[t,:])) - netcdf_error[t] = np.mean(absolute_error(netcdf_regrid, cosmo1[t,:])) - logging.info(f'era norm error: {era_norm_error[t]} netcdf err: {netcdf_error[t]}') - if t>0: - netcdf_early_error[t] = np.mean(absolute_error(prev_netcdf_regrid, cosmo1[t,:])) - netcdf_late_error[t-1] = np.mean(absolute_error(netcdf_regrid, cosmo1[t-1,:])) - logging.info(f'netcdf early err: {netcdf_early_error[t]}, netcdf late err: {netcdf_late_error[t-1]}') - prev_netcdf_regrid = netcdf_regrid - -maes = {} -maes['era normalized'] = era_norm_error -maes['copernicus'] = netcdf_error -maes['copernicus-early'] = netcdf_early_error -maes['copernicus-late'] = netcdf_late_error -plot_scores_vs_t(maes, times=cosmo1.dates, filename='plots/errors.png') - -#fig = plt.figure() -#fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -#interpolate_basic.plot_projection(ax, longitudes=cosmo1.longitudes, latitudes=cosmo1.latitudes, values=cosmo1[0,:]) -#fig.savefig('cosmo1.png') - -#fig = plt.figure() -#fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -#interpolate_basic.plot_projection(ax, longitudes=cosmo1.longitudes, latitudes=cosmo1.latitudes, values=cosmo6[0,:]) -#fig.savefig('cosmo6.png') - - - - - - - - From 5d12c4ca859d1d2a45bb9be7509692a0bdaeb806 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 1 Sep 2025 12:25:31 +0200 Subject: [PATCH 148/189] Log config into training script logs --- src/hirad/training/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index e865bca..c1ca8b2 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -91,6 +91,7 @@ def main(cfg: DictConfig) -> None: fp16 = fp_optimizations == "fp16" enable_amp = fp_optimizations.startswith("amp") amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + logger0.info(f"Config is: {cfg}") logger0.info(f"Saving the outputs in {os.getcwd()}") checkpoint_dir = os.path.join( cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" From 0a2880118ae6b0681e720a87e5002386e76fa5e0 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 9 Sep 2025 13:51:59 +0200 Subject: [PATCH 149/189] Make InfiniteSampler start from where it left off when resuming training and not restart from beginning. --- src/hirad/datasets/dataset.py | 7 +++++-- src/hirad/training/train.py | 6 ++++-- src/hirad/utils/function_utils.py | 6 +++++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 1d36bc2..924951e 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -37,6 +37,7 @@ def init_train_valid_datasets_from_config( batch_size: int = 1, seed: int = 0, train_test_split: bool = True, + sampler_start_idx: int = 0, ) -> Tuple[ DownscalingDataset, Iterable, @@ -52,6 +53,7 @@ def init_train_valid_datasets_from_config( - batch_size (int): The number of samples in each batch of data. Defaults to 1. - seed (int): The random seed for dataset shuffling. Defaults to 0. - train_test_split (bool): A flag to determine whether to create a validation dataset. Defaults to True. + - sampler_start_idx (int): The initial index of the sampler to use for resuming training. Defaults to 0. Returns: - Tuple[base.DownscalingDataset, Iterable, Optional[base.DownscalingDataset], Optional[Iterable]]: A tuple containing the training dataset and iterator, and optionally the validation dataset and iterator if train_test_split is True. @@ -61,7 +63,7 @@ def init_train_valid_datasets_from_config( if 'validation_path' in config: del config['validation_path'] (dataset, dataset_iter) = init_dataset_from_config( - config, dataloader_cfg, batch_size=batch_size, seed=seed + config, dataloader_cfg, batch_size=batch_size, seed=seed, sampler_start_idx=sampler_start_idx, ) if train_test_split: valid_dataset_cfg = copy.deepcopy(dataset_cfg) @@ -81,6 +83,7 @@ def init_dataset_from_config( dataloader_cfg: Union[dict, None] = None, batch_size: int = 1, seed: int = 0, + sampler_start_idx: int = 0, ) -> Tuple[DownscalingDataset, Iterable]: dataset_cfg = copy.deepcopy(dataset_cfg) dataset_type = dataset_cfg.pop("type", "era5_cosmo") @@ -97,7 +100,7 @@ def init_dataset_from_config( dist = DistributedManager() dataset_sampler = InfiniteSampler( - dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed + dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed, start_idx=sampler_start_idx, ) dataset_iterator = iter( diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index e865bca..b6a8ca9 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -117,8 +117,9 @@ def main(cfg: DictConfig) -> None: cfg.training.hp.batch_size_per_gpu * dist.world_size ) + cur_nimg = load_checkpoint(path=checkpoint_dir) - set_seed(dist.rank) + set_seed(dist.rank + cur_nimg) configure_cuda_for_consistent_precision() # Instantiate the dataset @@ -138,8 +139,10 @@ def main(cfg: DictConfig) -> None: batch_size=cfg.training.hp.batch_size_per_gpu, seed=0, train_test_split=train_test_split, + sampler_start_idx=cur_nimg, ) logger0.info(f"Training on dataset with size {len(dataset)}") + logger0.info(f"Validating on dataset with size {len(validation_dataset)}") # Parse image configuration & update model args dataset_channels = len(dataset.input_channels()) @@ -172,7 +175,6 @@ def main(cfg: DictConfig) -> None: ) # Parse the patch shape - #TODO figure out patched diffusion and how to use it if ( cfg.model.name == "patched_diffusion" or cfg.model.name == "lt_aware_patched_diffusion" diff --git a/src/hirad/utils/function_utils.py b/src/hirad/utils/function_utils.py index 2da05db..6534782 100644 --- a/src/hirad/utils/function_utils.py +++ b/src/hirad/utils/function_utils.py @@ -133,6 +133,8 @@ class InfiniteSampler(torch.utils.data.Sampler[int]): # pragma: no cover window_size : float, default=0.5 Fraction of dataset to use as window for shuffling. Must be between 0 and 1. A larger window means more thorough shuffling but slower iteration. + start_idx : int, default=0 + The initial index to use for the sampler. This is used for resuming training. """ def __init__( @@ -143,6 +145,7 @@ def __init__( shuffle: bool = True, seed: int = 0, window_size: float = 0.5, + start_idx: int = 0, ): if not len(dataset) > 0: raise ValueError("Dataset must contain at least one item") @@ -159,6 +162,7 @@ def __init__( self.shuffle = shuffle self.seed = seed self.window_size = window_size + self.start_idx = start_idx def __iter__(self) -> Iterator[int]: order = np.arange(len(self.dataset)) @@ -169,7 +173,7 @@ def __iter__(self) -> Iterator[int]: rnd.shuffle(order) window = int(np.rint(order.size * self.window_size)) - idx = 0 + idx = self.start_idx while True: i = idx % order.size if idx % self.num_replicas == self.rank: From 19cc78b2498d063d01e0a44e3202fa1bfe5d504e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 10 Sep 2025 09:07:53 +0200 Subject: [PATCH 150/189] Move model compilation after checkpoint loading and fix checkpoint to strip off optimization layer. --- src/hirad/training/train.py | 18 ++++++------ src/hirad/utils/checkpoint.py | 53 +++++++++++++++++++++-------------- 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index b6a8ca9..0922717 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -284,8 +284,8 @@ def main(cfg: DictConfig) -> None: # Enable distributed data parallel if applicable if dist.world_size > 1: - if use_torch_compile: - model = torch.compile(model) + # if use_torch_compile: + # model = torch.compile(model) model = DistributedDataParallel( model, device_ids=[dist.local_rank], @@ -334,13 +334,6 @@ def main(cfg: DictConfig) -> None: else: regression_net = None - # Compile the model and regression net if applicable - if use_torch_compile: - if dist.world_size==1: - model = torch.compile(model) - if regression_net: - regression_net = torch.compile(regression_net) - # Compute the number of required gradient accumulation rounds # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size @@ -436,6 +429,13 @@ def main(cfg: DictConfig) -> None: except: cur_nimg = 0 + # Compile the model and regression net if applicable + if use_torch_compile: + if dist.world_size==1: + model = torch.compile(model) + if regression_net: + regression_net = torch.compile(regression_net) + # init the generator for inference to visualize checkpoint results if visualize_checkpoints: generator = Generator( diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index 070f8c3..28e6547 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -174,10 +174,15 @@ def save_checkpoint( Path(path).mkdir(parents=True, exist_ok=True) # == Saving model checkpoint == - if model: + if model is not None: if hasattr(model, "module"): # Strip out DDP layer model = model.module + + # Strip out optimization wrapper if exists + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + model = model._orig_mod + # Base name of model is meta.name unless pytorch model name = model.__class__.__name__ # Get full file path / name @@ -223,7 +228,7 @@ def save_checkpoint( def load_checkpoint( path: str, - model: torch.nn.Module, + model: torch.nn.Module = None, optimizer: Union[optimizer, None] = None, scheduler: Union[scheduler, None] = None, scaler: Union[scaler, None] = None, @@ -268,27 +273,33 @@ def load_checkpoint( ) return 0 - # == Loading model checkpoint == - if hasattr(model, "module"): - # Strip out DDP layer - model = model.module - # Base name of model is meta.name unless pytorch model - name = model.__class__.__name__ - # Get full file path / name - file_name = _get_checkpoint_filename( - path, name, index=epoch, - ) - if not Path(file_name).exists(): - checkpoint_logging.warning( - f"Could not find valid model file {file_name}, skipping load" + if model is not None: + # == Loading model checkpoint == + if hasattr(model, "module"): + # Strip out DDP layer + model = model.module + # Strip out optimization wrapper if exists + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + model = model._orig_mod + checkpoint_logging.warning( + f"Model {model.__class__.__name__} is already compiled, consider loading first and then compiling." + ) + name = model.__class__.__name__ + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, ) - else: - # Load state dictionary - model.load_state_dict(torch.load(file_name, map_location=device)) + if not Path(file_name).exists(): + checkpoint_logging.warning( + f"Could not find valid model file {file_name}, skipping load" + ) + else: + # Load state dictionary + model.load_state_dict(torch.load(file_name, map_location=device)) - checkpoint_logging.success( - f"Loaded model state dictionary {file_name} to device {device}" - ) + checkpoint_logging.success( + f"Loaded model state dictionary {file_name} to device {device}" + ) # == Loading training checkpoint == checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") From 1d4865fdcd343cf2c45ac10de0fb2f1e74488124 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 10 Sep 2025 15:08:51 +0200 Subject: [PATCH 151/189] Refactor image_batching and image_fuse to handle input tensor dtype conversion and remove _cast_type function --- src/hirad/utils/patching.py | 60 ++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/src/hirad/utils/patching.py b/src/hirad/utils/patching.py index 6f4bc4d..bd537af 100644 --- a/src/hirad/utils/patching.py +++ b/src/hirad/utils/patching.py @@ -591,14 +591,26 @@ def image_batching( ) # (padding_left,padding_right,padding_top,padding_bottom) input_padded = image_padding(input) patch_num = patch_num_x * patch_num_y + + # Cast to float for unfold + if input.dtype == torch.int32: + input_padded = input_padded.view(torch.float32) + elif input.dtype == torch.int64: + input_padded = input_padded.view(torch.float64) + x_unfold = torch.nn.functional.unfold( - input=input_padded.view(_cast_type(input_padded)), # Cast to float + input=input_padded, kernel_size=(patch_shape_y, patch_shape_x), stride=( patch_shape_y - overlap_pix - boundary_pix, patch_shape_x - overlap_pix - boundary_pix, ), - ).to(input_padded.dtype) + ) + + # Cast back to original dtype + if input.dtype in [torch.int32, torch.int64]: + x_unfold = x_unfold.view(input.dtype) + x_unfold = rearrange( x_unfold, "b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w", @@ -608,16 +620,7 @@ def image_batching( nb_p_w=patch_num_x, ) if input_interp is not None: - input_interp_repeated = rearrange( - torch.repeat_interleave( - input=input_interp, - repeats=patch_num, - dim=0, - output_size=x_unfold.shape[0], - ), - "(b p) c h w -> (p b) c h w", - p=patch_num, - ) + input_interp_repeated = input_interp.repeat(patch_num, 1, 1, 1) return torch.cat((x_unfold, input_interp_repeated), dim=1) else: return x_unfold @@ -722,6 +725,13 @@ def image_fuse( nb_p_h=patch_num_y, nb_p_w=patch_num_x, ) + + # Cast to float for fold + if input.dtype == torch.int32: + x = x.view(torch.float32) + elif input.dtype == torch.int64: + x = x.view(torch.float64) + # Stitch patches together (by summing over overlapping patches) x_folded = torch.nn.functional.fold( input=x, @@ -733,6 +743,10 @@ def image_fuse( ), ) + # Cast back to original dtype + if input.dtype in [torch.int32, torch.int64]: + x_folded = x_folded.view(input.dtype) + # Remove padding x_no_padding = x_folded[ ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x @@ -743,25 +757,3 @@ def image_fuse( # Normalize by overlap count return x_no_padding / overlap_count_no_padding - - -def _cast_type(input: Tensor) -> torch.dtype: - """Return float type based on input tensor type. - - Parameters - ---------- - input : Tensor - Input tensor to determine float type from - - Returns - ------- - torch.dtype - Float type corresponding to input tensor type for int32/64, - otherwise returns original dtype - """ - if input.dtype == torch.int32: - return torch.float32 - elif input.dtype == torch.int64: - return torch.float64 - else: - return input.dtype From 4b574e8a506b5f564e244e893158b2150c82fc07 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 10 Sep 2025 15:23:21 +0200 Subject: [PATCH 152/189] Validate max_patch_per_gpu against batch_size_per_gpu --- src/hirad/training/train.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 0922717..84941b4 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -345,14 +345,17 @@ def main(cfg: DictConfig) -> None: batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") - patch_num = getattr(cfg.training.hp, "patch_num", 1) - max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", 1) - # calculate patch per iter - if hasattr(cfg.training.hp, "max_patch_per_gpu") and max_patch_per_gpu > 1: + patch_num = getattr(cfg.training.hp, "patch_num", 1) + if hasattr(cfg.training.hp, "max_patch_per_gpu"): + max_patch_per_gpu = cfg.training.hp.max_patch_per_gpu + if max_patch_per_gpu // batch_size_per_gpu < 1: + raise ValueError( + f"max_patch_per_gpu ({max_patch_per_gpu}) must be greater or equal to batch_size_per_gpu ({batch_size_per_gpu})." + ) max_patch_num_per_iter = min( patch_num, (max_patch_per_gpu // batch_size_per_gpu) - ) # Ensure at least 1 patch per iter + ) patch_iterations = ( patch_num + max_patch_num_per_iter - 1 ) // max_patch_num_per_iter @@ -360,7 +363,7 @@ def main(cfg: DictConfig) -> None: min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) for i in range(patch_iterations) ] - print( + logger0.info( f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}" ) else: From 40d1792a1cee1e8ef3253ce802b699648452142a Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 16 Sep 2025 12:21:55 +0200 Subject: [PATCH 153/189] Add method to load static variables --- interpolate.sh | 2 +- src/hirad/input_data/interpolate_basic.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/interpolate.sh b/interpolate.sh index 607ad66..73e947d 100755 --- a/interpolate.sh +++ b/interpolate.sh @@ -3,4 +3,4 @@ #SBATCH --partition=postproc #SBATCH --time=23:59:00 -python src/input_data/interpolate_basic.py src/input_data/era-all.yaml src/input_data/cosmo-all.yaml /store_new/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-all-channels/ +python src/hirad/input_data/interpolate_basic.py src/hirad/input_data/era-all.yaml src/hirad/input_data/cosmo-all.yaml /store_new/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-all-channels/ diff --git a/src/hirad/input_data/interpolate_basic.py b/src/hirad/input_data/interpolate_basic.py index 7ec5801..e5953a1 100644 --- a/src/hirad/input_data/interpolate_basic.py +++ b/src/hirad/input_data/interpolate_basic.py @@ -14,6 +14,7 @@ from scipy.interpolate import griddata import torch import multiprocessing +import xarray # Margin to use for ERA dataset (to avoid nans from interpolation at boundary) ERA_MARGIN_DEGREES = 1.0 @@ -47,7 +48,6 @@ def _read_input(era_config_file: str, cosmo_config_file: str, bound_to_cosmo_are return (era, cosmo) - def regrid(era_for_time: np.ndarray, input_grid: np.ndarray, output_grid: np.ndarray): # shape (channel, ensemble, grid) interpolated_data = np.empty([era_for_time.shape[0], 1, output_grid.shape[0]]) @@ -234,6 +234,12 @@ def plot_tp(path_6h: str, path_1h: str): plt.savefig(filename) plt.close('all') +def load_static(infile_era: str, infile_cosmo: str, output_directory: str): + _, cosmo = _read_input(infile_era, infile_cosmo, bound_to_cosmo_area=True) + + torch.save(cosmo[0,:,:,:], os.path.join(output_directory, 'cosmo-static')) + shutil.copy(infile_cosmo, os.path.join(output_directory, "cosmo-static.yaml")) + def main(): # TODO: Do better arg parsing so it's not as easy to reverse era and cosmo configs. if len(sys.argv) < 4: @@ -242,13 +248,15 @@ def main(): infile_cosmo = sys.argv[2] output_directory = sys.argv[3] - logging.basicConfig( filename=os.path.join(output_directory, 'interpolate_basic.log'), format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') - interpolate_and_save(infile_era, infile_cosmo, output_directory, threaded=False, outfile_plots_path=os.path.join(output_directory, "plots/")) + + #load_static(infile_era, infile_cosmo, output_directory) + #interpolate_and_save(infile_era, infile_cosmo, output_directory, threaded=False, outfile_plots_path=os.path.join(output_directory, "plots/")) + interpolate_and_save(infile_era, infile_cosmo, output_directory, threaded=False, outfile_plots_path=None) if __name__ == "__main__": main() From 4d6b9cbf189da3ddd5d5a86c402d575962ee708e Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 16 Sep 2025 12:22:04 +0200 Subject: [PATCH 154/189] Add static variable config --- src/hirad/input_data/cosmo-static.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 src/hirad/input_data/cosmo-static.yaml diff --git a/src/hirad/input_data/cosmo-static.yaml b/src/hirad/input_data/cosmo-static.yaml new file mode 100644 index 0000000..9bece97 --- /dev/null +++ b/src/hirad/input_data/cosmo-static.yaml @@ -0,0 +1,8 @@ +dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr' +select: ['hsurf', 'lsm',] +# Static cosmo channels +trim_edge: 19 # Removes boundary +start: 2016-01-01 +# start: 2015-11-29 +end: 2016-01-01 +# end: 2020-12-31 \ No newline at end of file From 3ed02fa342f2a574106496f2243394bf5c004ae8 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 18 Sep 2025 12:41:45 +0200 Subject: [PATCH 155/189] Add a script to check input data for missing/corrupt/nan data --- src/hirad/input_data/regrid_copernicus_tp.py | 149 ++++++++++++------- src/hirad/input_data/regrid_copernicus_tp.sh | 9 ++ src/hirad/input_data/test_input_data.py | 88 +++++++++++ 3 files changed, 195 insertions(+), 51 deletions(-) create mode 100755 src/hirad/input_data/regrid_copernicus_tp.sh create mode 100644 src/hirad/input_data/test_input_data.py diff --git a/src/hirad/input_data/regrid_copernicus_tp.py b/src/hirad/input_data/regrid_copernicus_tp.py index 93a142f..7e3f14f 100644 --- a/src/hirad/input_data/regrid_copernicus_tp.py +++ b/src/hirad/input_data/regrid_copernicus_tp.py @@ -43,8 +43,8 @@ def extract_grib_values(grib_data): def extract_lat_lon_025(data): logging.info('extracting lat/lon') - lat = data['latitudes'][:] - lon = data['longitudes'][:] + lat = data['latitude'][:] + lon = data['longitude'][:] output_lat = np.zeros(len(lat)* len(lon)) output_lon = np.zeros(len(lat) * len(lon)) for i in range(len(lat)): @@ -75,8 +75,7 @@ def extract_values(data, variable, area=None): lonui = np.argmin( np.abs(lon - area[3])) lat = data['latitude'][latli:latui] lon = data['longitude'][lonli:lonui] - - values = data[variable][latli:latui,lonli:lonui] + values = data[variable][latli:latui,lonli:lonui] return np.reshape(values, (values.shape[0], values.shape[1]*values.shape[2])) def reshape_to_cosmo(vals): @@ -142,22 +141,48 @@ def calc_errors(cosmo1, era1): maes['copernicus-late'] = netcdf_late_error plot_scores_vs_t(maes, times=cosmo1.dates, filename='plots/errors.png') -def process_era_interpolated(netcdf_data, netcdf_tp_values): - #for t in range(10): +def process_era_interpolated(netcdf_data, netcdf_tp_values, input_data_filepath, output_interpolated_filepath, netcdf_grid, cosmo_grid): + make_plots = False + #for t in range(100): for t in range(netcdf_tp_values.shape[0]): netcdf_date = netcdf_data['valid_time'][t] date_filename = datetime.datetime.fromtimestamp(netcdf_date, datetime.UTC).strftime('%Y%m%d-%H%M') t1 = datetime.datetime.now() - era_filename = os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, "era-interpolated", date_filename) + era_filename = os.path.join(input_data_filepath, "era-interpolated", date_filename) + output_filename = os.path.join(output_interpolated_filepath, date_filename) if os.path.exists(era_filename): - era_data = torch.load(era_filename, weights_only=False) - t2 = datetime.datetime.now() - logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})') - interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], cosmo_grid, method='linear') - t3 = datetime.datetime.now() - era_data[TP_INDEX,0,:] = interpolated_tp - torch.save(era_data, os.path.join(OUTPUT_DATA_FILEPATH_ERA_INTERPOLATED, date_filename)) - t4 = datetime.datetime.now() + requires_processing = False + if date_filename == '20200615-2100': + requires_processing = True + #if os.path.exists(output_filename): + # test the output to make sure it is not corrupted. + #requires_processing = False + #try: + # torch.load(output_filename, weights_only=False) + #except: + # requires_processing = True + if requires_processing: + era_data = torch.load(era_filename, weights_only=False) + t2 = datetime.datetime.now() + #if t % 100 == 0: + logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})') + interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], cosmo_grid, method='linear') + t3 = datetime.datetime.now() + if make_plots and max > 0.0002: + nans = np.count_nonzero(np.isnan(interpolated_tp)) + nonzeros = np.count_nonzero(interpolated_tp) + max = np.max(interpolated_tp) + logging.info(f'nonzeros: {nonzeros} nans: {nans} max: {max}') + if max > 0.0002: + cosmo_filename = os.path.join(input_data_filepath, "cosmo", date_filename) + cosmo_data = torch.load(cosmo_filename, weights_only=False) + plot_map_precipitation(reshape_to_cosmo(interpolated_tp), f'plots/tp/{date_filename}-netcdf-regrid') + # 3 is tp index in cosmo data + plot_map_precipitation(reshape_to_cosmo(cosmo_data[3,:]), f'plots/tp/{date_filename}-cosmo') + plot_map_precipitation(reshape_to_cosmo(era_data[TP_INDEX,0,:]/6), f'plots/tp/{date_filename}-era-norm') + era_data[TP_INDEX,0,:] = interpolated_tp + torch.save(era_data, output_filename) + t4 = datetime.datetime.now() def process_era(netcdf_data, netcdf_tp_values): for t in range(netcdf_tp_values.shape[0]): @@ -174,42 +199,64 @@ def process_era(netcdf_data, netcdf_tp_values): torch.save(era_data, os.path.join(OUTPUT_DATA_FILEPATH_ERA, date_filename)) t4 = datetime.datetime.now() -root = logging.getLogger() -root.setLevel(logging.INFO) - -logging.info('loading data') -netcdf_data = netCDF4.Dataset(CDF_FILENAME_CLARIDEN_TP) - -#cosmo1 = open_dataset(COSMO_1H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') -#cosmo6 = open_dataset(COSMO_6H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29') -#era1 = open_dataset(ANEMOI_1H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') -#era6 = open_dataset(ANEMOI_6H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29') -logging.info('loading data complete') - -cosmo_grid = torch.load(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'info/cosmo-lat-lon'), weights_only=False) -era_grid = torch.load(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'info/era-lat-lon'), weights_only=False) - - -logging.info('processing netcdf data') - -netcdf_latitudes, netcdf_longitudes = extract_lat_lon_n320(netcdf_data) -netcdf_tp_values = extract_values(netcdf_data, 'tp') - -#netcdf_latitudes, netcdf_longitudes = extract_lat_lon(netcdf_data, area=[60, 0, 40, 20]) -#netcdf_tp_values = extract_values(netcdf_data, 'tp', area=[60, 0, 40, 20]) -netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) -logging.info(f'netcdf shape {netcdf_grid.shape}') -logging.info(f'{netcdf_grid[1:10,:]}') -logging.info(f'era shape {era_grid.shape}') -logging.info(f'{era_grid[1:10,:]}') - -process_era_interpolated(netcdf_data, netcdf_tp_values) -process_era(netcdf_data, netcdf_tp_values) - - - - - +def make_stats(): + #cosmo_files = os.listdir(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'cosmo')) + #era_files = os.listdir(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'cosmo')) + + stats = torch.load(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'info', 'era-stats'), weights_only=False) + print(stats) + set1 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2015-2016.nc") + set2 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2017-2018.nc") + set3 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc") + set1_tp = extract_values(set1, 'tp') + set2_tp = extract_values(set2, 'tp') + set3_tp = extract_values(set3, 'tp') + all_tp = np.row_stack((set1_tp, set2_tp, set3_tp)) + print(all_tp.shape) + all_tp = all_tp.reshape(all_tp.shape[0] * all_tp.shape[1], 1) + mean = np.mean(all_tp) + max = np.max(all_tp) + min = np.min(all_tp) + stdev = np.std(all_tp) + stats['mean'][TP_INDEX] = mean + stats['maximum'][TP_INDEX] = max + stats['minimum'][TP_INDEX] = min + stats['stdev'][TP_INDEX] = stdev + print(stats) + torch.save(stats, os.path.join(OUTPUT_DATA_FILEPATH_ERA_INTERPOLATED, 'era-stats')) + + +#process_era(netcdf_data, netcdf_tp_values) + + +def main(): + root = logging.getLogger() + root.setLevel(logging.INFO) + + logging.info('loading data') + netcdf_file = sys.argv[1] + input_data_filepath = sys.argv[2] + output_interpolated_filepath = sys.argv[3] + + netcdf_data = netCDF4.Dataset(netcdf_file) + logging.info(netcdf_data) + + logging.info('processing netcdf data') + netcdf_latitudes, netcdf_longitudes = extract_lat_lon_025(netcdf_data) + netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) + cosmo_grid = torch.load(os.path.join(input_data_filepath, 'info/cosmo-lat-lon'), weights_only=False) + cosmo_grid = np.column_stack((cosmo_grid[:,1], cosmo_grid[:,0])) + logging.info(f'netcdf grid shape {netcdf_grid.shape}') + logging.info(f'{netcdf_grid[1:10,:]}') + logging.info(f'cosmo grid shape {cosmo_grid.shape}') + logging.info(f'{cosmo_grid[1:10,:]}') + + netcdf_tp_values = extract_values(netcdf_data, 'tp') + + process_era_interpolated(netcdf_data, netcdf_tp_values, input_data_filepath, output_interpolated_filepath, netcdf_grid, cosmo_grid) + +if __name__ == "__main__": + main() diff --git a/src/hirad/input_data/regrid_copernicus_tp.sh b/src/hirad/input_data/regrid_copernicus_tp.sh new file mode 100755 index 0000000..f7fc550 --- /dev/null +++ b/src/hirad/input_data/regrid_copernicus_tp.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +pip install -e . +pip install anemoi.datasets + +python src/hirad/input_data/regrid_copernicus_tp.py \ + /capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc \ + /capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/ \ + /capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp/ \ No newline at end of file diff --git a/src/hirad/input_data/test_input_data.py b/src/hirad/input_data/test_input_data.py new file mode 100644 index 0000000..fa13491 --- /dev/null +++ b/src/hirad/input_data/test_input_data.py @@ -0,0 +1,88 @@ +import logging +import os +import sys + +import datetime +import torch +import numpy as np + + +from hirad.eval.plotting import plot_map_precipitation, plot_scores_vs_t + +def load_all_data(filepath: str): + files = os.listdir(filepath) + example = torch.load(os.path.join(filepath, files[0]), weights_only=False) + dims = (len(files),) + example.shape + data = np.zeros(dims) + for f in range(100): + #for f in range(len(files)): + if f % 100 == 0: + logging.info(f) + curr = torch.load(os.path.join(filepath, files[f]), weights_only=False) + data[f,:] = curr + return data + +def count_nans(data: np.array): + nans = np.count_nonzero(np.isnan(data)) + return nans + +def make_stats(filepath: str): + data = load_all_data(filepath) + stats = {} + num_channels = data.shape[1] + stats['mean'] = np.zeros(num_channels) + stats['stdev'] = np.zeros(num_channels) + stats['minimum'] = np.zeros(num_channels) + stats['maximum'] = np.zeros(num_channels) + for k in range(num_channels): + logging.info(f'channel {k}') + stats['mean'][k] = np.mean(data[:,k,:,:]) + stats['minimum'][k] = np.min(data[:,k,:,:]) + stats['maximum'][k] = np.max(data[:,k,:,:]) + stats['stdev'][k] = np.std(data[:,k,:,:]) + return stats + +def main(): + root = logging.getLogger() + root.setLevel(logging.INFO) + logging.info('starting main') + input_directory = sys.argv[1] + + missing_data = [] + corrupt_data = [] + nan_data = [] + check_for_nans = False + + + + files = os.listdir(input_directory) + files.sort() + start_date = datetime.datetime.strptime(files[0],'%Y%m%d-%H%M') + next_date = datetime.datetime.strptime(files[1],'%Y%m%d-%H%M') + delta = next_date - start_date + prev_date = start_date - delta + + for f in files: + curr_date = datetime.datetime.strptime(f,'%Y%m%d-%H%M') + if curr_date - prev_date != delta: + logging.info(f'missing data: {prev_date} and {curr_date} not {delta} apart') + expected_date = prev_date + delta + while (expected_date < curr_date): + missing_data.append(datetime.datetime.strftime(expected_date, '%Y%m%d-%H%M')) + expected_date = expected_date + delta + try: + data = torch.load(os.path.join(input_directory, f), weights_only=False) + except: + logging.info(f'corrupt data: {curr_date}') + corrupt_data.append(curr_date) + if check_for_nans: + if count_nans(data): + logging.info(f'data nans: {curr_date}') + nan_data.append(curr_date) + prev_date = curr_date + logging.info(f'missing data: {missing_data}') + logging.info(f'corrupt data: {corrupt_data}') + logging.info(f'nan data: {nan_data}') + +if __name__ == "__main__": + main() From 9d858ebff380b3ead65f014517ae85f06b6e2eac Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 19 Sep 2025 17:25:39 +0200 Subject: [PATCH 156/189] Enhance ERA5_COSMO dataset class with static channels option and Box-Cox transformation support for total precipitation; update histogram and exceedance scripts to remove baseline scaling factor. --- src/hirad/datasets/era5_cosmo.py | 87 +++++++++++++++++-- .../diurnal_cycle_precip_mean_wet-hour.py | 2 +- src/hirad/eval/diurnal_cycle_precip_p99.py | 4 +- src/hirad/eval/hist.py | 6 +- src/hirad/eval/map_precip_stats.py | 4 +- src/hirad/eval/probability_of_exceedance.py | 6 +- src/hirad/eval/snapshots.py | 8 +- 7 files changed, 98 insertions(+), 19 deletions(-) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 76d1982..6957877 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -15,9 +15,35 @@ def __init__(self, dataset_path: str, input_channel_names: List[str] = [], outpu self._era5_path = os.path.join(dataset_path, 'era-interpolated') self._cosmo_path = os.path.join(dataset_path, 'cosmo') self._info_path = os.path.join(dataset_path, 'info') + self._static_path = '/capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/static'# os.path.join(dataset_path, 'static') + self._zarr_path = os.path.join(dataset_path, 'dataset.zarr') # load file list (each file is one date-time state) - self._file_list = os.listdir(self._cosmo_path) + self._file_list = sorted(os.listdir(self._cosmo_path)) + + # open zarr store + # self._zarr_store = zarr.open(self._zarr_path, mode='r') + # self.era5 = self._zarr_store['era5'] + # self.cosmo = self._zarr_store['cosmo'] + + # Load static info and channel names + if static_channel_names: + with open(os.path.join(self._static_path, 'cosmo-static.yaml'), 'r') as file: + self._static_info = yaml.safe_load(file) + self._static_indeces = [self._static_info['select'].index(name) for name in static_channel_names] + self._static_channels = [ChannelMetadata(name) if len(name.split('_'))==1 + else ChannelMetadata(name.split('_')[0],name.split('_')[1]) + for name in self._static_info['select'] if name in static_channel_names] + static_data = torch.load(os.path.join(self._static_path,'cosmo-static'), weights_only=False)[self._static_indeces] + orig_shape = self.image_shape() + self.static_data = np.flip(static_data \ + .squeeze() \ + .reshape(-1,*orig_shape), + 1) + self.static_mean = self.static_data.mean(axis=(1,2)) + self.static_std = self.static_data.std(axis=(1,2)) + else: + self.static_data = None # Load cosmo info and channel names with open(os.path.join(self._info_path,'cosmo.yaml'), 'r') as file: @@ -52,7 +78,37 @@ def __init__(self, dataset_path: str, input_channel_names: List[str] = [], outpu era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False) self.input_mean = era_stats['mean'][self._era_indeces] self.input_std = era_stats['stdev'][self._era_indeces] + if self.static_data is not None: + self.input_mean = np.concatenate((self.input_mean, self.static_mean), axis=0) + self.input_std = np.concatenate((self.input_std, self.static_std), axis=0) + + # FEATURE: load the mean and std values for transformed channels and update the normalization statistics + self.input_transforms = {} + self.input_inverse_transforms = {} + self.output_transforms = {} + self.output_inverse_transforms = {} + for transform_descriptor in transform_channels: + channel, transformation = transform_descriptor.split('-') + input_channel_idx = self._era_info['select'].index(channel) if channel in self._era_info['select'] else None + output_channel_idx = self._cosmo_info['select'].index(channel) if channel in self._cosmo_info['select'] else None + if transformation.startswith('box_cox'): + lmbda_str = transformation.split('_')[-1] + lmbda = float(transformation.split('_')[-1])/(10**(len(lmbda_str)-1)) + print(f"Applying Box-Cox transformation with lambda={lmbda} to channel {channel} (input idx: {input_channel_idx}, output idx: {output_channel_idx})") + if input_channel_idx is not None: + self.input_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda) + self.input_inverse_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda) + self.input_mean[input_channel_idx] = torch.load(os.path.join("/users/pstamenk/HiRAD-Gen",f"era5-{transform_descriptor}-mean"), weights_only=False) + self.input_std[input_channel_idx] = torch.load(os.path.join("/users/pstamenk/HiRAD-Gen",f"era5-{transform_descriptor}-std"), weights_only=False) + if output_channel_idx is not None: + self.output_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda) + self.output_inverse_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda) + self.output_mean[output_channel_idx] = torch.load(os.path.join("/users/pstamenk/HiRAD-Gen",f"cosmo-{transform_descriptor}-mean"), weights_only=False) + self.output_std[output_channel_idx] = torch.load(os.path.join("/users/pstamenk/HiRAD-Gen",f"cosmo-{transform_descriptor}-std"), weights_only=False) + else: + raise ValueError(f"Transformation: {transformation} for channel {channel} not implemented.") + def __getitem__(self, idx): """Get cosmo and era5 interpolated to cosmo grid""" # get data point @@ -66,6 +122,7 @@ def __getitem__(self, idx): .squeeze() \ .reshape(-1,*orig_shape), 1) + era5_data = np.concatenate((era5_data, self.static_data), axis=0) if self.static_data is not None else era5_data era5_data = self.normalize_input(era5_data) cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)[self._cosmo_indeces] @@ -96,7 +153,7 @@ def latitude(self) -> np.ndarray: def input_channels(self) -> List[ChannelMetadata]: """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" - return self._era_channels + return self._era_channels + self._static_channels if self.static_data is not None else self._era_channels def output_channels(self) -> List[ChannelMetadata]: @@ -118,23 +175,43 @@ def image_shape(self) -> Tuple[int, int]: def normalize_input(self, x: np.ndarray) -> np.ndarray: """Convert input from physical units to normalized data.""" + for channel_idx, transform in self.input_transforms.items(): + x[channel_idx,::] = transform(x[channel_idx,::]) return (x - self.input_mean.reshape((self.input_mean.shape[0],1,1))) \ / self.input_std.reshape((self.input_std.shape[0],1,1)) def denormalize_input(self, x: np.ndarray) -> np.ndarray: """Convert input from normalized data to physical units.""" - return x * self.input_std.reshape((self.input_std.shape[0],1,1)) \ + x = x * self.input_std.reshape((self.input_std.shape[0],1,1)) \ + self.input_mean.reshape((self.input_mean.shape[0],1,1)) + for channel_idx, inverse_transform in self.input_inverse_transforms.items(): + x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::]) + return x def normalize_output(self, x: np.ndarray) -> np.ndarray: """Convert output from physical units to normalized data.""" + for channel_idx, transform in self.output_transforms.items(): + x[channel_idx,::] = transform(x[channel_idx,::]) return (x - self.output_mean.reshape((self.output_mean.shape[0],1,1))) \ / self.output_std.reshape((self.output_std.shape[0],1,1)) def denormalize_output(self, x: np.ndarray) -> np.ndarray: """Convert output from normalized data to physical units.""" - return x * self.output_std.reshape((self.output_std.shape[0],1,1)) \ - + self.output_mean.reshape((self.output_mean.shape[0],1,1)) \ No newline at end of file + x = x * self.output_std.reshape((self.output_std.shape[0],1,1)) \ + + self.output_mean.reshape((self.output_mean.shape[0],1,1)) + for channel_idx, inverse_transform in self.output_inverse_transforms.items(): + x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::]) + return x + + def box_cox_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarray: + """Apply Box-Cox transformation to the data.""" + channel_array = np.clip(channel_array, 0, None) + return (np.power(channel_array, lmbda) - 1) / lmbda + + def box_cox_inverse_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarray: + """Apply inverse Box-Cox transformation to the data.""" + channel_array = np.clip(channel_array, -1/lmbda, None) + return np.power((lmbda * channel_array) + 1, 1 / lmbda) \ No newline at end of file diff --git a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py index 988024a..03e7d70 100644 --- a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py +++ b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py @@ -72,7 +72,7 @@ def main(cfg: DictConfig): for idx, ts in enumerate(times, 1): dt = datetimes[idx-1] target = torch.load(out_root/ts/f"{ts}-target", weights_only=False)[tp_out] * CONV_FACTOR - baseline = torch.load(out_root/ts/f"{ts}-baseline", weights_only=False)[tp_in] * CONV_FACTOR / 6. # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset + baseline = torch.load(out_root/ts/f"{ts}-baseline", weights_only=False)[tp_in] * CONV_FACTOR # / 6. # 6 because 1h -> accumulation period is 6h in hourly ERA5 dataset preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False)[:, tp_out, :, :] * CONV_FACTOR try: mean_pred = torch.load(out_root/ts/f"{ts}-regression-prediction", weights_only=False)[tp_out] * CONV_FACTOR diff --git a/src/hirad/eval/diurnal_cycle_precip_p99.py b/src/hirad/eval/diurnal_cycle_precip_p99.py index 316a387..9c54eb7 100644 --- a/src/hirad/eval/diurnal_cycle_precip_p99.py +++ b/src/hirad/eval/diurnal_cycle_precip_p99.py @@ -103,8 +103,8 @@ def main(cfg: DictConfig): hourly_p99 = da.groupby('time.hour').quantile(0.99, dim='time') # Apply scaling factor for baseline - if mode == 'baseline': - hourly_p99 = hourly_p99 / 6.0 + # if mode == 'baseline': + # hourly_p99 = hourly_p99 / 6.0 pct99_mean[mode] = hourly_p99.mean(dim=['lat', 'lon']) diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py index 183dee8..78d7395 100644 --- a/src/hirad/eval/hist.py +++ b/src/hirad/eval/hist.py @@ -131,7 +131,7 @@ def main(cfg: DictConfig): land_mask = load_land_sea_mask() # Define histogram bins - bins = np.logspace(-1, 1, 50) # Log-spaced bins for precipitation + bins = np.logspace(-1, 3.3, 200) # Log-spaced bins for precipitation # Storage for histogram data and land values hist_data = {} @@ -153,8 +153,8 @@ def main(cfg: DictConfig): data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target', 'regression-prediction'] else tp_in] * CONV_FACTOR_HOURLY * land_mask # Apply scaling factor for baseline - if mode == 'baseline': - data = data / 6.0 + # if mode == 'baseline': + # data = data / 6.0 land_values = data.values[~np.isnan(data.values)] all_values.extend(land_values) diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py index 47e4125..d17b8f9 100644 --- a/src/hirad/eval/map_precip_stats.py +++ b/src/hirad/eval/map_precip_stats.py @@ -152,8 +152,8 @@ def main(cfg: DictConfig): dims=['time', 'lat', 'lon'], coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} ) - if mode == 'baseline': - mode_data = mode_data / 6.0 + # if mode == 'baseline': + # mode_data = mode_data / 6.0 # Compute and plot all statistics for this mode for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for {mode}...") diff --git a/src/hirad/eval/probability_of_exceedance.py b/src/hirad/eval/probability_of_exceedance.py index e301016..f56b551 100644 --- a/src/hirad/eval/probability_of_exceedance.py +++ b/src/hirad/eval/probability_of_exceedance.py @@ -128,7 +128,7 @@ def main(cfg: DictConfig): land_mask = load_land_sea_mask() # Define thresholds for exceedance calculation - thresholds = np.logspace(-2, 2, 200) # From 0.01 to 100 mm/h + thresholds = np.logspace(-2, 2.1, 200) # From 0.01 to 100 mm/h # Storage for exceedance data and land values exceedance_data = {} @@ -148,8 +148,8 @@ def main(cfg: DictConfig): data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * CONV_FACTOR_HOURLY * land_mask # Apply scaling factor for baseline - if mode == 'baseline': - data = data / 6.0 + # if mode == 'baseline': + # data = data / 6.0 land_values = data.values[~np.isnan(data.values)] all_values.extend(land_values) diff --git a/src/hirad/eval/snapshots.py b/src/hirad/eval/snapshots.py index 08d6e6b..f6c7892 100644 --- a/src/hirad/eval/snapshots.py +++ b/src/hirad/eval/snapshots.py @@ -29,7 +29,7 @@ class ChannelMeta: vmin: float = None vmax: float = None extend: str = "both" - precip_kwargs: dict = field(default_factory=lambda: {"threshold": 0.1, "rfac": 100.0}) + precip_kwargs: dict = field(default_factory=lambda: {"threshold": 0.1, "rfac": 1000.0}) @classmethod def get(cls, ch_or_name: "ChannelMeta | str | None", *, vmin=None, vmax=None) -> "ChannelMeta": @@ -40,7 +40,7 @@ def get(cls, ch_or_name: "ChannelMeta | str | None", *, vmin=None, vmax=None) -> return base CHANNELS = { - "tp": ChannelMeta(name="tp", cmap=None, unit="mm/h", extend="max", precip_kwargs={"threshold": 0.1, "rfac": 100.0}), + "tp": ChannelMeta(name="tp", cmap=None, unit="mm/h", extend="max", precip_kwargs={"threshold": 0.1, "rfac": 1000.0}), "2t": ChannelMeta(name="2t", cmap="RdYlBu_r", me_cmap="RdBu", unit="K", err_vmin=-4.5, err_vmax=4.5), "10u": ChannelMeta(name="10u", cmap="BrBG", me_cmap="BrBG", unit="m/s", err_vmin=-10, err_vmax=10, vmin=-10, vmax=10), "10v": ChannelMeta(name="10v", cmap="BrBG", me_cmap="BrBG", unit="m/s", err_vmin=-10, err_vmax=10, vmin=-10, vmax=10), @@ -124,7 +124,9 @@ def main(cfg: DictConfig) -> None: logger = logging.getLogger("plot_maps") if cfg.generation.times_range: - times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") + else: + times = cfg.generation.times dataset_cfg = OmegaConf.to_container(cfg.dataset) has_lead_time = cfg.generation.get("has_lead_time", False) From caaa424e4685dcd49bf1cbdf171200fce0c4ba0f Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 19 Sep 2025 17:30:29 +0200 Subject: [PATCH 157/189] Add lead time label param in RegressionLoss for consistency and refactor noise parameter computation in ResidualLoss; enhance logging in training script for dataset details. --- src/hirad/losses/loss.py | 74 +++++++++++++++++++++++++++++-------- src/hirad/models/unet.py | 44 ++++++++++++++++++++-- src/hirad/training/train.py | 4 +- 3 files changed, 102 insertions(+), 20 deletions(-) diff --git a/src/hirad/losses/loss.py b/src/hirad/losses/loss.py index fb65960..2c2123b 100644 --- a/src/hirad/losses/loss.py +++ b/src/hirad/losses/loss.py @@ -380,6 +380,7 @@ def __call__( augment_pipe: Optional[ Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] ] = None, + lead_time_label: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Calculate and return the regression loss for @@ -439,7 +440,23 @@ def __call__( y_lr = y_tot[:, img_clean.shape[1] :, :, :] zero_input = torch.zeros_like(y, device=img_clean.device) - D_yn = net(zero_input, y_lr, force_fp32=False, augment_labels=augment_labels) + + if lead_time_label is not None: + D_yn = net( + zero_input, + y_lr, + force_fp32=False, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + zero_input, + y_lr, + force_fp32=False, + augment_labels=augment_labels, + ) + loss = weight * ((D_yn - y) ** 2) return loss @@ -518,6 +535,32 @@ def __init__( self.hr_mean_conditioning = hr_mean_conditioning self.y_mean = None + def get_noise_params(self, y: torch.Tensor) -> torch.Tensor: + """ + Compute the noise parameters to apply denoising score matching. + + Parameters + ---------- + y : torch.Tensor + Latent state of shape :math:`(B, *)`. Only used to determine the shape of + the noise and create tensors on the same device. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + - Noise ``n`` of shape :math:`(B, *)` to be added to the latent state. + - Noise level ``sigma`` of shape :math:`(B, 1, 1, 1)`. + - Weight ``weight`` of shape :math:`(B, 1, 1, 1)` to multiply the loss. + """ + # Sample noise level + rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=y.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + # Loss weight + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + # Sample noise + n = torch.randn_like(y) * sigma + return n, sigma, weight + def __call__( self, net: torch.nn.Module, @@ -712,35 +755,34 @@ def __call__( y = y_patched y_lr = y_lr_patched - # Noise - rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() - weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - - # Input + noise - latent = y + torch.randn_like(y) * sigma + # Add noise to the latent state + n, sigma, weight = self.get_noise_params(y) if lead_time_label is not None: D_yn = net( - latent, + y + n, y_lr, sigma, embedding_selector=None, - global_index=patching.global_index(batch_size, img_clean.device) - if patching is not None - else None, + global_index=( + patching.global_index(batch_size, img_clean.device) + if patching is not None + else None + ), lead_time_label=lead_time_label, augment_labels=augment_labels, ) else: D_yn = net( - latent, + y + n, y_lr, sigma, embedding_selector=None, - global_index=patching.global_index(batch_size, img_clean.device) - if patching is not None - else None, + global_index=( + patching.global_index(batch_size, img_clean.device) + if patching is not None + else None + ), augment_labels=augment_labels, ) loss = weight * ((D_yn - y) ** 2) diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index e0a447a..f1edc6b 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -69,9 +69,9 @@ class UNet(nn.Module): # TODO a lot of redundancy, need to clean up See Also -------- For information on model types and their usage: - :class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models - :class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings - :class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings + :class:`~models.song_unet.SongUNet`: Basic U-Net for diffusion models + :class:`~models.song_unet.SongUNetPosEmbd`: U-Net with positional embeddings + :class:`~models.song_unet.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings Please refer to the documentation of these classes for details on how to call and use these models directly. @@ -152,6 +152,44 @@ def __init__( **model_kwargs, ) + @property + def use_fp16(self): + """ + bool: Whether the model uses float16 precision. + + Returns + ------- + bool + True if the model is in float16 mode, False otherwise. + """ + return self._use_fp16 + + @use_fp16.setter + def use_fp16(self, value: bool): + """ + Set whether the model should use float16 precision. + + Parameters + ---------- + value : bool + If True, moves the model to torch.float16. If False, moves to torch.float32. + + Raises + ------ + ValueError + If `value` is not a boolean. + """ + # NOTE: allow 0/1 values for older checkpoints + if not (isinstance(value, bool) or value in [0, 1]): + raise ValueError( + f"`use_fp16` must be a boolean, but got {type(value).__name__}." + ) + self._use_fp16 = value + if value: + self.to(torch.float16) + else: + self.to(torch.float32) + def forward( self, x: torch.Tensor, diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 84941b4..b9092f2 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -151,7 +151,9 @@ def main(cfg: DictConfig) -> None: img_out_channels = len(dataset.output_channels()) if cfg.model.hr_mean_conditioning: img_in_channels += img_out_channels - + logger0.info(f"Training on dataset with grid size {img_shape[0]}x{img_shape[1]}, {img_in_channels} input channels and {img_out_channels} output channels.") + logger0.info(f"Input channels: {dataset.input_channels()}") + logger0.info(f"Output channels: {dataset.output_channels()}") if cfg.model.name == "lt_aware_ce_regression": prob_channels = dataset.get_prob_channel_index() #TODO figure out what prob_channel are and update dataloader From 7ecc3d98bb9ba2088ae01600294f07129a3e00fc Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 19 Sep 2025 17:33:35 +0200 Subject: [PATCH 158/189] fix bug in dataset initialization --- src/hirad/datasets/era5_cosmo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 6957877..a4fa82f 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -7,7 +7,7 @@ import torch.nn.functional as F class ERA5_COSMO(DownscalingDataset): - def __init__(self, dataset_path: str, input_channel_names: List[str] = [], output_channel_names: List[str] = []): + def __init__(self, dataset_path: str, input_channel_names: List[str] = [], output_channel_names: List[str] = [], static_channel_names: List[str] = [], transform_channels: List[str] = []): super().__init__() #TODO switch hanbdling paths to Path rather than pure strings From b9a843e6882cc95ac0ba0b026cbed32b84ad207a Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 24 Sep 2025 19:53:23 +0200 Subject: [PATCH 159/189] Add a less heavyweight Dockerfile that doesn't use nvidia stuff, for input data processing. --- ci/docker/Dockerfile.python | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 ci/docker/Dockerfile.python diff --git a/ci/docker/Dockerfile.python b/ci/docker/Dockerfile.python new file mode 100644 index 0000000..80ea1eb --- /dev/null +++ b/ci/docker/Dockerfile.python @@ -0,0 +1,29 @@ +# Following some suggestions in https://meteoswiss.atlassian.net/wiki/spaces/APN/pages/719684202/Clariden+Alps+environment+setup + +FROM ubuntu:22.04 as builder +#FROM nvcr.io/nvidia/pytorch:25.01-py3 + + +# setup +RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade \ + pip + #ninja + #wheel + #packaging + #setuptools + +# install the rest of dependencies +# TODO: Factor pydeps into a separate file(s) +# TODO: Add versions for things +RUN pip install \ + anemoi-datasets \ + cartopy \ + matplotlib \ + numpy \ + pandas \ + scipy \ + torch + + + From 3a1ae7c4d9090f503e6a711a537a4b226237c95f Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 24 Sep 2025 19:54:47 +0200 Subject: [PATCH 160/189] Modifications to input data tests to optionally load the torch files --- src/hirad/input_data/test_input_data.py | 30 ++++++++++++++----------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/hirad/input_data/test_input_data.py b/src/hirad/input_data/test_input_data.py index fa13491..d8d0f6e 100644 --- a/src/hirad/input_data/test_input_data.py +++ b/src/hirad/input_data/test_input_data.py @@ -45,13 +45,15 @@ def make_stats(filepath: str): def main(): root = logging.getLogger() root.setLevel(logging.INFO) - logging.info('starting main') input_directory = sys.argv[1] + + logging.info(f'checking input directory {input_directory}') missing_data = [] corrupt_data = [] nan_data = [] check_for_nans = False + check_for_corrupt = False @@ -70,19 +72,21 @@ def main(): while (expected_date < curr_date): missing_data.append(datetime.datetime.strftime(expected_date, '%Y%m%d-%H%M')) expected_date = expected_date + delta - try: - data = torch.load(os.path.join(input_directory, f), weights_only=False) - except: - logging.info(f'corrupt data: {curr_date}') - corrupt_data.append(curr_date) - if check_for_nans: - if count_nans(data): - logging.info(f'data nans: {curr_date}') - nan_data.append(curr_date) + if check_for_corrupt: + try: + data = torch.load(os.path.join(input_directory, f), weights_only=False) + except: + logging.info(f'corrupt data: {curr_date}') + corrupt_data.append(curr_date) + if check_for_nans or curr_date == start_date: + if count_nans(data): + logging.info(f'data nans: {curr_date}') + nan_data.append(curr_date) prev_date = curr_date - logging.info(f'missing data: {missing_data}') - logging.info(f'corrupt data: {corrupt_data}') - logging.info(f'nan data: {nan_data}') + logging.info(f'missing data size {len(missing_data)}: {missing_data}') + logging.info(f'corrupt data size {len(corrupt_data)}: {corrupt_data}') + if check_for_nans: + logging.info(f'nan data: {nan_data}') if __name__ == "__main__": main() From e77c61ae188e66ba057886758e07cefc82f65ebd Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 24 Sep 2025 19:55:09 +0200 Subject: [PATCH 161/189] Scripts used to reprocess data for various runs --- .../input_data/reprocess_change_tp_accum.py | 58 +++++++++++++++++++ src/hirad/input_data/reprocess_exclude_tp.py | 43 ++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 src/hirad/input_data/reprocess_change_tp_accum.py create mode 100644 src/hirad/input_data/reprocess_exclude_tp.py diff --git a/src/hirad/input_data/reprocess_change_tp_accum.py b/src/hirad/input_data/reprocess_change_tp_accum.py new file mode 100644 index 0000000..410c351 --- /dev/null +++ b/src/hirad/input_data/reprocess_change_tp_accum.py @@ -0,0 +1,58 @@ +import logging +import os +import sys + +import torch +import numpy as np + +# Reprocess ERA-interpolated data to exclude the tp variable. + +# 6H data is all channels, but with 6h accumulation +DATA_SOURCE_6H = "/capstor/scratch/cscs/mmcgloho/datasets/processed/era5-cosmo-1h-all-channels/era-interpolated" +STATS_FILEPATH_6H = "/capstor/scratch/cscs/mmcgloho/datasets/processed/era5-cosmo-1h-all-channels/info" +# 1h data is the updated +DATA_SOURCE_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp/" +STATS_FILEPATH_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/info" +OUTPUT_DIR = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/train/era-interpolated" +OUTPUT_STATS_FILEPATH = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/train/info/" +TP_INDEX_6H = 34 # in era-all.yaml +TP_INDEX_1H = 12 # in era.yaml + +def process(input_directory_6h: str, input_directory_1h: str, output_directory: str): + input_6h_filepath = os.path.join(input_directory_6h) + files = os.listdir(input_6h_filepath) + files.sort() + for f in range(len(files)): + if f % 100 == 0: + logging.info(f) + input_1h_file = os.path.join(input_directory_1h, files[f]) + input_6h_file = os.path.join(input_directory_6h, files[f]) + outfile = os.path.join(output_directory, files[f]) + in_data_6h = torch.load(input_6h_file, weights_only=False) + in_data_1h = torch.load(input_1h_file, weights_only=False) + in_data_6h[TP_INDEX_6H,:] = in_data_1h[TP_INDEX_1H,:] + torch.save(in_data_6h, outfile) + +def edit_info(info_6h_filepath: str, info_1h_filepath: str, output_filepath: str): + stats_6h = torch.load(os.path.join(info_6h_filepath, 'era-stats'), weights_only=False) + stats_1h = torch.load(os.path.join(info_1h_filepath, 'era-stats'), weights_only=False) + logging.info(f'6h stats: {stats_6h}') + logging.info(f'1h stats: {stats_1h}') + for k in stats_6h.keys(): + logging.info(k, stats_6h[k]) + tmp = stats_6h[k] + tmp[TP_INDEX_6H] = stats_1h[k][TP_INDEX_1H] + stats_6h[k] = tmp + logging.info(stats_6h) + torch.save(stats_6h, os.path.join(output_filepath, 'era-stats')) + +def main(): + logging.basicConfig( + format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + process(DATA_SOURCE_6H, DATA_SOURCE_1H, OUTPUT_DIR) + edit_info(STATS_FILEPATH_6H, STATS_FILEPATH_1H, OUTPUT_STATS_FILEPATH) + +if __name__ == "__main__": + main() diff --git a/src/hirad/input_data/reprocess_exclude_tp.py b/src/hirad/input_data/reprocess_exclude_tp.py new file mode 100644 index 0000000..190a474 --- /dev/null +++ b/src/hirad/input_data/reprocess_exclude_tp.py @@ -0,0 +1,43 @@ +import logging +import os +import sys + +import torch +import numpy as np + +# Reprocess ERA-interpolated data to exclude the tp variable. + +TP_INDEX = 12 + +def process(input_directory: str, output_directory: str): + input_filepath = os.path.join(input_directory, 'era-interpolated') + files = os.listdir(input_filepath) + files.sort() + for f in range(len(files)): + if f % 100 == 0: + logging.info(f) + outfile = os.path.join(output_directory, 'era-interpolated', files[f]) + if (not os.path.exists(outfile)) or (os.path.getsize(outfile) < 26000000): + in_data = torch.load(os.path.join(input_filepath, files[f]), weights_only=False) + out_data = in_data[0:TP_INDEX,:] + torch.save(out_data, outfile) + +def edit_info(input_filepath: str, output_filepath: str): + stats = torch.load(os.path.join(input_filepath, '/info', 'era-stats'), weights_only=False) + logging.info(stats) + for k in stats.keys(): + logging.info(k, stats[k]) + stats[k] = stats[k][0:TP_INDEX] + logging.info(stats) + torch.save(stats, os.path.join(output_filepath, "/info", "era-stats")) + +def main(): + root = logging.getLogger() + root.setLevel(logging.INFO) + input_directory = sys.argv[1] + output_directory = sys.argv[2] + process(input_directory, output_directory) + #edit_info(input_directory, output_directory) + +if __name__ == "__main__": + main() From 88dcc011d29bf49a772e3f4669a4df4cb00af843 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 2 Oct 2025 12:31:09 +0200 Subject: [PATCH 162/189] Add reprocessing script --- src/hirad/input_data/reprocess_change_tp_accum.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/hirad/input_data/reprocess_change_tp_accum.py b/src/hirad/input_data/reprocess_change_tp_accum.py index 410c351..5642300 100644 --- a/src/hirad/input_data/reprocess_change_tp_accum.py +++ b/src/hirad/input_data/reprocess_change_tp_accum.py @@ -11,16 +11,16 @@ DATA_SOURCE_6H = "/capstor/scratch/cscs/mmcgloho/datasets/processed/era5-cosmo-1h-all-channels/era-interpolated" STATS_FILEPATH_6H = "/capstor/scratch/cscs/mmcgloho/datasets/processed/era5-cosmo-1h-all-channels/info" # 1h data is the updated -DATA_SOURCE_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp/" -STATS_FILEPATH_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/info" -OUTPUT_DIR = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/train/era-interpolated" -OUTPUT_STATS_FILEPATH = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/train/info/" +DATA_SOURCE_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/validation/era-interpolated-with-copernicus-tp/" +STATS_FILEPATH_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/validation/info" +OUTPUT_DIR = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/validation/era-interpolated" +OUTPUT_STATS_FILEPATH = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/validation/info/" TP_INDEX_6H = 34 # in era-all.yaml TP_INDEX_1H = 12 # in era.yaml def process(input_directory_6h: str, input_directory_1h: str, output_directory: str): - input_6h_filepath = os.path.join(input_directory_6h) - files = os.listdir(input_6h_filepath) + input_1h_filepath = os.path.join(input_directory_1h) + files = os.listdir(input_1h_filepath) files.sort() for f in range(len(files)): if f % 100 == 0: From 425050f0650522474ff867dde95b0c579b287256 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 2 Oct 2025 15:55:03 +0200 Subject: [PATCH 163/189] add functionality to have hour of the day and month of the year embedding as conditioning inputs --- src/hirad/datasets/era5_cosmo.py | 109 +++++++++++++++++++++++++---- src/hirad/utils/inference_utils.py | 9 ++- 2 files changed, 99 insertions(+), 19 deletions(-) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index a4fa82f..1db8716 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -5,18 +5,35 @@ from typing import List, Tuple import yaml import torch.nn.functional as F +import time +# import zarr + +from hirad.utils.console import PythonLogger + +logger = PythonLogger(__name__) + +DATASET_ORIG_PATH = '/capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full' class ERA5_COSMO(DownscalingDataset): - def __init__(self, dataset_path: str, input_channel_names: List[str] = [], output_channel_names: List[str] = [], static_channel_names: List[str] = [], transform_channels: List[str] = []): + def __init__(self, + dataset_path: str, + input_channel_names: List[str] = [], + output_channel_names: List[str] = [], + static_channel_names: List[str] = [], + transform_channels: List[str] = [], + n_month_hour_channels: int = None, + ): super().__init__() #TODO switch hanbdling paths to Path rather than pure strings + self._n_month_hour_channels = n_month_hour_channels self._dataset_path = dataset_path self._era5_path = os.path.join(dataset_path, 'era-interpolated') self._cosmo_path = os.path.join(dataset_path, 'cosmo') - self._info_path = os.path.join(dataset_path, 'info') - self._static_path = '/capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/static'# os.path.join(dataset_path, 'static') - self._zarr_path = os.path.join(dataset_path, 'dataset.zarr') + self._info_path = os.path.join(DATASET_ORIG_PATH, 'info') + # self._static_path = '/capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/static'# os.path.join(dataset_path, 'static') + self._static_path = os.path.join(DATASET_ORIG_PATH, 'static')# os.path.join(dataset_path, 'static') + # self._zarr_path = os.path.join(dataset_path, 'dataset.zarr') # load file list (each file is one date-time state) self._file_list = sorted(os.listdir(self._cosmo_path)) @@ -99,13 +116,13 @@ def __init__(self, dataset_path: str, input_channel_names: List[str] = [], outpu if input_channel_idx is not None: self.input_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda) self.input_inverse_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda) - self.input_mean[input_channel_idx] = torch.load(os.path.join("/users/pstamenk/HiRAD-Gen",f"era5-{transform_descriptor}-mean"), weights_only=False) - self.input_std[input_channel_idx] = torch.load(os.path.join("/users/pstamenk/HiRAD-Gen",f"era5-{transform_descriptor}-std"), weights_only=False) + self.input_mean[input_channel_idx] = torch.load(os.path.join(self._info_path,f"era5-{transform_descriptor}-mean"), weights_only=False) + self.input_std[input_channel_idx] = torch.load(os.path.join(self._info_path,f"era5-{transform_descriptor}-std"), weights_only=False) if output_channel_idx is not None: self.output_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda) self.output_inverse_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda) - self.output_mean[output_channel_idx] = torch.load(os.path.join("/users/pstamenk/HiRAD-Gen",f"cosmo-{transform_descriptor}-mean"), weights_only=False) - self.output_std[output_channel_idx] = torch.load(os.path.join("/users/pstamenk/HiRAD-Gen",f"cosmo-{transform_descriptor}-std"), weights_only=False) + self.output_mean[output_channel_idx] = torch.load(os.path.join(self._info_path,f"cosmo-{transform_descriptor}-mean"), weights_only=False) + self.output_std[output_channel_idx] = torch.load(os.path.join(self._info_path,f"cosmo-{transform_descriptor}-std"), weights_only=False) else: raise ValueError(f"Transformation: {transformation} for channel {channel} not implemented.") @@ -117,7 +134,11 @@ def __getitem__(self, idx): # flip so that it starts in top-left corner (by default it is bottom left) # orig_shape = [350,542] #TODO currently padding to be divisible by 16 orig_shape = self.image_shape() - era5_data = torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)[self._era_indeces] + try: + era5_data = torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)[self._era_indeces] + except: + logger.error(f"Error loading file {os.path.join(self._era5_path,self._file_list[idx])}") + raise era5_data = np.flip(era5_data \ .squeeze() \ .reshape(-1,*orig_shape), @@ -125,15 +146,29 @@ def __getitem__(self, idx): era5_data = np.concatenate((era5_data, self.static_data), axis=0) if self.static_data is not None else era5_data era5_data = self.normalize_input(era5_data) - cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)[self._cosmo_indeces] + try: + cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)[self._cosmo_indeces] + except: + logger.error(f"Error loading file {os.path.join(self._cosmo_path,self._file_list[idx])}") + raise cosmo_data = np.flip(cosmo_data\ .squeeze() \ .reshape(-1,*orig_shape), 1) cosmo_data = self.normalize_output(cosmo_data) + if self._n_month_hour_channels is not None and self._n_month_hour_channels>0: + # extract month and hour from filename + filename = self._file_list[idx] + date_str, hour_str = filename.split('-') + month = int(date_str[4:6]) + hour = int(hour_str[0:2]) + + time_grid = self.make_time_grids(hour, month) + era5_data = np.concatenate((era5_data, time_grid), axis=0) + return torch.tensor(cosmo_data),\ - torch.tensor(era5_data), + torch.tensor(era5_data) def __len__(self): return len(self._file_list) @@ -153,8 +188,13 @@ def latitude(self) -> np.ndarray: def input_channels(self) -> List[ChannelMetadata]: """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" - return self._era_channels + self._static_channels if self.static_data is not None else self._era_channels - + channels = self._era_channels + self._static_channels if self.static_data is not None else self._era_channels + if self._n_month_hour_channels is not None and self._n_month_hour_channels>0: + for i in range(self._n_month_hour_channels): + channels.append(ChannelMetadata("hour-enc",f"{i}")) + for i in range(self._n_month_hour_channels): + channels.append(ChannelMetadata("month-enc",f"{i}")) + return channels def output_channels(self) -> List[ChannelMetadata]: """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" @@ -183,6 +223,8 @@ def normalize_input(self, x: np.ndarray) -> np.ndarray: def denormalize_input(self, x: np.ndarray) -> np.ndarray: """Convert input from normalized data to physical units.""" + if self._n_month_hour_channels is not None and self._n_month_hour_channels>0: + x = x[:,:-2*self._n_month_hour_channels,:,:] x = x * self.input_std.reshape((self.input_std.shape[0],1,1)) \ + self.input_mean.reshape((self.input_mean.shape[0],1,1)) for channel_idx, inverse_transform in self.input_inverse_transforms.items(): @@ -214,4 +256,43 @@ def box_cox_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarr def box_cox_inverse_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarray: """Apply inverse Box-Cox transformation to the data.""" channel_array = np.clip(channel_array, -1/lmbda, None) - return np.power((lmbda * channel_array) + 1, 1 / lmbda) \ No newline at end of file + return np.power((lmbda * channel_array) + 1, 1 / lmbda) + + def make_time_grids(self, hour, month): + """ + Create multi-frequency cyclic sin/cos feature grids for hour and month. + + Parameters + ---------- + hour : int + Hour of day, 0-23 + month : int + Month of year, 1-12 + + Returns + ------- + grid : np.ndarray, shape (C, H, W) + Channels = [sin(k*hour), cos(k*hour), sin(k*month), cos(k*month) for each k frequency] + """ + H, W = self.image_shape() + hour_freqs = np.arange(1, self._n_month_hour_channels//2 + 1) + month_freqs = np.arange(1, self._n_month_hour_channels//2 + 1) + + channels = [] + + # --- hour encodings --- + for k in hour_freqs: + angle = 2 * np.pi * k * (hour % 24) / 24.0 + channels.append(np.sin(angle)) + channels.append(np.cos(angle)) + + # --- month encodings --- + for k in month_freqs: + angle = 2 * np.pi * k * ((month - 1) % 12) / 12.0 + channels.append(np.sin(angle)) + channels.append(np.cos(angle)) + + channels = np.array(channels, dtype=np.float32) + grid = np.tile(channels[:, None, None], (1, H, W)) # (C, H, W) + + return grid \ No newline at end of file diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index a792c70..0cdfd98 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -208,13 +208,12 @@ def diffusion_step( def save_results_as_torch(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): - os.makedirs(output_path, exist_ok=True) - target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) - prediction_ensemble = np.flip(dataset.denormalize_output(image_pred.squeeze()),-2) - baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1) + target = np.flip(dataset.denormalize_output(image_hr)[0,::].squeeze(),1) + prediction_ensemble = np.flip(dataset.denormalize_output(image_pred).squeeze(),-2) + baseline = np.flip(dataset.denormalize_input(image_lr)[0,::].squeeze(),1) if mean_pred is not None: - mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) + mean_pred = np.flip(dataset.denormalize_output(mean_pred)[0,::].squeeze(),1) torch.save(mean_pred, os.path.join(output_path, f'{time_step}-regression-prediction')) torch.save(target, os.path.join(output_path, f'{time_step}-target')) torch.save(prediction_ensemble, os.path.join(output_path, f'{time_step}-predictions')) From 6a09fb63d539539d873a926f40ec771d4003454d Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 2 Oct 2025 16:58:38 +0200 Subject: [PATCH 164/189] update config to latest features --- src/hirad/conf/dataset/era_cosmo.yaml | 9 ++++++--- src/hirad/conf/dataset/era_cosmo_inference.yaml | 8 ++++++-- src/hirad/conf/model/era_cosmo_diffusion.yaml | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml index 37b3df2..fc81714 100644 --- a/src/hirad/conf/dataset/era_cosmo.yaml +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -1,6 +1,9 @@ type: era5_cosmo -dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/train +dataset_path: /iopsstor/scratch/cscs/mmcgloho/run-1_2/train/ # dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-all-channels -validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation +validation_path: /iopsstor/scratch/cscs/mmcgloho/run-1_2/validation/ input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp] -output_channel_names: [2t, 10u, 10v, tp] \ No newline at end of file +output_channel_names: [2t, 10u, 10v, tp] +static_channel_names: ['hsurf'] +transform_channels: ['tp-box_cox_025'] +n_month_hour_channels: 4 \ No newline at end of file diff --git a/src/hirad/conf/dataset/era_cosmo_inference.yaml b/src/hirad/conf/dataset/era_cosmo_inference.yaml index bca045d..71b0db1 100644 --- a/src/hirad/conf/dataset/era_cosmo_inference.yaml +++ b/src/hirad/conf/dataset/era_cosmo_inference.yaml @@ -1,4 +1,8 @@ type: era5_cosmo -dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation +# dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/era5-cosmo-1h-linear-interpolation/validation +dataset_path: /iopsstor/scratch/cscs/mmcgloho/run-1_2/validation input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp] -output_channel_names: [2t, 10u, 10v, tp] \ No newline at end of file +output_channel_names: [2t, 10u, 10v, tp] +static_channel_names: ['hsurf'] +transform_channels: ['tp-box_cox_025'] +n_month_hour_channels: 4 \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_diffusion.yaml b/src/hirad/conf/model/era_cosmo_diffusion.yaml index 441239e..7a060e8 100644 --- a/src/hirad/conf/model/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/model/era_cosmo_diffusion.yaml @@ -10,6 +10,6 @@ model_args: # Controls how positional information is encoded. N_grid_channels: 4 # Number of channels for positional grid embeddings - embedding_type: "zero" + embedding_type: "positional" # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, # 'zero' for none \ No newline at end of file From ba08752638f2a7252febb7806985a91ccd0378e0 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 8 Oct 2025 10:08:32 +0200 Subject: [PATCH 165/189] remove scoringrules packeage, not installed --- src/hirad/eval/metrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index 8c8afde..af9d5a9 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -3,7 +3,6 @@ import numpy as np import torch import xskillscore -import scoringrules as sr from scipy.signal import periodogram import xskillscore From febbfe8edbe73e19b7762ba5741992542fe70f89 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Wed, 8 Oct 2025 10:10:48 +0200 Subject: [PATCH 166/189] put snapshot maps into own sbatch script --- src/hirad/eval_precip.sh | 1 - src/hirad/snapshots.sh | 50 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 src/hirad/snapshots.sh diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index db9b190..7574b07 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -57,5 +57,4 @@ srun --environment=./ci/edf/modulus_env.toml bash -c " # Maps python src/hirad/eval/map_precip_stats.py --config-name=generate_era_cosmo.yaml - # python src/hirad/eval/snapshots.py --config-name=generate_era_cosmo.yaml " \ No newline at end of file diff --git a/src/hirad/snapshots.sh b/src/hirad/snapshots.sh new file mode 100644 index 0000000..16a12de --- /dev/null +++ b/src/hirad/snapshots.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +#SBATCH --job-name="eval_precip" + +### HARDWARE ### +#SBATCH --partition=normal +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +#SBATCH --time=05:00:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/plots_precip.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 +# echo "Physical cores: $PHYSICAL_CORES" +# echo "Local processes: $LOCAL_PROCS" +# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + + python src/hirad/eval/snapshots.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file From 778f52b114428856bb3967412b9d6017e6361d25 Mon Sep 17 00:00:00 2001 From: David Leutwyler <14977216+leuty@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:25:04 +0200 Subject: [PATCH 167/189] Eval More 10m winds (#23) --- src/hirad/eval/map_precip_stats.py | 2 +- src/hirad/eval/map_wind_stats.py | 405 ++++++++++++++++++ .../eval/probability_of_exceedance_wind.py | 344 +++++++++++++++ src/hirad/eval_precip.sh | 1 - src/hirad/eval_wind.sh | 53 +++ 5 files changed, 803 insertions(+), 2 deletions(-) create mode 100644 src/hirad/eval/map_wind_stats.py create mode 100644 src/hirad/eval/probability_of_exceedance_wind.py create mode 100644 src/hirad/eval_wind.sh diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py index d17b8f9..257cfcd 100644 --- a/src/hirad/eval/map_precip_stats.py +++ b/src/hirad/eval/map_precip_stats.py @@ -61,7 +61,7 @@ def plot_stat_map(data, filename, stat_config, label): plot_map( data, filename, title=f'{label}: {stat_config["title_stat"]} (%)', - label='Wet-Hour Frequency [%]', vmin=0, vmax=10, cmap='PuBu', extend='max' + label='Wet-Hour Frequency [%]', vmin=0, vmax=30, cmap='PuBu', extend='max' ) elif stat_config['type'] == 'cdd': plot_map( diff --git a/src/hirad/eval/map_wind_stats.py b/src/hirad/eval/map_wind_stats.py new file mode 100644 index 0000000..ee6855a --- /dev/null +++ b/src/hirad/eval/map_wind_stats.py @@ -0,0 +1,405 @@ +import logging +from datetime import datetime +from pathlib import Path + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import plot_map, get_channel_indices, LOG_INTERVAL + + +def compute_wind_speed(u, v): + """Compute wind speed from U and V components.""" + return np.hypot(u, v) + + +def compute_wind_direction(u, v): + """Compute wind direction in degrees (meteorological convention: direction FROM which wind blows). + + Returns angle in degrees: 0° = North, 90° = East, 180° = South, 270° = West + """ + # atan2(u, v) gives mathematical angle, convert to meteorological + direction = np.rad2deg(np.arctan2(u, v)) + 180 + direction = np.mod(direction, 360) + return direction + + +def circular_mean_direction(directions, weights=None): + """Calculate circular mean of wind directions in degrees. + + Args: + directions: Array of directions in degrees + weights: Optional weights (e.g., wind speeds) + """ + # Convert to radians + rad = np.deg2rad(directions) + + # Calculate weighted mean of sin and cos components + if weights is not None: + sin_mean = np.average(np.sin(rad), weights=weights, axis=0) + cos_mean = np.average(np.cos(rad), weights=weights, axis=0) + else: + sin_mean = np.mean(np.sin(rad), axis=0) + cos_mean = np.mean(np.cos(rad), axis=0) + + # Calculate mean direction + mean_dir = np.arctan2(sin_mean, cos_mean) + mean_dir_deg = np.rad2deg(mean_dir) + mean_dir_deg = np.mod(mean_dir_deg, 360) + + return mean_dir_deg + + +def circular_std(directions): + """Calculate circular standard deviation of wind directions in degrees. + + Returns values from 0 (perfect alignment) to ~81.03 degrees (uniform distribution) + """ + rad = np.deg2rad(directions) + + # Calculate mean resultant length + sin_mean = np.mean(np.sin(rad), axis=0) + cos_mean = np.mean(np.cos(rad), axis=0) + R = np.hypot(sin_mean, cos_mean) + + # Circular standard deviation + # Handle R=0 case to avoid log(0) + R = np.clip(R, 1e-10, 1.0) + circ_std = np.rad2deg(np.sqrt(-2 * np.log(R))) + + return circ_std + + +def apply_wind_statistic(u_data, v_data, stat_type, stat_param=None): + """Apply a wind statistic to U and V components along the time dimension. + + Args: + u_data: xarray DataArray of U wind component + v_data: xarray DataArray of V wind component + stat_type: Type of statistic to compute + stat_param: Optional parameter for the statistic (e.g., quantile value) + + Returns: + Result as numpy array + """ + # Compute wind speed + speed = compute_wind_speed(u_data.values, v_data.values) + + if stat_type == 'mean_speed': + return np.mean(speed, axis=0) + + if stat_type == 'quantile_speed': + return np.quantile(speed, stat_param, axis=0) + + if stat_type == 'max_speed': + return np.max(speed, axis=0) + + if stat_type == 'wind_power': + # Wind power density is proportional to cube of wind speed + return np.mean(speed**3, axis=0) + + if stat_type == 'calm_freq': + # Frequency of calm conditions (< 2 m/s, Beaufort 0-1) + calm_threshold = 2.0 + return np.mean(speed < calm_threshold, axis=0) * 100 + + if stat_type == 'light_breeze_freq': + # Frequency of light breeze (> 1.6 m/s, Beaufort 2+) + light_breeze_threshold = 1.6 + return np.mean(speed > light_breeze_threshold, axis=0) * 100 + + if stat_type == 'moderate_breeze_freq': + # Frequency of moderate breeze (> 5.5 m/s, Beaufort 4+) + moderate_breeze_threshold = 5.5 + return np.mean(speed > moderate_breeze_threshold, axis=0) * 100 + + if stat_type == 'strong_breeze_freq': + # Frequency of strong breeze (> 10.8 m/s, Beaufort 6+) + strong_breeze_threshold = 10.8 + return np.mean(speed > strong_breeze_threshold, axis=0) * 100 + + if stat_type == 'gale_freq': + # Frequency of fresh gale (> 17.2 m/s, Beaufort 8+) + gale_threshold = 17.2 + return np.mean(speed > gale_threshold, axis=0) * 100 + + if stat_type == 'prevailing_direction': + # Compute wind directions + direction = compute_wind_direction(u_data.values, v_data.values) + # Weight by wind speed for more meaningful prevailing direction + return circular_mean_direction(direction, weights=speed) + + if stat_type == 'direction_variability': + # Circular standard deviation of wind direction + direction = compute_wind_direction(u_data.values, v_data.values) + return circular_std(direction) + + if stat_type == 'mean_u': + return np.mean(u_data.values, axis=0) + + if stat_type == 'mean_v': + return np.mean(v_data.values, axis=0) + + raise ValueError(f"Unsupported wind statistic type: {stat_type}") + + +def plot_wind_stat_map(data, filename, stat_config, label): + """Plot a single wind statistic map with appropriate styling.""" + + if stat_config['type'] == 'mean_speed': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Wind Speed [m/s]', vmin=0, vmax=10, cmap='inferno', extend='max' + ) + elif stat_config['type'] in ['quantile_speed', 'max_speed']: + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Wind Speed [m/s]', vmin=0, vmax=30, cmap='inferno', extend='max' + ) + elif stat_config['type'] == 'wind_power': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Wind Power Density [m³/s³]', vmin=0, vmax=1000, cmap='plasma', extend='max' + ) + elif stat_config['type'] in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq', 'strong_breeze_freq', 'gale_freq']: + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Frequency [%]', vmin=0, vmax=80, cmap='GnBu', extend='max' + ) + elif stat_config['type'] == 'prevailing_direction': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Direction [degrees from N]', vmin=0, vmax=360, cmap='twilight', extend='neither' + ) + elif stat_config['type'] == 'direction_variability': + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Circular Std Dev [degrees]', vmin=20, vmax=140, cmap='viridis', extend='max' + ) + elif stat_config['type'] in ['mean_u', 'mean_v']: + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Wind Component [m/s]', vmin=-10, vmax=10, cmap='RdBu_r', extend='both' + ) + else: + plot_map( + data, filename, + title=f'{label}: {stat_config["title_stat"]}', + label='Value', vmin=None, vmax=None, cmap='viridis', extend='neither' + ) + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup and config + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting wind statistics generation") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Processing {len(times)} timesteps") + + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + out_root = Path(cfg.generation.io.output_path or './outputs') + indices = get_channel_indices(dataset) + + # Get U and V wind component indices + u10_out = indices['output'].get('10u') + v10_out = indices['output'].get('10v') + u10_in = indices['input'].get('10u', u10_out) + v10_in = indices['input'].get('10v', v10_out) + + + # Wind statistic configurations + WIND_STATISTICS_CONFIG = { + 'mean_speed': { + 'type': 'mean_speed', + 'title': 'Mean Wind Speed' + }, + 'p90_speed': { + 'type': 'quantile_speed', + 'param': 0.90, + 'title': '90th Percentile Wind Speed' + }, + 'max_speed': { + 'type': 'max_speed', + 'title': 'Maximum Wind Speed' + }, + 'wind_power': { + 'type': 'wind_power', + 'title': 'Mean Wind Power Density' + }, + 'calm_freq': { + 'type': 'calm_freq', + 'title': 'Calm Frequency (<2 m/s, Beaufort 0-1)' + }, + 'light_breeze_freq': { + 'type': 'light_breeze_freq', + 'title': 'Light Breeze Frequency (>1.6 m/s, Beaufort 2+)' + }, + 'moderate_breeze_freq': { + 'type': 'moderate_breeze_freq', + 'title': 'Moderate Breeze Frequency (>5.5 m/s, Beaufort 4+)' + }, + 'strong_breeze_freq': { + 'type': 'strong_breeze_freq', + 'title': 'Strong Breeze Frequency (>10.8 m/s, Beaufort 6+)' + }, + 'gale_freq': { + 'type': 'gale_freq', + 'title': 'Gale Frequency (>17.2 m/s, Beaufort 8+)' + }, + 'prevailing_dir': { + 'type': 'prevailing_direction', + 'title': 'Prevailing Wind Direction' + }, + 'dir_variability': { + 'type': 'direction_variability', + 'title': 'Wind Direction Variability' + }, + 'mean_u': { + 'type': 'mean_u', + 'title': 'Mean U-Component' + }, + 'mean_v': { + 'type': 'mean_v', + 'title': 'Mean V-Component' + } + } + + stat_configs = [ + { + 'stat_name': name, + 'title_stat': config['title'], + 'param': config.get('param'), + **config + } + for name, config in WIND_STATISTICS_CONFIG.items() + ] + + # Target and baseline modes + basic_modes = { + 'target': ((u10_out, v10_out), 'COSMO-2 Analysis'), + 'baseline': ((u10_in, v10_in), 'ERA5'), + 'regression-prediction': ((u10_out, v10_out), 'Regression Prediction') + } + logger.info(f"Generating {len(stat_configs)} wind statistics for {len(basic_modes)} basic modes + predictions") + + for mode, (wind_channels, label) in basic_modes.items(): + logger.info(f"Processing mode: {mode}") + u_channel, v_channel = wind_channels + + # Load all timesteps for this mode + u_data_list = [] + v_data_list = [] + try: + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Loading {mode} timestep {i+1}/{len(times)}: {ts}") + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) + u_data_list.append(data[u_channel]) + v_data_list.append(data[v_channel]) + except Exception as e: + logger.warning(f"{mode} not available, skipping: {e}") + continue + + # Create xarray DataArrays + u_mode_data = xr.DataArray( + np.stack(u_data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + v_mode_data = xr.DataArray( + np.stack(v_data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + + # Compute and plot all statistics for this mode + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for {mode}...") + result = apply_wind_statistic( + u_mode_data, v_mode_data, + stat_config['type'], stat_config['param'] + ) + + map_output_dir = out_root / f"maps_wind_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_wind_stat_map( + result, + str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), + stat_config, + label + ) + + # Predictions mode: process each member separately to save memory + logger.info("Processing predictions mode...") + try: + data = torch.load(out_root/times[0]/f"{times[0]}-predictions", weights_only=False) + n_members = data.shape[0] + logger.info(f"Found {n_members} ensemble members") + + for member_idx in range(n_members): + logger.info(f"Processing prediction member {member_idx+1}/{n_members}") + + # Load all timesteps for this member + u_data_list = [] + v_data_list = [] + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Loading prediction member {member_idx} timestep {i+1}/{len(times)}: {ts}") + pred_data = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) + u_data_list.append(pred_data[member_idx, u10_out]) + v_data_list.append(pred_data[member_idx, v10_out]) + + u_member_data = xr.DataArray( + np.stack(u_data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + v_member_data = xr.DataArray( + np.stack(v_data_list, axis=0), + dims=['time', 'lat', 'lon'], + coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} + ) + + # Compute and plot all statistics for this member + for stat_config in stat_configs: + logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") + member_result = apply_wind_statistic( + u_member_data, v_member_data, + stat_config['type'], stat_config['param'] + ) + + # Create map + map_output_dir = out_root / f"maps_wind_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') + member_label = f'CorrDiff Member {member_idx+1}' + plot_wind_stat_map(member_result, member_filename, stat_config, member_label) + + except Exception as e: + logger.warning(f"Predictions not available, skipping: {e}") + + logger.info("All wind statistics maps generated successfully") + + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval/probability_of_exceedance_wind.py b/src/hirad/eval/probability_of_exceedance_wind.py new file mode 100644 index 0000000..6fe3ff7 --- /dev/null +++ b/src/hirad/eval/probability_of_exceedance_wind.py @@ -0,0 +1,344 @@ +"""Probability of exceedance for wind speed and components.""" +import logging +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +import xarray as xr + +from hirad.datasets import get_dataset_and_sampler_inference +from hirad.distributed import DistributedManager +from hirad.utils.function_utils import get_time_from_range +from hirad.eval.plotting import get_channel_indices, LOG_INTERVAL + + +def compute_wind_speed(u, v): + """Compute wind speed from U and V components.""" + return np.hypot(u, v) + + +def compute_exceedance_probs(values, thresholds, use_abs=False): + """Compute exceedance probabilities.""" + if use_abs: + return np.array([np.mean(np.abs(values) > t) for t in thresholds]) + else: + return np.array([np.mean(values > t) for t in thresholds]) + + +def update_exceedance_counts(counts, total, values, thresholds, use_abs=False): + """Update exceedance counts incrementally.""" + data = np.abs(values) if use_abs else values + for i, threshold in enumerate(thresholds): + counts[i] += np.sum(data > threshold) + total += len(values) + return counts, total + + +def compute_percentiles(values, percentile_dict, use_abs=False): + """Compute percentiles.""" + data = np.abs(values) if use_abs else values + data_array = xr.DataArray(data) + return {key: data_array.quantile(p).item() for key, p in percentile_dict.items()} + + +def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title, ylabel, out_path, percentiles_data=None): + """Save probability of exceedance plot.""" + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + + plt.figure(figsize=(10, 6)) + + # Plot exceedance curves + for (key, exceedance_data), label, color in zip(exceedance_data_dict.items(), labels, colors): + if isinstance(exceedance_data, tuple): # Handle ensemble data + # Plot individual members with transparency + for i, member_exceedance in enumerate(exceedance_data): + alpha = 0.5 if i > 0 else 0.7 + label_member = label if i == 0 else None + plt.plot(thresholds, member_exceedance, alpha=alpha, color=color, + label=label_member, linewidth=1) + else: + # Plot single dataset + plt.plot(thresholds, exceedance_data, alpha=0.7, color=color, + label=label, linewidth=2) + + plt.xscale('log') + plt.xlim(thresholds[1], thresholds[-1]) + plt.yscale('log') + plt.xlabel(ylabel) + plt.ylabel('Probability of Exceedance') + plt.ylim(1e-8, 1) + plt.title(title) + plt.grid(True, alpha=0.3) + + # Add percentile lines if provided + if percentiles_data: + # Calculate y-range for percentile lines (lowest 10% of log scale) + y_bottom, y_top = plt.ylim() + log_bottom, log_top = np.log10(y_bottom), np.log10(y_top) + vline_ymax = 10**(log_bottom + 0.1 * (log_top - log_bottom)) + vline_ymin = y_bottom + + # Define line styles for percentiles + percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'} + percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'} + colors_perc = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green', 'regression-prediction': 'red'} + legend_added = set() + + # Plot all percentile lines + for dataset_name, data in percentiles_data.items(): + color = colors_perc[dataset_name] + + if dataset_name in ['target', 'baseline', 'regression-prediction']: + # Single dataset + for percentile, value in data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.8) # No label here + else: + # Ensemble members + for member_data in data.values(): + for percentile, value in member_data.items(): + linestyle = percentile_styles[percentile] + legend_added.add(percentile) # Track percentiles for black legend entries + + plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax, + linestyles=linestyle, alpha=0.6) # No label here + + # Add black legend entries for percentiles (override the colored ones) + for percentile in [99, 99.9, 99.99]: + if percentile in legend_added: + plt.plot([], [], color='black', linestyle=percentile_styles[percentile], + label=percentile_labels[percentile]) + + plt.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig): + # Setup logging + DistributedManager.initialize() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Starting computation for probability of exceedance for wind speed") + times = get_time_from_range(cfg.generation.times_range, "%Y%m%d-%H%M") + logger.info(f"Loaded {len(times)} timesteps to process") + + # Initialize dataset + ds_cfg = OmegaConf.to_container(cfg.dataset) + dataset, _ = get_dataset_and_sampler_inference( + ds_cfg, times, cfg.generation.get('has_lead_time', False) + ) + logger.info("Dataset and sampler initialized") + + # Output root + out_root = Path(cfg.generation.io.output_path or './outputs') + + # Find channel indices for wind components + indices = get_channel_indices(dataset) + u10_out = indices['output'].get('10u') + v10_out = indices['output'].get('10v') + u10_in = indices['input'].get('10u', u10_out) + v10_in = indices['input'].get('10v', v10_out) + + if u10_out is None or v10_out is None: + logger.error("Wind components (10u, 10v) not found in dataset!") + return + + logger.info(f"Wind component channel indices - output: 10u={u10_out}, 10v={v10_out}, input: 10u={u10_in}, 10v={v10_in}") + + # Define thresholds for exceedance calculation (same for all variables) + thresholds = np.logspace(-1, 1.5, 200) # From 0.1 to ~31.6 m/s + n_thresholds = len(thresholds) + + # Storage for exceedance counts (incremental computation) + exceedance_counts = { + 'speed': {}, 'u': {}, 'v': {} + } + totals = {'speed': {}, 'u': {}, 'v': {}} + + # Storage for percentile computation (collect samples) + percentile_samples = {'speed': {}, 'u': {}, 'v': {}} + + # -- Process target and baseline -- + for mode in ['target', 'baseline', 'regression-prediction']: + logger.info(f"Processing mode: {mode}") + + # Initialize counts + for var in ['speed', 'u', 'v']: + exceedance_counts[var][mode] = np.zeros(n_thresholds, dtype=np.int64) + totals[var][mode] = 0 + percentile_samples[var][mode] = [] + + try: + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) + + # Extract wind components + if mode in ['target', 'regression-prediction']: + u = data[u10_out] + v = data[v10_out] + else: # baseline + u = data[u10_in] + v = data[v10_in] + + wind_speed = compute_wind_speed(u, v) + + # Get valid values + valid_mask = ~np.isnan(wind_speed) + speed_vals = wind_speed[valid_mask].flatten() + u_vals = u[valid_mask].flatten() + v_vals = v[valid_mask].flatten() + + # Update exceedance counts incrementally + exceedance_counts['speed'][mode], totals['speed'][mode] = update_exceedance_counts( + exceedance_counts['speed'][mode], totals['speed'][mode], speed_vals, thresholds, use_abs=False + ) + exceedance_counts['u'][mode], totals['u'][mode] = update_exceedance_counts( + exceedance_counts['u'][mode], totals['u'][mode], u_vals, thresholds, use_abs=True + ) + exceedance_counts['v'][mode], totals['v'][mode] = update_exceedance_counts( + exceedance_counts['v'][mode], totals['v'][mode], v_vals, thresholds, use_abs=True + ) + + # Collect samples for percentiles (subsample to save memory) + sample_rate = max(1, len(speed_vals) // 10000) # Keep ~10k samples per timestep + percentile_samples['speed'][mode].extend(speed_vals[::sample_rate]) + percentile_samples['u'][mode].extend(u_vals[::sample_rate]) + percentile_samples['v'][mode].extend(v_vals[::sample_rate]) + + except Exception as e: + logger.warning(f"{mode} data not found or error occurred, skipping: {e}") + continue + + logger.info(f"Processed {totals['speed'][mode]} values for {mode}") + + # -- Process predictions: compute exceedance for each ensemble member -- + logger.info("Processing predictions") + + n_members = None + member_counts = {'speed': [], 'u': [], 'v': []} + member_totals = {'speed': [], 'u': [], 'v': []} + member_samples = {'speed': [], 'u': [], 'v': []} + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Processing timestep {i+1}/{len(times)}") + + preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) # [n_members, n_channels, lat, lon] + + if n_members is None: + n_members = preds.shape[0] + for var in ['speed', 'u', 'v']: + member_counts[var] = [np.zeros(n_thresholds, dtype=np.int64) for _ in range(n_members)] + member_totals[var] = [0 for _ in range(n_members)] + member_samples[var] = [[] for _ in range(n_members)] + + for member_idx in range(n_members): + u = preds[member_idx, u10_out] + v = preds[member_idx, v10_out] + wind_speed = compute_wind_speed(u, v) + + valid_mask = ~np.isnan(wind_speed) + speed_vals = wind_speed[valid_mask].flatten() + u_vals = u[valid_mask].flatten() + v_vals = v[valid_mask].flatten() + + # Update counts + member_counts['speed'][member_idx], member_totals['speed'][member_idx] = update_exceedance_counts( + member_counts['speed'][member_idx], member_totals['speed'][member_idx], speed_vals, thresholds, use_abs=False + ) + member_counts['u'][member_idx], member_totals['u'][member_idx] = update_exceedance_counts( + member_counts['u'][member_idx], member_totals['u'][member_idx], u_vals, thresholds, use_abs=True + ) + member_counts['v'][member_idx], member_totals['v'][member_idx] = update_exceedance_counts( + member_counts['v'][member_idx], member_totals['v'][member_idx], v_vals, thresholds, use_abs=True + ) + + # Collect samples for percentiles + sample_rate = max(1, len(speed_vals) // 10000) + member_samples['speed'][member_idx].extend(speed_vals[::sample_rate]) + member_samples['u'][member_idx].extend(u_vals[::sample_rate]) + member_samples['v'][member_idx].extend(v_vals[::sample_rate]) + + logger.info(f"Collected {n_members} ensemble members for predictions") + + # Convert counts to probabilities + exceedance_data = {'speed': {}, 'u': {}, 'v': {}} + + for var in ['speed', 'u', 'v']: + # Single datasets + for mode in ['target', 'baseline', 'regression-prediction']: + if mode in exceedance_counts[var] and totals[var][mode] > 0: + exceedance_data[var][mode] = exceedance_counts[var][mode] / totals[var][mode] + + # Ensemble members + member_probs = [] + for member_idx in range(n_members): + if member_totals[var][member_idx] > 0: + member_probs.append(member_counts[var][member_idx] / member_totals[var][member_idx]) + exceedance_data[var]['predictions'] = tuple(member_probs) + + # Compute percentiles for all datasets and variables + percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999} + percentiles_data = {'speed': {}, 'u': {}, 'v': {}} + + # Single datasets (target, baseline, regression-prediction) + for var in ['speed', 'u', 'v']: + use_abs = (var in ['u', 'v']) + for mode in ['target', 'baseline', 'regression-prediction']: + if mode in percentile_samples[var] and len(percentile_samples[var][mode]) > 0: + percentiles_data[var][mode] = compute_percentiles( + np.array(percentile_samples[var][mode]), percentiles, use_abs + ) + + # Ensemble members + percentiles_data[var]['predictions'] = {} + for member_idx in range(n_members): + if len(member_samples[var][member_idx]) > 0: + percentiles_data[var]['predictions'][f'member_{member_idx}'] = compute_percentiles( + np.array(member_samples[var][member_idx]), percentiles, use_abs + ) + + # Create exceedance plots + labels = ['COSMO-2 Analysis', 'ERA5', 'Regression Prediction', 'CorrDiff Ensemble'] if 'regression-prediction' in exceedance_data['speed'] else ['COSMO-2 Analysis', 'ERA5', 'CorrDiff Ensemble'] + colors = ['blue', 'orange', 'red', 'green'] if 'regression-prediction' in exceedance_data['speed'] else ['blue', 'orange', 'green'] + + # Define plot configurations + plot_configs = [ + ('windspeed_exceedance.png', 'speed', 'Probability of Exceedance for Wind Speed', + 'All-hour Wind Speed [m/s] (Pooled Data)'), + ('wind_u_exceedance.png', 'u', 'Probability of Exceedance for abs(10u)', + 'All-hour 10u Component [m/s] (Pooled Data)'), + ('wind_v_exceedance.png', 'v', 'Probability of Exceedance for abs(10v)', + 'All-hour 10v Component [m/s] (Pooled Data)'), + ] + + for filename, var, title, ylabel in plot_configs: + fn = out_root / filename + save_exceedance_plot( + exceedance_data[var], + thresholds, + labels, + colors, + title, + ylabel, + fn, + percentiles_data[var] + ) + logger.info(f"{var.capitalize()} exceedance plot saved: {fn}") + + +if __name__ == '__main__': + main() diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh index 7574b07..ec522d9 100644 --- a/src/hirad/eval_precip.sh +++ b/src/hirad/eval_precip.sh @@ -49,7 +49,6 @@ srun --environment=./ci/edf/modulus_env.toml bash -c " # Diurnal cycle python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=generate_era_cosmo.yaml python src/hirad/eval/diurnal_cycle_precip_p99.py --config-name=generate_era_cosmo.yaml - python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=generate_era_cosmo.yaml # TODO: Transfer to relevant script. # Histograms python src/hirad/eval/hist.py --config-name=generate_era_cosmo.yaml diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh new file mode 100644 index 0000000..37194a6 --- /dev/null +++ b/src/hirad/eval_wind.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +#SBATCH --job-name="eval_wind" + +### HARDWARE ### +#SBATCH --partition=normal +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=72 +#SBATCH --time=6:00:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=./logs/plots_wind.log + +### ENVIRONMENT #### +#SBATCH -A a161 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# # Use SLURM_NTASKS (number of processes to be launched by torchrun) +# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# # Compute threads per process +# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +# export OMP_NUM_THREADS=$OMP_THREADS +export OMP_NUM_THREADS=72 + +srun --environment=./ci/edf/modulus_env.toml bash -c " + pip install -e . --no-dependencies + + # Diurnal cycle + python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=generate_era_cosmo.yaml + + # Maps + python src/hirad/eval/map_wind_stats.py --config-name=generate_era_cosmo.yaml + + # Probability of exceedance + python src/hirad/eval/probability_of_exceedance_wind.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file From 9783e101729fb8e0e08933e0dcac517740c83149 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 20 Oct 2025 13:10:33 +0200 Subject: [PATCH 168/189] first draft of regridding REA-L-CH1 --- src/hirad/input_data/regrid_realch1.py | 100 +++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 src/hirad/input_data/regrid_realch1.py diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py new file mode 100644 index 0000000..b19c890 --- /dev/null +++ b/src/hirad/input_data/regrid_realch1.py @@ -0,0 +1,100 @@ + + +import datetime +import logging +import os +import shutil +import sys +import yaml +import array + +from anemoi.datasets import open_dataset +from anemoi.datasets.data.dataset import Dataset +import netCDF4 +import numpy as np +from pandas import to_datetime +from scipy.interpolate import griddata +from meteodatalab import icon_grid +from meteodatalab.operators import regrid +import torch +import multiprocessing +import xarray as xr +from meteodatalab import ogd_api + +def anemoi_to_xarray(anemoi_data: Dataset, variable): + lon = anemoi_data.longitudes + lat = anemoi_data.latitudes + eps = [0] # deterministic + time = generate_times(anemoi_data) + var_index = anemoi_data.variables.index(variable) + metadata = getMetadataFromOGD() + + ds = xr.Dataset( + data_vars=dict( + variable=(["time", "eps", "cell"], np.array(anemoi_data.data[:,var_index,:,:])), + ), + coords=dict( + eps=eps, + time=time, + lon=("cell", lon), + lat=("cell", lat), + ), + attrs=dict(description=f'xarray from anemoi dataset for {variable}', + metadata=metadata), + ) + print(ds) + return ds + +def getMetadataFromOGD(): + lead_times = ["P0DT0H"] + req = ogd_api.Request( + collection="ogd-forecasting-icon-ch1", + variable="TOT_PREC", + ref_time="latest", + perturbed=False, + lead_time=lead_times, + ) + tot_prec = ogd_api.get_from_ogd(req) + return tot_prec.metadata + +def generate_times(anemoi_data: Dataset): + times = [] + curr_time = anemoi_data.start_date.item() + while curr_time <= anemoi_data.end_date: + times.append(curr_time) + curr_time = curr_time + anemoi_data.frequency + return times + + + +def get_coeffs_path(model: str): + return coeffs_path + # TODO some value error check file avialable ofr sth. + +def remap(): + # get UUID for 1-km native grid + #icon_grid_uuid = get_uuid('icon-ch1-eps') + + coeffs_path = f'/store_new/mch/msopr/icon_workflow_2/iconremap-weights/{model}-rotlatlon.nc' + coeffs = xr.open_dataset(coeffs_path) + + indices = coeffs["rbf_B_glbidx"].values + weights = coeffs["rbf_B_wgt"].values + geo = { + "gridType": "rotated_ll", + "longitudeOfSouthernPoleInDegrees": coeffs.north_pole_lon - 180, + "latitudeOfSouthernPoleInDegrees": -1 * coeffs.north_pole_lat, + } + dst = RegularGrid( + crs=_get_crs(geo), + nx=coeffs.nx, + ny=coeffs.ny, + xmin=coeffs.xmin, + ymin=coeffs.ymin, + xmax=coeffs.xmax, + ymax=coeffs.ymax, + ) + +realch1 = open_dataset('/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr') +myxarray = anemoi_to_xarray(realch1, "TOT_PREC") +regrid.icon2rotlatlon(myxarray) From 2d0d0f8111789fb3532ae33dbc5aa265e80c12ec Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 20 Oct 2025 13:11:05 +0200 Subject: [PATCH 169/189] remove unused function --- src/hirad/input_data/regrid_realch1.py | 30 -------------------------- 1 file changed, 30 deletions(-) diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py index b19c890..3aabef6 100644 --- a/src/hirad/input_data/regrid_realch1.py +++ b/src/hirad/input_data/regrid_realch1.py @@ -64,36 +64,6 @@ def generate_times(anemoi_data: Dataset): times.append(curr_time) curr_time = curr_time + anemoi_data.frequency return times - - - -def get_coeffs_path(model: str): - return coeffs_path - # TODO some value error check file avialable ofr sth. - -def remap(): - # get UUID for 1-km native grid - #icon_grid_uuid = get_uuid('icon-ch1-eps') - - coeffs_path = f'/store_new/mch/msopr/icon_workflow_2/iconremap-weights/{model}-rotlatlon.nc' - coeffs = xr.open_dataset(coeffs_path) - - indices = coeffs["rbf_B_glbidx"].values - weights = coeffs["rbf_B_wgt"].values - geo = { - "gridType": "rotated_ll", - "longitudeOfSouthernPoleInDegrees": coeffs.north_pole_lon - 180, - "latitudeOfSouthernPoleInDegrees": -1 * coeffs.north_pole_lat, - } - dst = RegularGrid( - crs=_get_crs(geo), - nx=coeffs.nx, - ny=coeffs.ny, - xmin=coeffs.xmin, - ymin=coeffs.ymin, - xmax=coeffs.xmax, - ymax=coeffs.ymax, - ) realch1 = open_dataset('/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr') myxarray = anemoi_to_xarray(realch1, "TOT_PREC") From 2aaa067e7d408101cfdd32b42bf893706771f9f4 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 20 Oct 2025 20:51:28 +0200 Subject: [PATCH 170/189] add conversion to geo coords --- src/hirad/input_data/regrid_realch1.py | 72 ++++++++++++++++++++------ 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py index 3aabef6..3a9770f 100644 --- a/src/hirad/input_data/regrid_realch1.py +++ b/src/hirad/input_data/regrid_realch1.py @@ -1,25 +1,18 @@ -import datetime import logging -import os -import shutil -import sys -import yaml -import array from anemoi.datasets import open_dataset from anemoi.datasets.data.dataset import Dataset -import netCDF4 import numpy as np -from pandas import to_datetime -from scipy.interpolate import griddata -from meteodatalab import icon_grid from meteodatalab.operators import regrid -import torch -import multiprocessing import xarray as xr from meteodatalab import ogd_api +from hirad.input_data.interpolate_basic import plot_projection + +import matplotlib.pyplot as plt +import cartopy.crs as ccrs +from earthkit.geo.rotate import unrotate def anemoi_to_xarray(anemoi_data: Dataset, variable): lon = anemoi_data.longitudes @@ -65,6 +58,55 @@ def generate_times(anemoi_data: Dataset): curr_time = curr_time + anemoi_data.frequency return times -realch1 = open_dataset('/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr') -myxarray = anemoi_to_xarray(realch1, "TOT_PREC") -regrid.icon2rotlatlon(myxarray) +def get_geo_coords(regridded_data: xr.Dataset): + xmin = regridded_data.metadata.get("longitudeOfFirstGridPointInDegrees") + xmax = regridded_data.metadata.get("longitudeOfLastGridPointInDegrees") + dx = regridded_data.metadata.get("iDirectionIncrementInDegrees") + ymin = regridded_data.metadata.get("latitudeOfFirstGridPointInDegrees") + ymax = regridded_data.metadata.get("latitudeOfLastGridPointInDegrees") + dy = regridded_data.metadata.get("jDirectionIncrementInDegrees") + y = np.arange(ymin,ymax+dy,dy) + x = np.arange(xmin,xmax+dx,dx) + sp_lat = regridded_data.metadata.get("latitudeOfSouthernPoleInDegrees") + sp_lon = regridded_data.metadata.get("longitudeOfSouthernPoleInDegrees") + xcoords = np.meshgrid(x,y)[0].flatten() + ycoords = np.meshgrid(x,y)[1].flatten() + geo_coords = unrotate(ycoords, xcoords, sp_lat, sp_lon) + return geo_coords + +def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None): + p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) + ax.coastlines() + ax.gridlines(draw_labels=False) + plt.colorbar(p, orientation="horizontal") + +def plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, projection=ccrs.PlateCarree(), cmap=None, vmin = None, vmax = None): + """Plot observed or interpolated data in a scatter plot.""" + # TODO: Refactor this somehow, it's not really generalizing well across variables. + fig = plt.figure() + fig, ax = plt.subplots(subplot_kw={"projection": projection}) + logging.info(f'plotting values to {filename}') + plot_projection(ax, longitudes, latitudes, values, cmap, vmin, vmax) + #p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) + #ax.coastlines() + #ax.gridlines(draw_labels=True) + #plt.colorbar(p, orientation="horizontal") + plt.savefig(filename) + plt.close('all') + + +realch1 = open_dataset('/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr', + ) +myxarray = anemoi_to_xarray(realch1, "TOT_PREC").to_dataarray() +regridded=regrid.icon2rotlatlon(myxarray) +plot_and_save_projection(realch1.longitudes, realch1.latitudes, + realch1[0,56,0,:], "anemoi.png") +plot_and_save_projection(myxarray.lon, myxarray.lat, + myxarray[0,0,0,:], "xarray.png") +# South pole rotation of lon=10, latitude=-43 +#rotated_crs = ccrs.RotatedPole( +# pole_longitude=190, pole_latitude=43 +#) +geo_coords = get_geo_coords(regridded) +plot_and_save_projection(geo_coords[1], geo_coords[0], + regridded[0,0,0,:], "regridded.png") \ No newline at end of file From 608b988165a54bceba1ec3acec354d0afd13effb Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 21 Oct 2025 17:15:44 +0200 Subject: [PATCH 171/189] use plotting function from interpolate_baisic --- src/hirad/input_data/interpolate_basic.py | 9 ++++---- src/hirad/input_data/regrid_realch1.py | 27 +++-------------------- 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/src/hirad/input_data/interpolate_basic.py b/src/hirad/input_data/interpolate_basic.py index e5953a1..604d5e4 100644 --- a/src/hirad/input_data/interpolate_basic.py +++ b/src/hirad/input_data/interpolate_basic.py @@ -15,6 +15,7 @@ import torch import multiprocessing import xarray +from earthkit.geo.rotate import unrotate # Margin to use for ERA dataset (to avoid nans from interpolation at boundary) ERA_MARGIN_DEGREES = 1.0 @@ -76,11 +77,11 @@ def _interpolate_task(i: int, era: Dataset, cosmo: Dataset, input_grid: np.ndarr logging.info(f'plotting {datestr} to {outfile_plots_path}') for j,var in enumerate(era.variables): # plot era original - _plot_and_save_projection(era.longitudes, era.latitudes, era[i, j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era.jpg') + plot_and_save_projection(era.longitudes, era.latitudes, era[i, j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era.jpg') - _plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, interpolated_data[j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era-interpolated.jpg') + plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, interpolated_data[j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era-interpolated.jpg') for j,var in enumerate(cosmo.variables): - _plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, cosmo[i, j, 0, :], f'{outfile_plots_path}{cosmo.variables[j]}-{datestr}-cosmo.jpg') + plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, cosmo[i, j, 0, :], f'{outfile_plots_path}{cosmo.variables[j]}-{datestr}-cosmo.jpg') @@ -163,7 +164,7 @@ def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.ar ax.gridlines(draw_labels=False) plt.colorbar(p, orientation="horizontal") -def _plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): +def plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): """Plot observed or interpolated data in a scatter plot.""" # TODO: Refactor this somehow, it's not really generalizing well across variables. fig = plt.figure() diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py index 3a9770f..e81d12b 100644 --- a/src/hirad/input_data/regrid_realch1.py +++ b/src/hirad/input_data/regrid_realch1.py @@ -8,7 +8,7 @@ from meteodatalab.operators import regrid import xarray as xr from meteodatalab import ogd_api -from hirad.input_data.interpolate_basic import plot_projection +from hirad.input_data.interpolate_basic import plot_and_save_projection import matplotlib.pyplot as plt import cartopy.crs as ccrs @@ -67,6 +67,7 @@ def get_geo_coords(regridded_data: xr.Dataset): dy = regridded_data.metadata.get("jDirectionIncrementInDegrees") y = np.arange(ymin,ymax+dy,dy) x = np.arange(xmin,xmax+dx,dx) + # TODO, this parameter is not producing what I want it to. sp_lat = regridded_data.metadata.get("latitudeOfSouthernPoleInDegrees") sp_lon = regridded_data.metadata.get("longitudeOfSouthernPoleInDegrees") xcoords = np.meshgrid(x,y)[0].flatten() @@ -74,29 +75,7 @@ def get_geo_coords(regridded_data: xr.Dataset): geo_coords = unrotate(ycoords, xcoords, sp_lat, sp_lon) return geo_coords -def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None): - p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) - ax.coastlines() - ax.gridlines(draw_labels=False) - plt.colorbar(p, orientation="horizontal") - -def plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, projection=ccrs.PlateCarree(), cmap=None, vmin = None, vmax = None): - """Plot observed or interpolated data in a scatter plot.""" - # TODO: Refactor this somehow, it's not really generalizing well across variables. - fig = plt.figure() - fig, ax = plt.subplots(subplot_kw={"projection": projection}) - logging.info(f'plotting values to {filename}') - plot_projection(ax, longitudes, latitudes, values, cmap, vmin, vmax) - #p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) - #ax.coastlines() - #ax.gridlines(draw_labels=True) - #plt.colorbar(p, orientation="horizontal") - plt.savefig(filename) - plt.close('all') - - -realch1 = open_dataset('/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr', - ) +realch1 = open_dataset('/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr') myxarray = anemoi_to_xarray(realch1, "TOT_PREC").to_dataarray() regridded=regrid.icon2rotlatlon(myxarray) plot_and_save_projection(realch1.longitudes, realch1.latitudes, From 705e652193801167ac4f9c068e12c2a9ec333e0c Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 21 Oct 2025 17:16:08 +0200 Subject: [PATCH 172/189] pip install meteodata-lab --- ci/docker/Dockerfile.corrdiff | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/docker/Dockerfile.corrdiff b/ci/docker/Dockerfile.corrdiff index 3187c84..f4be071 100644 --- a/ci/docker/Dockerfile.corrdiff +++ b/ci/docker/Dockerfile.corrdiff @@ -10,4 +10,5 @@ RUN pip install \ Cartopy==0.22.0 \ xskillscore \ scoringrules \ - mlflow + mlflow \ + meteodata-lab From 3a54d18e4910686586d971ea64615d6d1510856b Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 21 Oct 2025 17:16:58 +0200 Subject: [PATCH 173/189] config file for REA-L-CH1 --- src/hirad/input_data/realch1.yaml | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 src/hirad/input_data/realch1.yaml diff --git a/src/hirad/input_data/realch1.yaml b/src/hirad/input_data/realch1.yaml new file mode 100644 index 0000000..6b0f776 --- /dev/null +++ b/src/hirad/input_data/realch1.yaml @@ -0,0 +1,26 @@ +dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr' +select: ['CLCH', 'CLCL', 'CLCM', 'CLCT', + 'FI_100', 'FI_1000', 'FI_150', 'FI_200', 'FI_250', 'FI_300', 'FI_400', + 'FI_50', 'FI_500', 'FI_600', 'FI_700', 'FI_850', 'FI_925', + 'FR_LAND', 'HSURF', + 'OMEGA_100', 'OMEGA_1000', 'OMEGA_150', 'OMEGA_200', 'OMEGA_250', + 'OMEGA_300', 'OMEGA_400', 'OMEGA_50', 'OMEGA_500', 'OMEGA_600', + 'OMEGA_700', 'OMEGA_850', 'OMEGA_925', + 'PLCOV', 'PMSL', 'PS', + 'QV_100', 'QV_1000', 'QV_150', 'QV_200', 'QV_250', 'QV_300', 'QV_400', + 'QV_50', 'QV_500', 'QV_600', 'QV_700', 'QV_850', 'QV_925', + 'SKC', 'SKT', 'SOILTYP', 'SSO_GAMMA', 'SSO_SIGMA', 'SSO_STDH', + 'SSO_THETA', 'TD_2M', 'TOT_PREC', 'TOT_PREC_6H', + 'T_100', 'T_1000', 'T_150', 'T_200', 'T_250', 'T_2M', 'T_300', 'T_400', + 'T_50', 'T_500', 'T_600', 'T_700', 'T_850', 'T_925', + 'U_100', 'U_1000', 'U_10M', 'U_150', 'U_200', 'U_250', 'U_300', 'U_400', + 'U_50', 'U_500', 'U_600', 'U_700', 'U_850', 'U_925', + 'V_100', 'V_1000', 'V_10M', 'V_150', 'V_200', 'V_250', 'V_300', 'V_400', + 'V_50', 'V_500', 'V_600', 'V_700', 'V_850', 'V_925', + 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', + 'insolation', 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude'] +# ALL REALCH1 CHANNELS. +start: 2020-01-01 +# start: 2015-11-29 +end: 2020-01-01 +# end: 2020-12-31 From 7e06e91055329faccf2523a3a37dbeeb2943e778 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 21 Oct 2025 17:18:33 +0200 Subject: [PATCH 174/189] skeleton for realch1 interpolatoin task --- src/hirad/input_data/interpolate_realch1.py | 80 +++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 src/hirad/input_data/interpolate_realch1.py diff --git a/src/hirad/input_data/interpolate_realch1.py b/src/hirad/input_data/interpolate_realch1.py new file mode 100644 index 0000000..1334865 --- /dev/null +++ b/src/hirad/input_data/interpolate_realch1.py @@ -0,0 +1,80 @@ + + +import datetime +import logging +import os +import shutil +import sys +import yaml +import array + +from anemoi.datasets import open_dataset +from anemoi.datasets.data.dataset import Dataset +import netCDF4 +import numpy as np +from pandas import to_datetime +from scipy.interpolate import griddata +from meteodatalab.operators import regrid +import torch +import multiprocessing +import xarray + +# Margin to use for ERA dataset (to avoid nans from interpolation at boundary) +ERA_MARGIN_DEGREES = 1.0 +COPERNICUS_FILES = ['/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2015-2016.nc', + '/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2017-2018.nc', + '/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2019-2020.nc'] + +def _read_input(era_config_file: str, realch1_config_file: str, ) -> tuple[Dataset, Dataset, array.array]: + """ + Read both ERA and REA-L-CH1 data, and return the 2m + temperature values for the time range under COSMO. + """ + # trim edge removes boundary, we will use the same + with open(realch1_config_file) as realch1_file: + realch1_config = yaml.safe_load(realch1_file) + realch1 = open_dataset(realch1_config) + with open(era_config_file) as era_file: + era_config = yaml.safe_load(era_file) + era = open_dataset(era_config) + # Subset the ERA dataset to have REAL-CH-1 area/dates. + start_date = realch1.metadata()['start_date'] + end_date = realch1.metadata()['end_date'] + # load era5 2m-temperature in the time-range of cosmo + # area = N, W, S, E + min_lat = min(realch1.latitudes) - ERA_MARGIN_DEGREES + max_lat = max(realch1.latitudes) + ERA_MARGIN_DEGREES + min_lon = min(realch1.longitudes) - ERA_MARGIN_DEGREES + max_lon = max(realch1.longitudes) + ERA_MARGIN_DEGREES + era = open_dataset(era, start=start_date, end=end_date, + area=(max_lat, min_lon, min_lat, max_lon)) + + + copernicus_netcdf = [] + for f in COPERNICUS_FILES: + netcdf_data = netCDF4.Dataset(f) + copernicus_netcdf.append(netcdf_data) + + return (era, realch1, copernicus_netcdf) + + +def regrid_all(era: Dataset, realch1: Dataset, copernicus: array.array): + # iterate through the dates + realch1 = regrid_realch1 + # convert to xarray.dataarray + regrid.icon2rotlatlon + + pass + +def regrid_realch1(): + # Use the meteodatalab functions to regrid the realch1 anemoi data (one time point) + # onto the rotated lat lon + # save the output as torch + # return the np array + pass + +def regrid_era(): + # Take the output grid from realch1-regrid (rotated lat lon). + # regrid all variables *except* tp directly from era5 data + # regrid the pt variable from the netcdf data + # save the output as torch \ No newline at end of file From 1829e93486bab53896fa0ff664d67819a3d3418e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 22 Oct 2025 09:51:43 +0200 Subject: [PATCH 175/189] Enhance generation process with randomized sampler over time. --- src/hirad/conf/generation/era_cosmo.yaml | 8 +- src/hirad/conf/model_size/normal.yaml | 2 +- src/hirad/inference/generate.py | 9 +- src/hirad/inference/generator.py | 14 +- src/hirad/utils/deterministic_sampler.py | 341 ----------------------- src/hirad/utils/stochastic_sampler.py | 277 ------------------ 6 files changed, 27 insertions(+), 624 deletions(-) delete mode 100644 src/hirad/utils/deterministic_sampler.py delete mode 100644 src/hirad/utils/stochastic_sampler.py diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index 3396d43..a5302ff 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -6,7 +6,13 @@ inference_mode: all # Choose between "all" (regression + diffusion), "regression" or "diffusion" # Patch size. Patch-based sampling will be utilized if these dimensions differ from # img_shape_x and img_shape_y -# overlap_pixels: 0 +randomize: True + # Whether to randomize the random seeds for each generation. If false, fixed seeds + # from 0 to num_ensembles-1 will be used for each time step in times/times_range. +random_seed: 129 + # Base random seed. This is only used when randomize is True. + # random seed will be set for numpy random module to have reproducible randomized generative process. + # Number of overlapping pixels between adjacent patches # boundary_pixels: 0 # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary diff --git a/src/hirad/conf/model_size/normal.yaml b/src/hirad/conf/model_size/normal.yaml index 96c29fb..e746a2c 100644 --- a/src/hirad/conf/model_size/normal.yaml +++ b/src/hirad/conf/model_size/normal.yaml @@ -24,4 +24,4 @@ model_args: # Per-resolution multipliers for the number of channels. channel_mult: [1, 2, 2, 2, 2] # Resolutions at which self-attention layers are applied. - attn_resolutions: [28] \ No newline at end of file + attn_resolutions: [22] \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 3790820..f90f974 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -239,7 +239,14 @@ def elapsed_time(self, _): ) image_tar = image_tar.to(device=device).to(torch.float32) # image_out, image_reg = generate_fn(image_lr,lead_time_label) - image_out, image_reg = generator.generate(image_lr,lead_time_label) + random_seed = cfg.generation.get("random_seed", None)+index if cfg.generation.get("randomize", False) and cfg.generation.get("random_seed", None) is not None else None + # print(f"On rank {dist.rank} using base random seed: {random_seed} for time index {time_index}") + image_out, image_reg = generator.generate( + image_lr, + lead_time_label, + randomize=cfg.generation.get("randomize", False), + random_seed=random_seed + ) if dist.rank == 0: batch_size = image_out.shape[0] diff --git a/src/hirad/inference/generator.py b/src/hirad/inference/generator.py index ad051a9..d3c95c5 100644 --- a/src/hirad/inference/generator.py +++ b/src/hirad/inference/generator.py @@ -2,6 +2,7 @@ from functools import partial import nvtx import numpy as np +import random import torch from torch.distributed import gather from hirad.utils.inference_utils import regression_step, diffusion_step @@ -32,8 +33,9 @@ def __init__(self, self.get_rank_batches() self.patching = None - def get_rank_batches(self): - seeds = list(np.arange(self.ensemble_size)) + def get_rank_batches(self, seeds=None): + if seeds is None: + seeds = list(np.arange(self.ensemble_size)) num_batches = ( (len(seeds) - 1) // (self.batch_size * self.dist.world_size) + 1 ) * self.dist.world_size @@ -63,7 +65,7 @@ def initialize_patching(self, img_shape, patch_shape, boundary_pix, overlap_pix) overlap_pix=overlap_pix, ) - def generate(self, image_lr, lead_time_label=None): + def generate(self, image_lr, lead_time_label=None, randomize=False, random_seed=None): with nvtx.annotate("generate_fn", color="green"): # (1, C, H, W) img_shape = image_lr.shape[-2:] @@ -86,6 +88,12 @@ def generate(self, image_lr, lead_time_label=None): mean_hr = image_reg[0:1] else: mean_hr = None + if randomize: + # Set random seed for numpy + if random_seed is not None: + np.random.seed((random_seed) % (1 << 31)) + seeds = np.random.randint(0, 1<<31, size=self.ensemble_size) + self.get_rank_batches(seeds=seeds) with nvtx.annotate("diffusion model", color="purple"): image_res = diffusion_step( net=self.net_res, diff --git a/src/hirad/utils/deterministic_sampler.py b/src/hirad/utils/deterministic_sampler.py deleted file mode 100644 index e502875..0000000 --- a/src/hirad/utils/deterministic_sampler.py +++ /dev/null @@ -1,341 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Literal, Optional - -import numpy as np -import nvtx -import torch - -from hirad.models import EDMPrecond - -# ruff: noqa: E731 - - -@nvtx.annotate(message="deterministic_sampler", color="red") -def deterministic_sampler( - net: torch.nn.Module, - latents: torch.Tensor, - img_lr: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - randn_like: Callable = torch.randn_like, - num_steps: int = 18, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, - rho: float = 7.0, - solver: Literal["heun", "euler"] = "heun", - discretization: Literal["vp", "ve", "iddpm", "edm"] = "edm", - schedule: Literal["vp", "ve", "linear"] = "linear", - scaling: Literal["vp", "none"] = "none", - epsilon_s: float = 1e-3, - C_1: float = 0.001, - C_2: float = 0.008, - M: int = 1000, - alpha: float = 1.0, - S_churn: int = 0, - S_min: float = 0.0, - S_max: float = float("inf"), - S_noise: float = 1.0, -) -> torch.Tensor: - """ - Generalized sampler, representing the superset of all sampling methods - discussed in the paper "Elucidating the Design Space of Diffusion-Based - Generative Models" (EDM). - - https://arxiv.org/abs/2206.00364 - - This function integrates an ODE (probability flow) or SDE over multiple - time-steps to generate samples from the diffusion model provided by the - argument 'net'. It can be used to combine multiple choices to - design a custom sampler, including multiple integration solver, - discretization method, noise schedule, and so on. - - Parameters: - ----------- - net : torch.nn.Module - The diffusion model to use in the sampling process. - latents : torch.Tensor - The latent random noise used as the initial condition for the - stochastic ODE. - img_lr : torch.Tensor - Low-resolution input image for conditioning the diffusion process. - Passed as a keywork argument to the model 'net'. - class_labels : Optional[torch.Tensor] - Labels of the classes used as input to a class-conditionned - diffusion model. Passed as a keyword argument to the model 'net'. - If provided, it must be a tensor containing integer values. - Defaults to None, in which case it is ignored. - randn_like: Callable - Random Number Generator to generate random noise that is added - during the stochastic sampling. Must have the same signature as - torch.randn_like and return torch.Tensor. Defaults to - torch.randn_like. - num_steps : Optional[int] - Number of time-steps for the stochastic ODE integration. Defaults - to 18. - sigma_min : Optional[float] - Minimum noise level for the diffusion process. 'sigma_min', - 'sigma_max', and 'rho' are used to compute the time-step - discretization, based on the choice of discretization. For the - default choice ("discretization='heun'"), the noise level schedule - is computed as: - :math:`\sigma_i = (\sigma_{max}^{1/\rho} + i / (num_steps - 1) * (\sigma_{min}^{1/\rho} - \sigma_{max}^{1/\rho}))^{rho}`. - For other choices of 'discretization', see details in the EDM - paper. Defaults to None, in which case defaults values depending - of the specified discretization are used. - sigma_max : Optional[float] - Maximum noise level for the diffusion process. See sigma_min for - details. Defaults to None, in which case defaults values depending - of the specified discretization are used. - rho : float, optional - Exponent used in the noise schedule. See sigma_min for details. - Only used when 'discretization' is 'heun'. Values in the range [5, - 10] produce better images. Lower values lead to truncation errors - equalized over all time steps. Defaults to 7. - solver : Literal["heun", "euler"] - The numerical method used to integrate the stochastic ODE. "euler" - is 1st order solver, which is faster but produces lower-quality - images. "heun" is 2nd order, more expensive, but produces - higher-quality images. Defaults to "heun". - discretization : Literal["vp", "ve", "iddpm", "edm"] - The method to discretize time-steps :math:`t_i` in the - diffusion process. See the EDM papper for details. Defaults to - "edm". - schedule : Literal["vp", "ve", "linear"] - The type of noise level schedule. Defaults to "linear". If - schedule='ve', then :math:`\sigma(t) = \sqrt{t}`. If - schedule='linear', then :math:`\sigma(t) = t`. If schedule='vp', - see EDM paper for details. Defaults to "linear". - scaling : Literal["vp", "none"] - The type of time-dependent signal scaling :math:`s(t)`, such that - :math:`x = s(t) \hat{x}`. See EDM paper for details on the 'vp' - scaling. Defaults to 'none', in which case :math:`s(t)=1`. - epsilon_s : float, optional - Parameter to compute both the noise level schedule and the - time-step discetization. Only used when discretization='vp' or - schedule='vp'. Ignored in other cases. Defaults to 1e-3. - C_1 : float, optional - Parameters to compute the time-step discetization. Only used when - discretization='iddpm'. Defaults to 0.001. - C_2 : float, optional - Same as for C_1. Only used when discretization='iddpm'. Defaults to - 0.008. - M : int, optional - Same as for C_1 and C_2. Only used when discretization='iddpm'. - Defaults to 1000. - alpha : float, optional - Controls (i.e. multiplies) the step size :math:`t_{i+1} - - \hat{t}_i` in the stochastic sampler, where :math:`\hat{t}_i` is - the temporarily increased noise level. Defaults to 1.0, which is - the recommended value. - S_churn : int, optional - Controls the amount of stochasticty injected in the SDE in the - stochatsic sampler. Larger values of S_churn lead to larger values - of :math:`\hat{t}_i`, which in turn lead to injecting more - stochasticity in the SDE by Defaults to 0, which means no - stochasticity is injected. - S_min : float, optional - S_min and S_max control the time-step range obver which - stochasticty is injected in the SDE. Stochasticity is injected - through `\hat{t}_i` for time-steps :math:`t_i` such that - :math:`S_{min} \leq t_i \leq S_{max}`. Defaults to 0.0. - S_max : float, optional - See S_min. Defaults to float("inf"). - S_noise : float, optional - Controls the amount of stochasticty injected in the SDE in the - stochatsic sampler. Added signal noise is proportinal to - :math:`\epsilon_i` where `\epsilon_i ~ N(0, S_{noise}^2)`. Defaults - to 1.0. - - Returns - ------- - torch.Tensor: - Generated batch of samples. Same shape as the input 'latents'. - """ - - # conditioning - x_lr = img_lr - - if solver not in ["euler", "heun"]: - raise ValueError(f"Unknown solver {solver}") - if discretization not in ["vp", "ve", "iddpm", "edm"]: - raise ValueError(f"Unknown discretization {discretization}") - if schedule not in ["vp", "ve", "linear"]: - raise ValueError(f"Unknown schedule {schedule}") - if scaling not in ["vp", "none"]: - raise ValueError(f"Unknown scaling {scaling}") - - # Helper functions for VP & VE noise level schedules. - vp_sigma = ( - lambda beta_d, beta_min: lambda t: ( - np.e ** (0.5 * beta_d * (t**2) + beta_min * t) - 1 - ) - ** 0.5 - ) - vp_sigma_deriv = ( - lambda beta_d, beta_min: lambda t: 0.5 - * (beta_min + beta_d * t) - * (sigma(t) + 1 / sigma(t)) - ) - vp_sigma_inv = ( - lambda beta_d, beta_min: lambda sigma: ( - (beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min - ) - / beta_d - ) - ve_sigma = lambda t: t.sqrt() - ve_sigma_deriv = lambda t: 0.5 / t.sqrt() - ve_sigma_inv = lambda sigma: sigma**2 - - # Select default noise level range based on the specified time step discretization. - if sigma_min is None: - vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) - sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ - discretization - ] - if sigma_max is None: - vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1) - sigma_max = {"vp": vp_def, "ve": 100, "iddpm": 81, "edm": 80}[discretization] - - # Adjust noise levels based on what's supported by the network. - sigma_min = max(sigma_min, net.sigma_min) - sigma_max = min(sigma_max, net.sigma_max) - - # Compute corresponding betas for VP. - vp_beta_d = ( - 2 - * (np.log(sigma_min**2 + 1) / epsilon_s - np.log(sigma_max**2 + 1)) - / (epsilon_s - 1) - ) - vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d - - # Define time steps in terms of noise level. - step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) - if discretization == "vp": - orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) - sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) - elif discretization == "ve": - orig_t_steps = (sigma_max**2) * ( - (sigma_min**2 / sigma_max**2) ** (step_indices / (num_steps - 1)) - ) - sigma_steps = ve_sigma(orig_t_steps) - elif discretization == "iddpm": - u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) - alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 - for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 - u[j - 1] = ( - (u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1 - ).sqrt() - u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] - sigma_steps = u_filtered[ - ((len(u_filtered) - 1) / (num_steps - 1) * step_indices) - .round() - .to(torch.int64) - ] - else: - sigma_steps = ( - sigma_max ** (1 / rho) - + step_indices - / (num_steps - 1) - * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) - ) ** rho - - # Define noise level schedule. - if schedule == "vp": - sigma = vp_sigma(vp_beta_d, vp_beta_min) - sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) - sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) - elif schedule == "ve": - sigma = ve_sigma - sigma_deriv = ve_sigma_deriv - sigma_inv = ve_sigma_inv - else: - sigma = lambda t: t - sigma_deriv = lambda t: 1 - sigma_inv = lambda sigma: sigma - - # Define scaling schedule. - if scaling == "vp": - s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() - s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) - else: - s = lambda t: 1 - s_deriv = lambda t: 0 - - # Compute final time steps based on the corresponding noise levels. - t_steps = sigma_inv(net.round_sigma(sigma_steps)) - t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 - - # Main sampling loop. - t_next = t_steps[0] - x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) - for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 - x_cur = x_next - - # Increase noise temporarily. - gamma = ( - min(S_churn / num_steps, np.sqrt(2) - 1) - if S_min <= sigma(t_cur) <= S_max - else 0 - ) - t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) - x_hat = s(t_hat) / s(t_cur) * x_cur + ( - sigma(t_hat) ** 2 - sigma(t_cur) ** 2 - ).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur) - - # Euler step. - h = t_next - t_hat - if isinstance(net, EDMPrecond): - # Conditioning info is passed as keyword arg - denoised = net( - x_hat / s(t_hat), - sigma(t_hat), - condition=x_lr, - class_labels=class_labels, - ).to(torch.float64) - else: - denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to( - torch.float64 - ) - d_cur = ( - sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) - ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised - x_prime = x_hat + alpha * h * d_cur - t_prime = t_hat + alpha * h - - # Apply 2nd order correction. - if solver == "euler" or i == num_steps - 1: - x_next = x_hat + h * d_cur - else: - if isinstance(net, EDMPrecond): - # Conditioning info is passed as keyword arg - denoised = net( - x_prime / s(t_prime), - sigma(t_prime), - condition=x_lr, - class_labels=class_labels, - ).to(torch.float64) - else: - denoised = net( - x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels - ).to(torch.float64) - d_prime = ( - sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) - ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised - x_next = x_hat + h * ( - (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime - ) - - return x_next diff --git a/src/hirad/utils/stochastic_sampler.py b/src/hirad/utils/stochastic_sampler.py deleted file mode 100644 index 198fde4..0000000 --- a/src/hirad/utils/stochastic_sampler.py +++ /dev/null @@ -1,277 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Callable, Optional - -import torch -from torch import Tensor - -from hirad.utils.patching import GridPatching2D - - -def stochastic_sampler( - net: torch.nn.Module, - latents: torch.Tensor, - img_lr: torch.Tensor, - class_labels: Optional[Tensor] = None, - randn_like: Callable[[Tensor], Tensor] = torch.randn_like, - patching: Optional[GridPatching2D] = None, - mean_hr: Optional[torch.Tensor] = None, - lead_time_label: Optional[torch.Tensor] = None, - num_steps: int = 18, - sigma_min: float = 0.002, - sigma_max: float = 800, - rho: float = 7, - S_churn: float = 0, - S_min: float = 0, - S_max: float = float("inf"), - S_noise: float = 1, -) -> torch.Tensor: - """ - Proposed EDM sampler (Algorithm 2) with minor changes to enable - super-resolution and patch-based diffusion. - - Parameters - ---------- - net : torch.nn.Module - The neural network model that generates denoised images from noisy - inputs. - Expected signature: `net(x, x_lr, t_hat, class_labels, - lead_time_label=lead_time_label, embedding_selector=embedding_selector)`, - where: - x (torch.Tensor): Noisy input of shape (batch_size, C_out, H, W) - x_lr (torch.Tensor): Conditioning input of shape (batch_size, C_cond, H, W) - t_hat (torch.Tensor): Noise level of shape (batch_size, 1, 1, 1) or scalar - class_labels (torch.Tensor, optional): Optional class labels - lead_time_label (torch.Tensor, optional): Optional lead time labels - embedding_selector (callable, optional): Function to select - positional embeddings. Used for patch-based diffusion. - Returns: - torch.Tensor: Denoised prediction of shape (batch_size, C_out, H, W) - - Required attributes: - sigma_min (float): Minimum supported noise level for the model - sigma_max (float): Maximum supported noise level for the model - round_sigma (callable): Method to convert sigma values to tensor representation - latents : Tensor - The latent variables (e.g., noise) used as the initial input for the - sampler. Has shape (batch_size, C_out, img_shape_y, img_shape_x). - img_lr : Tensor - Low-resolution input image for conditioning the super-resolution - process. Must have shape (batch_size, C_lr, img_lr_ shape_y, - img_lr_shape_x). - class_labels : Optional[Tensor], optional - Class labels for conditional generation, if required by the model. By - default None. - randn_like : Callable[[Tensor], Tensor] - Function to generate random noise with the same shape as the input - tensor. - By default torch.randn_like. - patching : Optional[GridPatching2D], optional - A patching utility for patch-based diffusion. Implements methods to - extract patches from an image and batch the patches along `dim=0`. - Should also implement a `fuse` method to reconstruct the original image - from a batch of patches. See - :class:`physicsnemo.utils.patching.GridPatching2D` for details. By - default None, in which case non-patched diffusion is used. - mean_hr : Optional[Tensor], optional - Optional tensor containing mean high-resolution images for - conditioning. Must have same height and width as `img_lr`, with shape - (B_hr, C_hr, img_lr_shape_y, img_lr_shape_x) where the batch dimension - B_hr can be either 1, either equal to batch_size, or can be omitted. If - B_hr = 1 or is omitted, `mean_hr` will be expanded to match the shape - of `img_lr`. By default None. - lead_time_label : Optional[Tensor], optional - Optional lead time labels. By default None. - num_steps : int - Number of time steps for the sampler. By default 18. - sigma_min : float - Minimum noise level. By default 0.002. - sigma_max : float - Maximum noise level. By default 800. - rho : float - Exponent used in the time step discretization. By default 7. - S_churn : float - Churn parameter controlling the level of noise added in each step. By - default 0. - S_min : float - Minimum time step for applying churn. By default 0. - S_max : float - Maximum time step for applying churn. By default float("inf"). - S_noise : float - Noise scaling factor applied during the churn step. By default 1. - - Returns - ------- - Tensor - The final denoised image produced by the sampler. Same shape as - `latents`: (batch_size, C_out, img_shape_y, img_shape_x). - - See Also - -------- - :class:`physicsnemo.models.diffusion.EDMPrecondSuperResolution`: A model - wrapper that provides preconditioning for super-resolution diffusion - models and implements the required interface for this sampler. - """ - - # Adjust noise levels based on what's supported by the network. - # Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution. - sigma_min = max(sigma_min, net.sigma_min) - sigma_max = min(sigma_max, net.sigma_max) - - if patching is not None and not isinstance(patching, GridPatching2D): - raise ValueError("patching must be an instance of GridPatching2D.") - - # Safety check: if patching is used then img_lr and latents must have same - # height and width, otherwise there is mismatch in the number - # of patches extracted to form the final batch_size. - if patching: - if img_lr.shape[-2:] != latents.shape[-2:]: - raise ValueError( - f"img_lr and latents must have the same height and width, " - f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. " - ) - # img_lr and latents must also have the same batch_size, otherwise mismatch - # when processed by the network - if img_lr.shape[0] != latents.shape[0]: - raise ValueError( - f"img_lr and latents must have the same batch size, but found " - f"{img_lr.shape[0]} vs {latents.shape[0]}." - ) - - # Time step discretization. - step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) - t_steps = ( - sigma_max ** (1 / rho) - + step_indices - / (num_steps - 1) - * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) - ) ** rho - t_steps = torch.cat( - [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] - ) # t_N = 0 - - batch_size = img_lr.shape[0] - - # conditioning = [mean_hr, img_lr, global_lr, pos_embd] - x_lr = img_lr - if mean_hr is not None: - if mean_hr.shape[-2:] != img_lr.shape[-2:]: - raise ValueError( - f"mean_hr and img_lr must have the same height and width, " - f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}." - ) - x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) - - # input and position padding + patching - if patching: - # Patched conditioning [x_lr, mean_hr] - # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) - x_lr = patching.apply(input=x_lr, additional_input=img_lr) - - # Function to select the correct positional embedding for each patch - def patch_embedding_selector(emb): - # emb: (N_pe, image_shape_y, image_shape_x) - # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) - return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) - - else: - patch_embedding_selector = None - - # Main sampling loop. - x_next = latents.to(torch.float64) * t_steps[0] - for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 - x_cur = x_next - # Increase noise temporarily. - gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 - t_hat = net.round_sigma(t_cur + gamma * t_cur) - - x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) - - # Euler step. Perform patching operation on score tensor if patch-based - # generation is used denoised = net(x_hat, t_hat, - # class_labels,lead_time_label=lead_time_label).to(torch.float64) - - x_hat_batch = (patching.apply(input=x_hat) if patching else x_hat).to( - latents.device - ) - x_lr = x_lr.to(latents.device) - - if lead_time_label is not None: - denoised = net( - x_hat_batch, - x_lr, - t_hat, - class_labels, - lead_time_label=lead_time_label, - embedding_selector=patch_embedding_selector, - ).to(torch.float64) - else: - # print("Sizes") - # print(x_hat_batch.shape) - # print(x_lr.shape) - # print(t_hat) - # print(class_labels) - # print(global_index) - denoised = net( - x_hat_batch, - x_lr, - t_hat, - class_labels, - embedding_selector=patch_embedding_selector, - ).to(torch.float64) - if patching: - # Un-patch the denoised image - # (batch_size, C_out, img_shape_y, img_shape_x) - denoised = patching.fuse(input=denoised, batch_size=batch_size) - - d_cur = (x_hat - denoised) / t_hat - x_next = x_hat + (t_next - t_hat) * d_cur - - # Apply 2nd order correction. - if i < num_steps - 1: - # Patched input - # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x) - x_next_batch = (patching.apply(input=x_next) if patching else x_next).to( - latents.device - ) - - if lead_time_label is not None: - denoised = net( - x_next_batch, - x_lr, - t_next, - class_labels, - lead_time_label=lead_time_label, - embedding_selector=patch_embedding_selector, - ).to(torch.float64) - else: - denoised = net( - x_next_batch, - x_lr, - t_next, - class_labels, - embedding_selector=patch_embedding_selector, - ).to(torch.float64) - if patching: - # Un-patch the denoised image - # (batch_size, C_out, img_shape_y, img_shape_x) - denoised = patching.fuse(input=denoised, batch_size=batch_size) - - d_prime = (x_next - denoised) / t_next - x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) - return x_next From ae6d8ffb836e0878eb84eb26e668026300fdf6e5 Mon Sep 17 00:00:00 2001 From: David Leutwyler <14977216+leuty@users.noreply.github.com> Date: Mon, 27 Oct 2025 11:53:47 +0100 Subject: [PATCH 176/189] Wind eval streaming (#24) --- src/hirad/eval/map_wind_stats.py | 454 ++++++++++++++++--------------- src/hirad/eval_wind.sh | 2 +- 2 files changed, 242 insertions(+), 214 deletions(-) diff --git a/src/hirad/eval/map_wind_stats.py b/src/hirad/eval/map_wind_stats.py index ee6855a..ac63224 100644 --- a/src/hirad/eval/map_wind_stats.py +++ b/src/hirad/eval/map_wind_stats.py @@ -1,12 +1,10 @@ import logging -from datetime import datetime from pathlib import Path import hydra import numpy as np import torch from omegaconf import DictConfig, OmegaConf -import xarray as xr from hirad.datasets import get_dataset_and_sampler_inference from hirad.distributed import DistributedManager @@ -15,150 +13,123 @@ def compute_wind_speed(u, v): - """Compute wind speed from U and V components.""" + """Compute wind speed from U and V.""" return np.hypot(u, v) -def compute_wind_direction(u, v): - """Compute wind direction in degrees (meteorological convention: direction FROM which wind blows). - - Returns angle in degrees: 0° = North, 90° = East, 180° = South, 270° = West - """ - # atan2(u, v) gives mathematical angle, convert to meteorological - direction = np.rad2deg(np.arctan2(u, v)) + 180 - direction = np.mod(direction, 360) - return direction - - -def circular_mean_direction(directions, weights=None): - """Calculate circular mean of wind directions in degrees. - - Args: - directions: Array of directions in degrees - weights: Optional weights (e.g., wind speeds) - """ - # Convert to radians - rad = np.deg2rad(directions) - - # Calculate weighted mean of sin and cos components - if weights is not None: - sin_mean = np.average(np.sin(rad), weights=weights, axis=0) - cos_mean = np.average(np.cos(rad), weights=weights, axis=0) - else: - sin_mean = np.mean(np.sin(rad), axis=0) - cos_mean = np.mean(np.cos(rad), axis=0) - - # Calculate mean direction - mean_dir = np.arctan2(sin_mean, cos_mean) - mean_dir_deg = np.rad2deg(mean_dir) - mean_dir_deg = np.mod(mean_dir_deg, 360) - - return mean_dir_deg - - -def circular_std(directions): - """Calculate circular standard deviation of wind directions in degrees. - - Returns values from 0 (perfect alignment) to ~81.03 degrees (uniform distribution) - """ - rad = np.deg2rad(directions) - - # Calculate mean resultant length - sin_mean = np.mean(np.sin(rad), axis=0) - cos_mean = np.mean(np.cos(rad), axis=0) - R = np.hypot(sin_mean, cos_mean) - - # Circular standard deviation - # Handle R=0 case to avoid log(0) - R = np.clip(R, 1e-10, 1.0) - circ_std = np.rad2deg(np.sqrt(-2 * np.log(R))) - - return circ_std +def compute_wind_direction(u, v, calm_threshold=0.0): + """Compute wind direction in degrees from N.""" + dir_deg = (np.degrees(np.arctan2(-u, -v)) % 360) + if calm_threshold > 0: + speed = np.hypot(u, v) + dir_deg = np.where(speed <= calm_threshold, np.nan, dir_deg) + return dir_deg -def apply_wind_statistic(u_data, v_data, stat_type, stat_param=None): - """Apply a wind statistic to U and V components along the time dimension. +def apply_wind_statistic_streaming(times, out_root, mode, u_channel, v_channel, stat_type, stat_param=None): + """Compute wind statistic by streaming through timesteps.""" + accumulator = None + count = 0 + sin_acc = cos_acc = speed_acc = None - Args: - u_data: xarray DataArray of U wind component - v_data: xarray DataArray of V wind component - stat_type: Type of statistic to compute - stat_param: Optional parameter for the statistic (e.g., quantile value) - - Returns: - Result as numpy array - """ - # Compute wind speed - speed = compute_wind_speed(u_data.values, v_data.values) - - if stat_type == 'mean_speed': - return np.mean(speed, axis=0) - - if stat_type == 'quantile_speed': - return np.quantile(speed, stat_param, axis=0) - - if stat_type == 'max_speed': - return np.max(speed, axis=0) - - if stat_type == 'wind_power': - # Wind power density is proportional to cube of wind speed - return np.mean(speed**3, axis=0) - - if stat_type == 'calm_freq': - # Frequency of calm conditions (< 2 m/s, Beaufort 0-1) - calm_threshold = 2.0 - return np.mean(speed < calm_threshold, axis=0) * 100 - - if stat_type == 'light_breeze_freq': - # Frequency of light breeze (> 1.6 m/s, Beaufort 2+) - light_breeze_threshold = 1.6 - return np.mean(speed > light_breeze_threshold, axis=0) * 100 - - if stat_type == 'moderate_breeze_freq': - # Frequency of moderate breeze (> 5.5 m/s, Beaufort 4+) - moderate_breeze_threshold = 5.5 - return np.mean(speed > moderate_breeze_threshold, axis=0) * 100 - - if stat_type == 'strong_breeze_freq': - # Frequency of strong breeze (> 10.8 m/s, Beaufort 6+) - strong_breeze_threshold = 10.8 - return np.mean(speed > strong_breeze_threshold, axis=0) * 100 - - if stat_type == 'gale_freq': - # Frequency of fresh gale (> 17.2 m/s, Beaufort 8+) - gale_threshold = 17.2 - return np.mean(speed > gale_threshold, axis=0) * 100 + for ts in times: + data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) + u = data[u_channel].cpu().numpy() if torch.is_tensor(data[u_channel]) else data[u_channel] + v = data[v_channel].cpu().numpy() if torch.is_tensor(data[v_channel]) else data[v_channel] + + if stat_type == 'mean_speed': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.zeros_like(speed) + accumulator += speed + elif stat_type == 'max_speed': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.full_like(speed, -np.inf) + accumulator = np.maximum(accumulator, speed) + elif stat_type == 'wind_power': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.zeros_like(speed) + accumulator += speed**3 + elif stat_type == 'mean_u': + if accumulator is None: + accumulator = np.zeros_like(u) + accumulator += u + elif stat_type == 'mean_v': + if accumulator is None: + accumulator = np.zeros_like(v) + accumulator += v + elif stat_type in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq', + 'strong_breeze_freq', 'gale_freq']: + speed = compute_wind_speed(u, v) + thresholds = { + 'calm_freq': 2.0, + 'light_breeze_freq': 1.6, + 'moderate_breeze_freq': 5.5, + 'strong_breeze_freq': 10.8, + 'gale_freq': 17.2 + } + threshold = thresholds[stat_type] + if accumulator is None: + accumulator = np.zeros_like(speed) + if stat_type == 'calm_freq': + accumulator += (speed < threshold).astype(float) + else: + accumulator += (speed > threshold).astype(float) + elif stat_type == 'prevailing_direction': + speed = compute_wind_speed(u, v) + direction = compute_wind_direction(u, v, calm_threshold=1.0) + rad = np.deg2rad(direction) + weighted_sin = np.sin(rad) * speed + weighted_cos = np.cos(rad) * speed + + if sin_acc is None: + sin_acc = np.zeros_like(weighted_sin) + cos_acc = np.zeros_like(weighted_cos) + speed_acc = np.zeros_like(speed) + + sin_acc += np.nan_to_num(weighted_sin, 0) + cos_acc += np.nan_to_num(weighted_cos, 0) + speed_acc += speed + elif stat_type == 'direction_variability': + direction = compute_wind_direction(u, v) + rad = np.deg2rad(direction) + + if sin_acc is None: + sin_acc = np.zeros_like(np.sin(rad)) + cos_acc = np.zeros_like(np.cos(rad)) + + sin_acc += np.sin(rad) + cos_acc += np.cos(rad) + + count += 1 + del data, u, v if stat_type == 'prevailing_direction': - # Compute wind directions - direction = compute_wind_direction(u_data.values, v_data.values) - # Weight by wind speed for more meaningful prevailing direction - return circular_mean_direction(direction, weights=speed) - - if stat_type == 'direction_variability': - # Circular standard deviation of wind direction - direction = compute_wind_direction(u_data.values, v_data.values) - return circular_std(direction) - - if stat_type == 'mean_u': - return np.mean(u_data.values, axis=0) - - if stat_type == 'mean_v': - return np.mean(v_data.values, axis=0) - - raise ValueError(f"Unsupported wind statistic type: {stat_type}") + mean_dir = np.arctan2(sin_acc / (speed_acc + 1e-10), cos_acc / (speed_acc + 1e-10)) + return np.mod(np.rad2deg(mean_dir), 360) + elif stat_type == 'direction_variability': + R = np.clip(np.hypot(sin_acc / count, cos_acc / count), 1e-10, 1.0) + return np.rad2deg(np.sqrt(-2 * np.log(R))) + elif stat_type == 'max_speed': + return accumulator + elif stat_type in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq', + 'strong_breeze_freq', 'gale_freq']: + return (accumulator / count) * 100 + else: + return accumulator / count def plot_wind_stat_map(data, filename, stat_config, label): - """Plot a single wind statistic map with appropriate styling.""" - + """Plot wind statistic map.""" if stat_config['type'] == 'mean_speed': plot_map( data, filename, title=f'{label}: {stat_config["title_stat"]}', label='Wind Speed [m/s]', vmin=0, vmax=10, cmap='inferno', extend='max' ) - elif stat_config['type'] in ['quantile_speed', 'max_speed']: + elif stat_config['type'] == 'max_speed': plot_map( data, filename, title=f'{label}: {stat_config["title_stat"]}', @@ -192,7 +163,7 @@ def plot_wind_stat_map(data, filename, stat_config, label): plot_map( data, filename, title=f'{label}: {stat_config["title_stat"]}', - label='Wind Component [m/s]', vmin=-10, vmax=10, cmap='RdBu_r', extend='both' + label='Wind Component [m/s]', vmin=-5, vmax=5, cmap='RdBu_r', extend='both' ) else: plot_map( @@ -204,7 +175,6 @@ def plot_wind_stat_map(data, filename, stat_config, label): @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig): - # Setup and config DistributedManager.initialize() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -220,24 +190,17 @@ def main(cfg: DictConfig): out_root = Path(cfg.generation.io.output_path or './outputs') indices = get_channel_indices(dataset) - # Get U and V wind component indices u10_out = indices['output'].get('10u') v10_out = indices['output'].get('10v') u10_in = indices['input'].get('10u', u10_out) v10_in = indices['input'].get('10v', v10_out) - - # Wind statistic configurations + WIND_STATISTICS_CONFIG = { 'mean_speed': { 'type': 'mean_speed', 'title': 'Mean Wind Speed' }, - 'p90_speed': { - 'type': 'quantile_speed', - 'param': 0.90, - 'title': '90th Percentile Wind Speed' - }, 'max_speed': { 'type': 'max_speed', 'title': 'Maximum Wind Speed' @@ -294,111 +257,176 @@ def main(cfg: DictConfig): for name, config in WIND_STATISTICS_CONFIG.items() ] - # Target and baseline modes basic_modes = { 'target': ((u10_out, v10_out), 'COSMO-2 Analysis'), 'baseline': ((u10_in, v10_in), 'ERA5'), 'regression-prediction': ((u10_out, v10_out), 'Regression Prediction') } - logger.info(f"Generating {len(stat_configs)} wind statistics for {len(basic_modes)} basic modes + predictions") + logger.info(f"Generating {len(stat_configs)} statistics for {len(basic_modes)} modes + predictions") for mode, (wind_channels, label) in basic_modes.items(): logger.info(f"Processing mode: {mode}") u_channel, v_channel = wind_channels - # Load all timesteps for this mode - u_data_list = [] - v_data_list = [] try: - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Loading {mode} timestep {i+1}/{len(times)}: {ts}") - data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False) - u_data_list.append(data[u_channel]) - v_data_list.append(data[v_channel]) + test_data = torch.load(out_root/times[0]/f"{times[0]}-{mode}", weights_only=False) + del test_data except Exception as e: - logger.warning(f"{mode} not available, skipping: {e}") + logger.warning(f"{mode} not available: {e}") continue - # Create xarray DataArrays - u_mode_data = xr.DataArray( - np.stack(u_data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} - ) - v_mode_data = xr.DataArray( - np.stack(v_data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} - ) - - # Compute and plot all statistics for this mode for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for {mode}...") - result = apply_wind_statistic( - u_mode_data, v_mode_data, - stat_config['type'], stat_config['param'] - ) - - map_output_dir = out_root / f"maps_wind_{stat_config['stat_name']}" - map_output_dir.mkdir(parents=True, exist_ok=True) - plot_wind_stat_map( - result, - str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), - stat_config, - label - ) + try: + result = apply_wind_statistic_streaming( + times, out_root, mode, u_channel, v_channel, + stat_config['type'], stat_config.get('param') + ) + + map_output_dir = out_root / f"maps_wind_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + plot_wind_stat_map( + result, + str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), + stat_config, + label + ) + del result + except Exception as e: + logger.error(f"Failed {stat_config['title_stat']} for {mode}: {e}") + continue - # Predictions mode: process each member separately to save memory logger.info("Processing predictions mode...") try: data = torch.load(out_root/times[0]/f"{times[0]}-predictions", weights_only=False) n_members = data.shape[0] + del data logger.info(f"Found {n_members} ensemble members") for member_idx in range(n_members): - logger.info(f"Processing prediction member {member_idx+1}/{n_members}") - - # Load all timesteps for this member - u_data_list = [] - v_data_list = [] - for i, ts in enumerate(times): - if i % LOG_INTERVAL == 0: - logger.info(f"Loading prediction member {member_idx} timestep {i+1}/{len(times)}: {ts}") - pred_data = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) - u_data_list.append(pred_data[member_idx, u10_out]) - v_data_list.append(pred_data[member_idx, v10_out]) - - u_member_data = xr.DataArray( - np.stack(u_data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} - ) - v_member_data = xr.DataArray( - np.stack(v_data_list, axis=0), - dims=['time', 'lat', 'lon'], - coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]} - ) + logger.info(f"Processing member {member_idx+1}/{n_members}") - # Compute and plot all statistics for this member for stat_config in stat_configs: logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...") - member_result = apply_wind_statistic( - u_member_data, v_member_data, - stat_config['type'], stat_config['param'] - ) + try: + def load_member_data(ts): + pred_data = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) + u_data = pred_data[member_idx, u10_out] + v_data = pred_data[member_idx, v10_out] + u = u_data.cpu().numpy() if torch.is_tensor(u_data) else u_data + v = v_data.cpu().numpy() if torch.is_tensor(v_data) else v_data + del pred_data + return u, v + + accumulator = None + count = 0 + sin_acc = cos_acc = speed_acc = None + + for i, ts in enumerate(times): + if i % LOG_INTERVAL == 0: + logger.info(f"Loading prediction member {member_idx} timestep {i+1}/{len(times)}: {ts}") + + u, v = load_member_data(ts) + + if stat_config['type'] == 'mean_speed': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.zeros_like(speed) + accumulator += speed + elif stat_config['type'] == 'max_speed': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.full_like(speed, -np.inf) + accumulator = np.maximum(accumulator, speed) + elif stat_config['type'] == 'wind_power': + speed = compute_wind_speed(u, v) + if accumulator is None: + accumulator = np.zeros_like(speed) + accumulator += speed**3 + elif stat_config['type'] == 'mean_u': + if accumulator is None: + accumulator = np.zeros_like(u) + accumulator += u + elif stat_config['type'] == 'mean_v': + if accumulator is None: + accumulator = np.zeros_like(v) + accumulator += v + elif stat_config['type'] in ['calm_freq', 'light_breeze_freq', + 'moderate_breeze_freq', 'strong_breeze_freq', + 'gale_freq']: + speed = compute_wind_speed(u, v) + thresholds = { + 'calm_freq': 2.0, + 'light_breeze_freq': 1.6, + 'moderate_breeze_freq': 5.5, + 'strong_breeze_freq': 10.8, + 'gale_freq': 17.2 + } + threshold = thresholds[stat_config['type']] + if accumulator is None: + accumulator = np.zeros_like(speed) + if stat_config['type'] == 'calm_freq': + accumulator += (speed < threshold).astype(float) + else: + accumulator += (speed > threshold).astype(float) + elif stat_config['type'] == 'prevailing_direction': + speed = compute_wind_speed(u, v) + direction = compute_wind_direction(u, v, calm_threshold=1.0) + rad = np.deg2rad(direction) + weighted_sin = np.sin(rad) * speed + weighted_cos = np.cos(rad) * speed + + if sin_acc is None: + sin_acc = np.zeros_like(weighted_sin) + cos_acc = np.zeros_like(weighted_cos) + speed_acc = np.zeros_like(speed) + + sin_acc += np.nan_to_num(weighted_sin, 0) + cos_acc += np.nan_to_num(weighted_cos, 0) + speed_acc += speed + elif stat_config['type'] == 'direction_variability': + direction = compute_wind_direction(u, v) + rad = np.deg2rad(direction) + + if sin_acc is None: + sin_acc = np.zeros_like(np.sin(rad)) + cos_acc = np.zeros_like(np.cos(rad)) + + sin_acc += np.sin(rad) + cos_acc += np.cos(rad) + + count += 1 + del u, v + + if stat_config['type'] == 'prevailing_direction': + mean_dir = np.arctan2(sin_acc / (speed_acc + 1e-10), cos_acc / (speed_acc + 1e-10)) + member_result = np.mod(np.rad2deg(mean_dir), 360) + elif stat_config['type'] == 'direction_variability': + R = np.clip(np.hypot(sin_acc / count, cos_acc / count), 1e-10, 1.0) + member_result = np.rad2deg(np.sqrt(-2 * np.log(R))) + elif stat_config['type'] == 'max_speed': + member_result = accumulator + elif stat_config['type'] in ['calm_freq', 'light_breeze_freq', + 'moderate_breeze_freq', 'strong_breeze_freq', + 'gale_freq']: + member_result = (accumulator / count) * 100 + else: + member_result = accumulator / count + + map_output_dir = out_root / f"maps_wind_{stat_config['stat_name']}" + map_output_dir.mkdir(parents=True, exist_ok=True) + member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') + plot_wind_stat_map(member_result, member_filename, stat_config, f'CorrDiff Member {member_idx+1}') + del member_result - # Create map - map_output_dir = out_root / f"maps_wind_{stat_config['stat_name']}" - map_output_dir.mkdir(parents=True, exist_ok=True) - member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}') - member_label = f'CorrDiff Member {member_idx+1}' - plot_wind_stat_map(member_result, member_filename, stat_config, member_label) + except Exception as e: + logger.error(f"Failed {stat_config['title_stat']} for member {member_idx+1}: {e}") + continue except Exception as e: - logger.warning(f"Predictions not available, skipping: {e}") + logger.warning(f"Predictions not available: {e}") - logger.info("All wind statistics maps generated successfully") + logger.info("Wind statistics generation complete") if __name__ == '__main__': diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh index 37194a6..df113e8 100644 --- a/src/hirad/eval_wind.sh +++ b/src/hirad/eval_wind.sh @@ -8,7 +8,7 @@ #SBATCH --ntasks-per-node=2 #SBATCH --gpus-per-node=2 #SBATCH --cpus-per-task=72 -#SBATCH --time=6:00:00 +#SBATCH --time=12:00:00 #SBATCH --no-requeue #SBATCH --exclusive From 68c3ae810278344ca9ea6bc04815b03ba7ca4398 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 28 Oct 2025 14:52:13 +0100 Subject: [PATCH 177/189] some updates to regridding (still not working) --- src/hirad/input_data/interpolate_basic.py | 12 ++-- src/hirad/input_data/regrid_realch1.py | 73 +++++++++++++++-------- 2 files changed, 55 insertions(+), 30 deletions(-) diff --git a/src/hirad/input_data/interpolate_basic.py b/src/hirad/input_data/interpolate_basic.py index 604d5e4..b68219b 100644 --- a/src/hirad/input_data/interpolate_basic.py +++ b/src/hirad/input_data/interpolate_basic.py @@ -158,19 +158,19 @@ def _get_plot_indices(era: Dataset, cosmo: Dataset) -> np.ndarray[np.intp]: indices = np.where(box_lon*box_lat) return indices -def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None): - p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) +def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None, s = None): + p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax, s=s) ax.coastlines() - ax.gridlines(draw_labels=False) + ax.gridlines(draw_labels=True) plt.colorbar(p, orientation="horizontal") -def plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): +def plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, projection=ccrs.PlateCarree(), cmap=None, vmin = None, vmax = None, s = None): """Plot observed or interpolated data in a scatter plot.""" # TODO: Refactor this somehow, it's not really generalizing well across variables. fig = plt.figure() - fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + fig, ax = plt.subplots(subplot_kw={"projection": projection}) logging.info(f'plotting values to {filename}') - plot_projection(ax, longitudes, latitudes, values, cmap, vmin, vmax) + plot_projection(ax, longitudes, latitudes, values, cmap, vmin, vmax, s) #p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) #ax.coastlines() #ax.gridlines(draw_labels=True) diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py index e81d12b..6f4c680 100644 --- a/src/hirad/input_data/regrid_realch1.py +++ b/src/hirad/input_data/regrid_realch1.py @@ -35,7 +35,6 @@ def anemoi_to_xarray(anemoi_data: Dataset, variable): attrs=dict(description=f'xarray from anemoi dataset for {variable}', metadata=metadata), ) - print(ds) return ds def getMetadataFromOGD(): @@ -58,34 +57,60 @@ def generate_times(anemoi_data: Dataset): curr_time = curr_time + anemoi_data.frequency return times -def get_geo_coords(regridded_data: xr.Dataset): - xmin = regridded_data.metadata.get("longitudeOfFirstGridPointInDegrees") - xmax = regridded_data.metadata.get("longitudeOfLastGridPointInDegrees") - dx = regridded_data.metadata.get("iDirectionIncrementInDegrees") - ymin = regridded_data.metadata.get("latitudeOfFirstGridPointInDegrees") - ymax = regridded_data.metadata.get("latitudeOfLastGridPointInDegrees") - dy = regridded_data.metadata.get("jDirectionIncrementInDegrees") - y = np.arange(ymin,ymax+dy,dy) - x = np.arange(xmin,xmax+dx,dx) +def get_geo_coords(regridded_data: xr.Dataset, sp_lat=None, sp_lon=None): + xmin = regridded_data.metadata.get("longitudeOfFirstGridPointInDegrees") + xmax = regridded_data.metadata.get("longitudeOfLastGridPointInDegrees") + dx = regridded_data.metadata.get("iDirectionIncrementInDegrees") + ymin = regridded_data.metadata.get("latitudeOfFirstGridPointInDegrees") + ymax = regridded_data.metadata.get("latitudeOfLastGridPointInDegrees") + dy = regridded_data.metadata.get("jDirectionIncrementInDegrees") + y = np.arange(ymin,ymax+dy,dy) + x = np.arange(xmin,xmax+dx,dx) # TODO, this parameter is not producing what I want it to. - sp_lat = regridded_data.metadata.get("latitudeOfSouthernPoleInDegrees") - sp_lon = regridded_data.metadata.get("longitudeOfSouthernPoleInDegrees") - xcoords = np.meshgrid(x,y)[0].flatten() - ycoords = np.meshgrid(x,y)[1].flatten() - geo_coords = unrotate(ycoords, xcoords, sp_lat, sp_lon) - return geo_coords + if not sp_lat or not sp_lon: + sp_lat = regridded_data.metadata.get("latitudeOfSouthernPoleInDegrees") # -43.0. north_pole_lat = 43.0 + sp_lon = regridded_data.metadata.get("longitudeOfSouthernPoleInDegrees") # 10.0. north_pole_lon = 190.0 + xcoords = np.meshgrid(x,y)[0].flatten() + ycoords = np.meshgrid(x,y)[1].flatten() + logging.info(f'sp_lat = {sp_lat}, sp_lon = {sp_lon}') + #geo_coords = unrotate(ycoords, xcoords, sp_lat*-1, sp_lon-180) + #geo_coords = unrotate(ycoords, xcoords, 43, 190) + #geo_coords = unrotate(ycoords, xcoords, -43, 190) # close + #geo_coords = unrotate(ycoords, xcoords, -43, 10) # produces range: 26W-8W, 41S-51.5S. image correct orientation. + geo_coords = unrotate(ycoords, xcoords, sp_lat, sp_lon) + + return geo_coords, ycoords, xcoords +logging.basicConfig(level=logging.INFO) realch1 = open_dataset('/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr') -myxarray = anemoi_to_xarray(realch1, "TOT_PREC").to_dataarray() +var = 'TD_2M' +var_index = realch1.variables.index(var) +myxarray = anemoi_to_xarray(realch1, var).to_dataarray() regridded=regrid.icon2rotlatlon(myxarray) + plot_and_save_projection(realch1.longitudes, realch1.latitudes, - realch1[0,56,0,:], "anemoi.png") + realch1[0,var_index,0,:], f'{var}-icon.png', s=0.01) plot_and_save_projection(myxarray.lon, myxarray.lat, - myxarray[0,0,0,:], "xarray.png") + myxarray[0,0,0,:], f'{var}-xarray.png', s=0.005) + +geo_coords, ycoords, xcoords = get_geo_coords(regridded) +logging.info(geo_coords) # South pole rotation of lon=10, latitude=-43 -#rotated_crs = ccrs.RotatedPole( -# pole_longitude=190, pole_latitude=43 -#) -geo_coords = get_geo_coords(regridded) +rotated_crs = ccrs.RotatedPole( + pole_longitude=190, pole_latitude=43 +) +plot_and_save_projection(xcoords, ycoords, + regridded[0,0,0,:], f'{var}-regridded-rotated-projection.png', + projection=rotated_crs) # picture looks accurate. plot_and_save_projection(geo_coords[1], geo_coords[0], - regridded[0,0,0,:], "regridded.png") \ No newline at end of file + regridded[0,0,0,:], f'{var}-regridded.png') + +# None of these work. +for sp_lat in [-43, 43]: + for sp_lon in [10, 170, 190, 350]: + geo_coords, ycoords, xcoords = get_geo_coords(regridded, sp_lat=sp_lat, sp_lon=sp_lon) + logging.info(geo_coords) + plot_and_save_projection(geo_coords[1], geo_coords[0], + regridded[0,0,0,:], f'{var}-regridded_{sp_lat}_{sp_lon}.png') + plot_and_save_projection(geo_coords[0], geo_coords[1], + regridded[0,0,0,:], f'{var}-regridded-reversed_{sp_lat}_{sp_lon}.png') \ No newline at end of file From 603f3d82af5d49a39c0d277934024b869eb9e2f3 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 28 Oct 2025 18:05:56 +0100 Subject: [PATCH 178/189] regridding is actually working now. --- src/hirad/input_data/regrid_realch1.py | 142 ++++++++++++++----------- 1 file changed, 77 insertions(+), 65 deletions(-) diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py index 6f4c680..7975c9d 100644 --- a/src/hirad/input_data/regrid_realch1.py +++ b/src/hirad/input_data/regrid_realch1.py @@ -1,6 +1,7 @@ import logging +import sys from anemoi.datasets import open_dataset from anemoi.datasets.data.dataset import Dataset @@ -9,34 +10,43 @@ import xarray as xr from meteodatalab import ogd_api from hirad.input_data.interpolate_basic import plot_and_save_projection +import yaml import matplotlib.pyplot as plt import cartopy.crs as ccrs from earthkit.geo.rotate import unrotate -def anemoi_to_xarray(anemoi_data: Dataset, variable): - lon = anemoi_data.longitudes - lat = anemoi_data.latitudes - eps = [0] # deterministic - time = generate_times(anemoi_data) - var_index = anemoi_data.variables.index(variable) - metadata = getMetadataFromOGD() - - ds = xr.Dataset( - data_vars=dict( - variable=(["time", "eps", "cell"], np.array(anemoi_data.data[:,var_index,:,:])), - ), - coords=dict( - eps=eps, - time=time, - lon=("cell", lon), - lat=("cell", lat), - ), - attrs=dict(description=f'xarray from anemoi dataset for {variable}', - metadata=metadata), - ) - return ds +# Take anemoi dataset and provide xarray dataarrays for a set of variables. +# returns: list of xarray dataarrays, and list of variable indices (anemoi) +def anemoi_to_xarray(anemoi_data: Dataset, variables): + lon = anemoi_data.longitudes + lat = anemoi_data.latitudes + eps = [0] # deterministic + time = generate_times(anemoi_data) + metadata = getMetadataFromOGD() + dataarrays = [] + var_indices = [] + for variable in variables: + var_index = anemoi_data.variables.index(variable) + var_indices.append(var_index) + + ds = xr.Dataset( + data_vars=dict( + variable=(["time", "eps", "cell"], np.array(anemoi_data.data[:,var_index,:,:])), + ), + coords=dict( + eps=eps, + time=time, + lon=("cell", lon), + lat=("cell", lat), + ), + attrs=dict(description=f'xarray from anemoi dataset for {variable}', + metadata=metadata), + ) + dataarrays.append(ds.to_dataarray()) + return dataarrays, var_indices +# Run a request to get the metadata, so that we can fake out an xarray. def getMetadataFromOGD(): lead_times = ["P0DT0H"] req = ogd_api.Request( @@ -49,6 +59,7 @@ def getMetadataFromOGD(): tot_prec = ogd_api.get_from_ogd(req) return tot_prec.metadata +# Get array of times from the anemoi dataset def generate_times(anemoi_data: Dataset): times = [] curr_time = anemoi_data.start_date.item() @@ -57,7 +68,9 @@ def generate_times(anemoi_data: Dataset): curr_time = curr_time + anemoi_data.frequency return times -def get_geo_coords(regridded_data: xr.Dataset, sp_lat=None, sp_lon=None): +# get the geo coordinates for the rotated lat/lon dataset. +# returns np.array of lats and array of lons +def get_geo_coords(regridded_data: xr.Dataset): xmin = regridded_data.metadata.get("longitudeOfFirstGridPointInDegrees") xmax = regridded_data.metadata.get("longitudeOfLastGridPointInDegrees") dx = regridded_data.metadata.get("iDirectionIncrementInDegrees") @@ -66,51 +79,50 @@ def get_geo_coords(regridded_data: xr.Dataset, sp_lat=None, sp_lon=None): dy = regridded_data.metadata.get("jDirectionIncrementInDegrees") y = np.arange(ymin,ymax+dy,dy) x = np.arange(xmin,xmax+dx,dx) - # TODO, this parameter is not producing what I want it to. - if not sp_lat or not sp_lon: - sp_lat = regridded_data.metadata.get("latitudeOfSouthernPoleInDegrees") # -43.0. north_pole_lat = 43.0 - sp_lon = regridded_data.metadata.get("longitudeOfSouthernPoleInDegrees") # 10.0. north_pole_lon = 190.0 + sp_lat = regridded_data.metadata.get("latitudeOfSouthernPoleInDegrees") # -43.0. north_pole_lat = 43.0 + sp_lon = regridded_data.metadata.get("longitudeOfSouthernPoleInDegrees") # 10.0. north_pole_lon = 190.0 xcoords = np.meshgrid(x,y)[0].flatten() ycoords = np.meshgrid(x,y)[1].flatten() + # Expect south pole rotation of lon=10, latitude=-43 logging.info(f'sp_lat = {sp_lat}, sp_lon = {sp_lon}') - #geo_coords = unrotate(ycoords, xcoords, sp_lat*-1, sp_lon-180) - #geo_coords = unrotate(ycoords, xcoords, 43, 190) - #geo_coords = unrotate(ycoords, xcoords, -43, 190) # close - #geo_coords = unrotate(ycoords, xcoords, -43, 10) # produces range: 26W-8W, 41S-51.5S. image correct orientation. - geo_coords = unrotate(ycoords, xcoords, sp_lat, sp_lon) - - return geo_coords, ycoords, xcoords + rotated_crs = ccrs.RotatedPole( + pole_longitude=(sp_lon + 180) % 360, pole_latitude=sp_lat * -1 # 190, 43 + ) + # Project onto PlateCarree. Geodetic produces similar coordinates (within 10 nanometers) + dst_grid = ccrs.PlateCarree() + geo_coords = dst_grid.transform_points(rotated_crs, xcoords, ycoords) + lats = geo_coords[:,1] + lons = geo_coords[:,0] + return lats, lons + +def main(): + # yml format + realch1_config_file = sys.argv[1] + output_directory = sys.argv[2] + + with open(realch1_config_file) as realch1_file: + realch1_config = yaml.safe_load(realch1_file) + realch1 = open_dataset(realch1_config) logging.basicConfig(level=logging.INFO) + realch1 = open_dataset('/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr') -var = 'TD_2M' -var_index = realch1.variables.index(var) -myxarray = anemoi_to_xarray(realch1, var).to_dataarray() -regridded=regrid.icon2rotlatlon(myxarray) - -plot_and_save_projection(realch1.longitudes, realch1.latitudes, - realch1[0,var_index,0,:], f'{var}-icon.png', s=0.01) -plot_and_save_projection(myxarray.lon, myxarray.lat, - myxarray[0,0,0,:], f'{var}-xarray.png', s=0.005) - -geo_coords, ycoords, xcoords = get_geo_coords(regridded) -logging.info(geo_coords) -# South pole rotation of lon=10, latitude=-43 -rotated_crs = ccrs.RotatedPole( - pole_longitude=190, pole_latitude=43 -) -plot_and_save_projection(xcoords, ycoords, - regridded[0,0,0,:], f'{var}-regridded-rotated-projection.png', - projection=rotated_crs) # picture looks accurate. -plot_and_save_projection(geo_coords[1], geo_coords[0], - regridded[0,0,0,:], f'{var}-regridded.png') - -# None of these work. -for sp_lat in [-43, 43]: - for sp_lon in [10, 170, 190, 350]: - geo_coords, ycoords, xcoords = get_geo_coords(regridded, sp_lat=sp_lat, sp_lon=sp_lon) - logging.info(geo_coords) - plot_and_save_projection(geo_coords[1], geo_coords[0], - regridded[0,0,0,:], f'{var}-regridded_{sp_lat}_{sp_lon}.png') - plot_and_save_projection(geo_coords[0], geo_coords[1], - regridded[0,0,0,:], f'{var}-regridded-reversed_{sp_lat}_{sp_lon}.png') \ No newline at end of file +variables = ['TD_2M', 'TOT_PREC'] +myxarrays, var_indices = anemoi_to_xarray(realch1, variables) + +for i in range(len(variables)): + myxarray = myxarrays[i] + regridded=regrid.icon2rotlatlon(myxarray) + plot_and_save_projection(realch1.longitudes, realch1.latitudes, + realch1[0,var_indices[i],0,:], f'{variables[i]}-icon.png', s=0.005) + plot_and_save_projection(myxarray.lon, myxarray.lat, + myxarray[0,0,0,:], f'{variables[i]}-xarray.png', s=0.005) + + lats, lons = get_geo_coords(regridded) + + plot_and_save_projection(lons, lats, + regridded[0,0,0,:], f'{variables[i]}-regridded.png', s=0.005) + + +if __name__ == "__main__": + main() \ No newline at end of file From 87c5206d01c9094164ecad28e9d1d87240d016ab Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 28 Oct 2025 19:05:14 +0100 Subject: [PATCH 179/189] skeleton of full regrid script --- src/hirad/input_data/regrid_realch1.py | 71 ++++++++++++++++---------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py index 7975c9d..b0fb0fa 100644 --- a/src/hirad/input_data/regrid_realch1.py +++ b/src/hirad/input_data/regrid_realch1.py @@ -1,6 +1,7 @@ import logging +import os import sys from anemoi.datasets import open_dataset @@ -17,22 +18,20 @@ from earthkit.geo.rotate import unrotate # Take anemoi dataset and provide xarray dataarrays for a set of variables. -# returns: list of xarray dataarrays, and list of variable indices (anemoi) -def anemoi_to_xarray(anemoi_data: Dataset, variables): +# returns: list of xarray dataarrays +def anemoi_to_xarray(anemoi_data: Dataset): lon = anemoi_data.longitudes lat = anemoi_data.latitudes eps = [0] # deterministic - time = generate_times(anemoi_data) + time = generate_times(anemoi_data) # anemoi_data.dates? metadata = getMetadataFromOGD() dataarrays = [] - var_indices = [] - for variable in variables: - var_index = anemoi_data.variables.index(variable) - var_indices.append(var_index) + variables = anemoi_data.variables + for var_index in range(anemoi_data.shape[1]): ds = xr.Dataset( data_vars=dict( - variable=(["time", "eps", "cell"], np.array(anemoi_data.data[:,var_index,:,:])), + variable=(["time", "eps", "cell"], np.array(anemoi_data[:,var_index,:,:])), ), coords=dict( eps=eps, @@ -40,18 +39,18 @@ def anemoi_to_xarray(anemoi_data: Dataset, variables): lon=("cell", lon), lat=("cell", lat), ), - attrs=dict(description=f'xarray from anemoi dataset for {variable}', + attrs=dict(description=f'xarray from anemoi dataset for {variables[var_index]}', metadata=metadata), ) dataarrays.append(ds.to_dataarray()) - return dataarrays, var_indices + return dataarrays # Run a request to get the metadata, so that we can fake out an xarray. def getMetadataFromOGD(): lead_times = ["P0DT0H"] req = ogd_api.Request( collection="ogd-forecasting-icon-ch1", - variable="TOT_PREC", + variable="TOT_PREC", #assuming this won't cause problems; we're only using grid info ref_time="latest", perturbed=False, lead_time=lead_times, @@ -99,29 +98,47 @@ def main(): # yml format realch1_config_file = sys.argv[1] output_directory = sys.argv[2] + if not os.path.exists(output_directory): + os.mkdir(output_directory) + for subdir in ['info', 'plots', 'realch1']: + if not os.path.exists(os.path.join(output_directory, subdir)): + os.mkdir(os.path.join(output_directory, subdir)) with open(realch1_config_file) as realch1_file: realch1_config = yaml.safe_load(realch1_file) realch1 = open_dataset(realch1_config) + variables = realch1.variables -logging.basicConfig(level=logging.INFO) - -realch1 = open_dataset('/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr') -variables = ['TD_2M', 'TOT_PREC'] -myxarrays, var_indices = anemoi_to_xarray(realch1, variables) - -for i in range(len(variables)): - myxarray = myxarrays[i] - regridded=regrid.icon2rotlatlon(myxarray) - plot_and_save_projection(realch1.longitudes, realch1.latitudes, - realch1[0,var_indices[i],0,:], f'{variables[i]}-icon.png', s=0.005) - plot_and_save_projection(myxarray.lon, myxarray.lat, - myxarray[0,0,0,:], f'{variables[i]}-xarray.png', s=0.005) + logging.basicConfig(level=logging.INFO) + xarrays = anemoi_to_xarray(realch1) + + # Get the lat/lon info by regridding first variable + regridded=regrid.icon2rotlatlon(xarrays[0]) lats, lons = get_geo_coords(regridded) - - plot_and_save_projection(lons, lats, - regridded[0,0,0,:], f'{variables[i]}-regridded.png', s=0.005) + logging.info(regridded) + logging.info(regridded.data) + logging.info(regridded.data.shape) + # TODO: Save lat/lon info + + # regridded is in shape (eps, time, variable, x, y) + # want this in shape (time,channel,ensemble,grid) + torch_data = np.zeros([len(realch1.dates), len(realch1.variables), 1, len(lats)]) + torch_data[:,0,:,:] = regridded.data.reshape(regridded.shape[1], regridded.shape[0], regridded.shape[3]*regridded.shape[4]) + + for i in range(1, len(xarrays)): + xarray = xarrays[i] + regridded=regrid.icon2rotlatlon(xarray) + torch_data[:,i,:,:] = regridded.data.reshape(regridded.shape[1], regridded.shape[0], regridded.shape[3]*regridded.shape[4]) + + # TODO: output each time point into torch file + + # Output plots + for i in range(torch_data.shape[1]): + plot_and_save_projection(realch1.longitudes, realch1.latitudes, + realch1[0,i,0,:], f'{variables[i]}-icon.png', s=0.005) + plot_and_save_projection(lons, lats, + torch_data[0,i,0,:], f'{variables[i]}-regridded.png', s=0.005) if __name__ == "__main__": From dd79f5e673d4323ebc02616bb508864078c4dfda Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 28 Oct 2025 19:25:26 +0100 Subject: [PATCH 180/189] Regridding REA-L-CH1 script complete --- src/hirad/input_data/realch1-all.yaml | 26 +++++++++++++++++++++ src/hirad/input_data/realch1.yaml | 24 +------------------- src/hirad/input_data/regrid_realch1.py | 31 +++++++++++++++++++------- 3 files changed, 50 insertions(+), 31 deletions(-) create mode 100644 src/hirad/input_data/realch1-all.yaml diff --git a/src/hirad/input_data/realch1-all.yaml b/src/hirad/input_data/realch1-all.yaml new file mode 100644 index 0000000..6b0f776 --- /dev/null +++ b/src/hirad/input_data/realch1-all.yaml @@ -0,0 +1,26 @@ +dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr' +select: ['CLCH', 'CLCL', 'CLCM', 'CLCT', + 'FI_100', 'FI_1000', 'FI_150', 'FI_200', 'FI_250', 'FI_300', 'FI_400', + 'FI_50', 'FI_500', 'FI_600', 'FI_700', 'FI_850', 'FI_925', + 'FR_LAND', 'HSURF', + 'OMEGA_100', 'OMEGA_1000', 'OMEGA_150', 'OMEGA_200', 'OMEGA_250', + 'OMEGA_300', 'OMEGA_400', 'OMEGA_50', 'OMEGA_500', 'OMEGA_600', + 'OMEGA_700', 'OMEGA_850', 'OMEGA_925', + 'PLCOV', 'PMSL', 'PS', + 'QV_100', 'QV_1000', 'QV_150', 'QV_200', 'QV_250', 'QV_300', 'QV_400', + 'QV_50', 'QV_500', 'QV_600', 'QV_700', 'QV_850', 'QV_925', + 'SKC', 'SKT', 'SOILTYP', 'SSO_GAMMA', 'SSO_SIGMA', 'SSO_STDH', + 'SSO_THETA', 'TD_2M', 'TOT_PREC', 'TOT_PREC_6H', + 'T_100', 'T_1000', 'T_150', 'T_200', 'T_250', 'T_2M', 'T_300', 'T_400', + 'T_50', 'T_500', 'T_600', 'T_700', 'T_850', 'T_925', + 'U_100', 'U_1000', 'U_10M', 'U_150', 'U_200', 'U_250', 'U_300', 'U_400', + 'U_50', 'U_500', 'U_600', 'U_700', 'U_850', 'U_925', + 'V_100', 'V_1000', 'V_10M', 'V_150', 'V_200', 'V_250', 'V_300', 'V_400', + 'V_50', 'V_500', 'V_600', 'V_700', 'V_850', 'V_925', + 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', + 'insolation', 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude'] +# ALL REALCH1 CHANNELS. +start: 2020-01-01 +# start: 2015-11-29 +end: 2020-01-01 +# end: 2020-12-31 diff --git a/src/hirad/input_data/realch1.yaml b/src/hirad/input_data/realch1.yaml index 6b0f776..d6ce166 100644 --- a/src/hirad/input_data/realch1.yaml +++ b/src/hirad/input_data/realch1.yaml @@ -1,26 +1,4 @@ dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr' -select: ['CLCH', 'CLCL', 'CLCM', 'CLCT', - 'FI_100', 'FI_1000', 'FI_150', 'FI_200', 'FI_250', 'FI_300', 'FI_400', - 'FI_50', 'FI_500', 'FI_600', 'FI_700', 'FI_850', 'FI_925', - 'FR_LAND', 'HSURF', - 'OMEGA_100', 'OMEGA_1000', 'OMEGA_150', 'OMEGA_200', 'OMEGA_250', - 'OMEGA_300', 'OMEGA_400', 'OMEGA_50', 'OMEGA_500', 'OMEGA_600', - 'OMEGA_700', 'OMEGA_850', 'OMEGA_925', - 'PLCOV', 'PMSL', 'PS', - 'QV_100', 'QV_1000', 'QV_150', 'QV_200', 'QV_250', 'QV_300', 'QV_400', - 'QV_50', 'QV_500', 'QV_600', 'QV_700', 'QV_850', 'QV_925', - 'SKC', 'SKT', 'SOILTYP', 'SSO_GAMMA', 'SSO_SIGMA', 'SSO_STDH', - 'SSO_THETA', 'TD_2M', 'TOT_PREC', 'TOT_PREC_6H', - 'T_100', 'T_1000', 'T_150', 'T_200', 'T_250', 'T_2M', 'T_300', 'T_400', - 'T_50', 'T_500', 'T_600', 'T_700', 'T_850', 'T_925', - 'U_100', 'U_1000', 'U_10M', 'U_150', 'U_200', 'U_250', 'U_300', 'U_400', - 'U_50', 'U_500', 'U_600', 'U_700', 'U_850', 'U_925', - 'V_100', 'V_1000', 'V_10M', 'V_150', 'V_200', 'V_250', 'V_300', 'V_400', - 'V_50', 'V_500', 'V_600', 'V_700', 'V_850', 'V_925', - 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', - 'insolation', 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude'] -# ALL REALCH1 CHANNELS. +select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC'] start: 2020-01-01 -# start: 2015-11-29 end: 2020-01-01 -# end: 2020-12-31 diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py index b0fb0fa..7d0b65d 100644 --- a/src/hirad/input_data/regrid_realch1.py +++ b/src/hirad/input_data/regrid_realch1.py @@ -2,6 +2,7 @@ import logging import os +import shutil import sys from anemoi.datasets import open_dataset @@ -12,6 +13,8 @@ from meteodatalab import ogd_api from hirad.input_data.interpolate_basic import plot_and_save_projection import yaml +import torch +from pandas import to_datetime import matplotlib.pyplot as plt import cartopy.crs as ccrs @@ -104,6 +107,9 @@ def main(): if not os.path.exists(os.path.join(output_directory, subdir)): os.mkdir(os.path.join(output_directory, subdir)) + # Copy the realch1.yml file to the info directory + shutil.copy(realch1_config_file, os.path.join(output_directory, 'info')) + with open(realch1_config_file) as realch1_file: realch1_config = yaml.safe_load(realch1_file) realch1 = open_dataset(realch1_config) @@ -116,13 +122,14 @@ def main(): # Get the lat/lon info by regridding first variable regridded=regrid.icon2rotlatlon(xarrays[0]) lats, lons = get_geo_coords(regridded) - logging.info(regridded) - logging.info(regridded.data) - logging.info(regridded.data.shape) - # TODO: Save lat/lon info + + # Save lat/lon info + grid = np.column_stack((lats, lons)) + torch.save(grid, os.path.join(output_directory, 'info', 'realch1-lat-lon')) # regridded is in shape (eps, time, variable, x, y) # want this in shape (time,channel,ensemble,grid) + # nervous about the reshaping screwing things up, but that's why we plot the interpolated data to visually check. torch_data = np.zeros([len(realch1.dates), len(realch1.variables), 1, len(lats)]) torch_data[:,0,:,:] = regridded.data.reshape(regridded.shape[1], regridded.shape[0], regridded.shape[3]*regridded.shape[4]) @@ -131,14 +138,22 @@ def main(): regridded=regrid.icon2rotlatlon(xarray) torch_data[:,i,:,:] = regridded.data.reshape(regridded.shape[1], regridded.shape[0], regridded.shape[3]*regridded.shape[4]) - # TODO: output each time point into torch file + # Output each time point into torch file + for t in range(torch_data.shape[0]): + fmtdate = to_datetime(realch1.dates[t]).strftime('%Y%m%d-%H%M') + torch.save(torch_data[t,:], os.path.join(output_directory, 'realch1', fmtdate)) + - # Output plots + # Output plots for each variable, for first time point for i in range(torch_data.shape[1]): plot_and_save_projection(realch1.longitudes, realch1.latitudes, - realch1[0,i,0,:], f'{variables[i]}-icon.png', s=0.005) + realch1[0,i,0,:], + os.path.join(output_directory, 'plots', f'{variables[i]}-iconnative.png'), + s=0.005) plot_and_save_projection(lons, lats, - torch_data[0,i,0,:], f'{variables[i]}-regridded.png', s=0.005) + torch_data[0,i,0,:], + os.path.join(output_directory, 'plots', f'{variables[i]}-rotlatlon.png'), + s=0.005) if __name__ == "__main__": From a5acccd8398eb4e2da94d193d078b85b693b9b5b Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 29 Oct 2025 17:19:52 +0100 Subject: [PATCH 181/189] Add trim_edge functionality to regridding realch1 --- src/hirad/input_data/regrid_realch1.py | 59 +++++++++++++++++--------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py index 7d0b65d..c4f8bc1 100644 --- a/src/hirad/input_data/regrid_realch1.py +++ b/src/hirad/input_data/regrid_realch1.py @@ -1,5 +1,5 @@ - +import datetime import logging import os import shutil @@ -20,18 +20,19 @@ import cartopy.crs as ccrs from earthkit.geo.rotate import unrotate +TRIM_EDGE = 41 + # Take anemoi dataset and provide xarray dataarrays for a set of variables. # returns: list of xarray dataarrays def anemoi_to_xarray(anemoi_data: Dataset): lon = anemoi_data.longitudes lat = anemoi_data.latitudes eps = [0] # deterministic - time = generate_times(anemoi_data) # anemoi_data.dates? + time = anemoi_data.dates metadata = getMetadataFromOGD() dataarrays = [] variables = anemoi_data.variables for var_index in range(anemoi_data.shape[1]): - ds = xr.Dataset( data_vars=dict( variable=(["time", "eps", "cell"], np.array(anemoi_data[:,var_index,:,:])), @@ -61,18 +62,9 @@ def getMetadataFromOGD(): tot_prec = ogd_api.get_from_ogd(req) return tot_prec.metadata -# Get array of times from the anemoi dataset -def generate_times(anemoi_data: Dataset): - times = [] - curr_time = anemoi_data.start_date.item() - while curr_time <= anemoi_data.end_date: - times.append(curr_time) - curr_time = curr_time + anemoi_data.frequency - return times - # get the geo coordinates for the rotated lat/lon dataset. # returns np.array of lats and array of lons -def get_geo_coords(regridded_data: xr.Dataset): +def get_geo_coords(regridded_data: xr.Dataset, trim_edge=0): xmin = regridded_data.metadata.get("longitudeOfFirstGridPointInDegrees") xmax = regridded_data.metadata.get("longitudeOfLastGridPointInDegrees") dx = regridded_data.metadata.get("iDirectionIncrementInDegrees") @@ -81,6 +73,11 @@ def get_geo_coords(regridded_data: xr.Dataset): dy = regridded_data.metadata.get("jDirectionIncrementInDegrees") y = np.arange(ymin,ymax+dy,dy) x = np.arange(xmin,xmax+dx,dx) + # trim x and y according to trim_edge. + # (Have manually verified that when doing this, the outputs are the same as + # trimming post-projection) + y = y[trim_edge:len(y)-trim_edge] + x = x[trim_edge:len(x)-trim_edge] sp_lat = regridded_data.metadata.get("latitudeOfSouthernPoleInDegrees") # -43.0. north_pole_lat = 43.0 sp_lon = regridded_data.metadata.get("longitudeOfSouthernPoleInDegrees") # 10.0. north_pole_lon = 190.0 xcoords = np.meshgrid(x,y)[0].flatten() @@ -97,6 +94,17 @@ def get_geo_coords(regridded_data: xr.Dataset): lons = geo_coords[:,0] return lats, lons +def regridded_to_numpy(regridded: xr.DataArray, trim_edge=0): + # regridded is in shape (eps, time, variable, x, y) + # want this in shape (time, channel, ensemble, grid) + # First, trim the edge + data = regridded.data[:,:,:, + trim_edge:regridded.data.shape[3]-trim_edge, + trim_edge:regridded.data.shape[4]-trim_edge] + # reshape to (time,channel,ensemble,grid) + data = data.reshape(data.shape[1], data.shape[0], data.shape[3]*data.shape[4]) + return data + def main(): # yml format realch1_config_file = sys.argv[1] @@ -107,6 +115,12 @@ def main(): if not os.path.exists(os.path.join(output_directory, subdir)): os.mkdir(os.path.join(output_directory, subdir)) + logging.basicConfig( + filename=os.path.join(output_directory, 'regrid_realch1.log'), + format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + # Copy the realch1.yml file to the info directory shutil.copy(realch1_config_file, os.path.join(output_directory, 'info')) @@ -115,13 +129,16 @@ def main(): realch1 = open_dataset(realch1_config) variables = realch1.variables - logging.basicConfig(level=logging.INFO) - xarrays = anemoi_to_xarray(realch1) # Get the lat/lon info by regridding first variable + logging.info(f'regridding {variables[0]} for time {realch1.start_date} to {realch1.end_date}') + start = datetime.datetime.now() regridded=regrid.icon2rotlatlon(xarrays[0]) - lats, lons = get_geo_coords(regridded) + end = datetime.datetime.now() + logging.info(f' regridding took {end-start} seconds') + logging.info('getting geo coords') + lats, lons = get_geo_coords(regridded, trim_edge=TRIM_EDGE) # Save lat/lon info grid = np.column_stack((lats, lons)) @@ -131,19 +148,23 @@ def main(): # want this in shape (time,channel,ensemble,grid) # nervous about the reshaping screwing things up, but that's why we plot the interpolated data to visually check. torch_data = np.zeros([len(realch1.dates), len(realch1.variables), 1, len(lats)]) - torch_data[:,0,:,:] = regridded.data.reshape(regridded.shape[1], regridded.shape[0], regridded.shape[3]*regridded.shape[4]) + torch_data[:,0,:,:] = regridded_to_numpy(regridded, trim_edge=TRIM_EDGE) for i in range(1, len(xarrays)): + logging.info(f'regridding {variables[i]} for time {realch1.start_date} to {realch1.end_date}') xarray = xarrays[i] + start = datetime.datetime.now() regridded=regrid.icon2rotlatlon(xarray) - torch_data[:,i,:,:] = regridded.data.reshape(regridded.shape[1], regridded.shape[0], regridded.shape[3]*regridded.shape[4]) + end = datetime.datetime.now() + logging.info(f' regridding took {end-start} seconds') + torch_data[:,i,:,:] = regridded_to_numpy(regridded, trim_edge=TRIM_EDGE) # Output each time point into torch file + logging.info('saving torch data') for t in range(torch_data.shape[0]): fmtdate = to_datetime(realch1.dates[t]).strftime('%Y%m%d-%H%M') torch.save(torch_data[t,:], os.path.join(output_directory, 'realch1', fmtdate)) - # Output plots for each variable, for first time point for i in range(torch_data.shape[1]): plot_and_save_projection(realch1.longitudes, realch1.latitudes, From 93b5a5fdb3195cf7822958463b653d9df3c195f6 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 30 Oct 2025 16:21:24 +0100 Subject: [PATCH 182/189] do rea-l-ch1 regridding in batches to improve performance/memory usage --- src/hirad/input_data/regrid_realch1.py | 66 +++++++++++++------------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py index c4f8bc1..8e84593 100644 --- a/src/hirad/input_data/regrid_realch1.py +++ b/src/hirad/input_data/regrid_realch1.py @@ -21,21 +21,28 @@ from earthkit.geo.rotate import unrotate TRIM_EDGE = 41 +XARRAY_BATCH = 4 # Take anemoi dataset and provide xarray dataarrays for a set of variables. # returns: list of xarray dataarrays -def anemoi_to_xarray(anemoi_data: Dataset): +def anemoi_to_xarray(anemoi_data: Dataset, start_date_index=-1, end_date_index=-1): + if start_date_index == -1: + start_date_index = 0 + if end_date_index == -1: + end_date_index = len(anemoi_data.dates) lon = anemoi_data.longitudes lat = anemoi_data.latitudes eps = [0] # deterministic - time = anemoi_data.dates + time = anemoi_data.dates[start_date_index:end_date_index] metadata = getMetadataFromOGD() dataarrays = [] variables = anemoi_data.variables for var_index in range(anemoi_data.shape[1]): + logging.info(f'building xarray for {variables[var_index]}') ds = xr.Dataset( data_vars=dict( - variable=(["time", "eps", "cell"], np.array(anemoi_data[:,var_index,:,:])), + variable=(["time", "eps", "cell"], + np.array(anemoi_data[start_date_index:end_date_index,var_index,:,:])), ), coords=dict( eps=eps, @@ -129,41 +136,36 @@ def main(): realch1 = open_dataset(realch1_config) variables = realch1.variables - xarrays = anemoi_to_xarray(realch1) - - # Get the lat/lon info by regridding first variable - logging.info(f'regridding {variables[0]} for time {realch1.start_date} to {realch1.end_date}') - start = datetime.datetime.now() + # Get the lat/lon info by regridding one variable + xarrays = anemoi_to_xarray(realch1, 0, 1) regridded=regrid.icon2rotlatlon(xarrays[0]) - end = datetime.datetime.now() - logging.info(f' regridding took {end-start} seconds') logging.info('getting geo coords') lats, lons = get_geo_coords(regridded, trim_edge=TRIM_EDGE) - # Save lat/lon info + # Save grid to file grid = np.column_stack((lats, lons)) torch.save(grid, os.path.join(output_directory, 'info', 'realch1-lat-lon')) - - # regridded is in shape (eps, time, variable, x, y) - # want this in shape (time,channel,ensemble,grid) - # nervous about the reshaping screwing things up, but that's why we plot the interpolated data to visually check. - torch_data = np.zeros([len(realch1.dates), len(realch1.variables), 1, len(lats)]) - torch_data[:,0,:,:] = regridded_to_numpy(regridded, trim_edge=TRIM_EDGE) - - for i in range(1, len(xarrays)): - logging.info(f'regridding {variables[i]} for time {realch1.start_date} to {realch1.end_date}') - xarray = xarrays[i] - start = datetime.datetime.now() - regridded=regrid.icon2rotlatlon(xarray) - end = datetime.datetime.now() - logging.info(f' regridding took {end-start} seconds') - torch_data[:,i,:,:] = regridded_to_numpy(regridded, trim_edge=TRIM_EDGE) - - # Output each time point into torch file - logging.info('saving torch data') - for t in range(torch_data.shape[0]): - fmtdate = to_datetime(realch1.dates[t]).strftime('%Y%m%d-%H%M') - torch.save(torch_data[t,:], os.path.join(output_directory, 'realch1', fmtdate)) + + # Split regridding into batches; too many time points seems to not scale well. + for i in range(0, len(realch1.dates), XARRAY_BATCH): + start_index = i + end_index = min(i+XARRAY_BATCH, len(realch1.dates)) + torch_data = np.zeros([end_index-start_index, len(realch1.variables), 1, len(lats)]) + + xarrays = anemoi_to_xarray(realch1, start_index, end_index) + for j in range(len(xarrays)): + logging.info(f'regridding {variables[j]} for time {realch1.dates[start_index]} to {realch1.dates[end_index]}') + xarray = xarrays[j] + start = datetime.datetime.now() + regridded=regrid.icon2rotlatlon(xarray) + end = datetime.datetime.now() + logging.info(f' regridding took {end-start} seconds') + torch_data[0:end_index-start_index,j,:,:] = regridded_to_numpy(regridded, trim_edge=TRIM_EDGE) + + logging.info('saving torch data') + for k in range(torch_data.shape[0]): + fmtdate = to_datetime(realch1.dates[start_index + k]).strftime('%Y%m%d-%H%M') + torch.save(torch_data[k,:], os.path.join(output_directory, 'realch1', fmtdate)) # Output plots for each variable, for first time point for i in range(torch_data.shape[1]): From 9f42baada836d14a2c67dabc91a2a0aa2ed22ad7 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 30 Oct 2025 19:21:04 +0100 Subject: [PATCH 183/189] small refactoring and cleanup so it is easier to re-use era5-cosmo interpolation methods --- src/hirad/input_data/interpolate_basic.py | 96 +++++++---------------- 1 file changed, 29 insertions(+), 67 deletions(-) diff --git a/src/hirad/input_data/interpolate_basic.py b/src/hirad/input_data/interpolate_basic.py index b68219b..f88b5e1 100644 --- a/src/hirad/input_data/interpolate_basic.py +++ b/src/hirad/input_data/interpolate_basic.py @@ -20,34 +20,39 @@ # Margin to use for ERA dataset (to avoid nans from interpolation at boundary) ERA_MARGIN_DEGREES = 1.0 -def _read_input(era_config_file: str, cosmo_config_file: str, bound_to_cosmo_area=True) -> tuple[Dataset, Dataset]: +def _read_era5_cosmo(era_config_file: str, cosmo_config_file: str) -> tuple[Dataset, Dataset]: """ Read both ERA and COSMO data, optionally bounding to the COSMO data area, and return the 2m temperature values for the time range under COSMO. """ # trim edge removes boundary + cosmo = read_cosmo_anemoi(cosmo_config_file) + # area = N, W, S, E + min_lat = min(cosmo.latitudes) - ERA_MARGIN_DEGREES + max_lat = max(cosmo.latitudes) + ERA_MARGIN_DEGREES + min_lon = min(cosmo.longitudes) - ERA_MARGIN_DEGREES + max_lon = max(cosmo.longitudes) + ERA_MARGIN_DEGREES + start_date = cosmo.metadata()['start_date'] + end_date = cosmo.metadata()['end_date'] + era = read_era5_anemoi(era_config_file, + start_date = start_date, end_date = end_date, + area=(max_lat, min_lon, min_lat, max_lon)) + return (era, cosmo) + +def read_cosmo_anemoi(cosmo_config_file: str): with open(cosmo_config_file) as cosmo_file: cosmo_config = yaml.safe_load(cosmo_file) cosmo = open_dataset(cosmo_config) + return cosmo + +def read_era5_anemoi(era_config_file: str, start_date = None, + end_date = None, area=None): with open(era_config_file) as era_file: era_config = yaml.safe_load(era_file) era = open_dataset(era_config) - # Subset the ERA dataset to have COSMO area/dates. - start_date = cosmo.metadata()['start_date'] - end_date = cosmo.metadata()['end_date'] - # load era5 2m-temperature in the time-range of cosmo - # area = N, W, S, E - if bound_to_cosmo_area: - min_lat = min(cosmo.latitudes) - ERA_MARGIN_DEGREES - max_lat = max(cosmo.latitudes) + ERA_MARGIN_DEGREES - min_lon = min(cosmo.longitudes) - ERA_MARGIN_DEGREES - max_lon = max(cosmo.longitudes) + ERA_MARGIN_DEGREES - era = open_dataset(era, start=start_date, end=end_date, - area=(max_lat, min_lon, min_lat, max_lon)) - else: - era = open_dataset(era, start=start_date, end=end_date) - - return (era, cosmo) + era = open_dataset(era, start=start_date, end=end_date, + area=area) + return era def regrid(era_for_time: np.ndarray, input_grid: np.ndarray, output_grid: np.ndarray): # shape (channel, ensemble, grid) @@ -58,7 +63,7 @@ def regrid(era_for_time: np.ndarray, input_grid: np.ndarray, output_grid: np.nda interpolated_data[j,0,:] = regrid return interpolated_data -def _interpolate_task(i: int, era: Dataset, cosmo: Dataset, input_grid: np.ndarray, output_grid: np.ndarray, intermediate_files_path: str, outfile_plots_path: str = None, plot_indices=[0]): +def _interpolate_era5_cosmo_task(i: int, era: Dataset, cosmo: Dataset, input_grid: np.ndarray, output_grid: np.ndarray, intermediate_files_path: str, outfile_plots_path: str = None, plot_indices=[0]): logging.info('interpolating time point ' + _format_date(cosmo.dates[i])) interpolated_data = np.empty([era.shape[1], 1, cosmo.shape[3]]) for j in range(era.shape[1]): @@ -85,7 +90,7 @@ def _interpolate_task(i: int, era: Dataset, cosmo: Dataset, input_grid: np.ndarr -def _interpolate_basic(era: Dataset, cosmo: Dataset, intermediate_files_path: str, threaded = True, outfile_plots_path: str =None, plot_indices=[0]): +def _interpolate_era5_cosmo_basic(era: Dataset, cosmo: Dataset, intermediate_files_path: str, threaded = True, outfile_plots_path: str =None, plot_indices=[0]): """Perform simple interpolation from ERA5 to COSMO grid for all data points in the COSMO date range. Parameters: @@ -113,13 +118,13 @@ def _interpolate_basic(era: Dataset, cosmo: Dataset, intermediate_files_path: st if (threaded): pool = multiprocessing.Pool() for i in dates: - pool.apply_async(_interpolate_task, (i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices)) + pool.apply_async(_interpolate_era5_cosmo_task, (i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices)) pool.close() pool.join() else: for i in dates: - _interpolate_task(i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices) + _interpolate_era5_cosmo_task(i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices) return @@ -138,26 +143,6 @@ def _save_latlon_grid(dataset: Dataset, filename: str): def _save_stats(dataset: Dataset, filename: str): torch.save(dataset.statistics, filename) -def _save_interpolation(values: np.ndarray[np.intp], filename: str): - """Output interpolated data to a given filename, in PyTorch tensor format.""" - torch_data = torch.from_numpy(values) - torch.save(torch_data, filename) - -def _get_plot_indices(era: Dataset, cosmo: Dataset) -> np.ndarray[np.intp]: - """ - Get indices of ERA5 data that is in the bounding rectangle of COSMO data. - This is useful for plotting in the case where read_input(..., bound_to_cosmo_area=False) was used. - In this case, one would then feed e.g. era.latitudes[indices] into _plot_projection. - """ - min_lat_cosmo = min(cosmo.latitudes) - max_lat_cosmo = max(cosmo.latitudes) - min_lon_cosmo = min(cosmo.longitudes) - max_lon_cosmo = max(cosmo.longitudes) - box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) - box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) - indices = np.where(box_lon*box_lat) - return indices - def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None, s = None): p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax, s=s) ax.coastlines() @@ -171,14 +156,10 @@ def plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: fig, ax = plt.subplots(subplot_kw={"projection": projection}) logging.info(f'plotting values to {filename}') plot_projection(ax, longitudes, latitudes, values, cmap, vmin, vmax, s) - #p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) - #ax.coastlines() - #ax.gridlines(draw_labels=True) - #plt.colorbar(p, orientation="horizontal") plt.savefig(filename) plt.close('all') -def interpolate_and_save(infile_era: str, infile_cosmo: str, outfile_data_path: str, threaded=True, outfile_plots_path: str = None, plot_indices=[0]): +def interpolate_era5_cosmo_and_save(infile_era: str, infile_cosmo: str, outfile_data_path: str, threaded=True, outfile_plots_path: str = None, plot_indices=[0]): """Read both ERA and COSMO data and perform basic interpolation. Save output into Pytorch format, and (optionally) plot ERA, COSMO, and interpolated data. @@ -221,25 +202,8 @@ def interpolate_and_save(infile_era: str, infile_cosmo: str, outfile_data_path: shutil.copy(infile_era, os.path.join(outfile_data_path, "info/era.yaml")) # generate interpolated data - _interpolate_basic(era, cosmo, outfile_data_path, threaded=threaded, outfile_plots_path=outfile_plots_path, plot_indices=plot_indices) - -def plot_tp(path_6h: str, path_1h: str): - fig, axs = plt.subplots(2, 3, subplot_kw={"projection": ccrs.PlateCarree()}) - - - logging.info(f'plotting values to {filename}') - p = ax.scatter(x=longitudes, y=latitudes, c=values) - ax.coastlines() - ax.gridlines(draw_labels=True) - plt.colorbar(p, label="absolute error", orientation="horizontal") - plt.savefig(filename) - plt.close('all') + _interpolate_era5_cosmo_basic(era, cosmo, outfile_data_path, threaded=threaded, outfile_plots_path=outfile_plots_path, plot_indices=plot_indices) -def load_static(infile_era: str, infile_cosmo: str, output_directory: str): - _, cosmo = _read_input(infile_era, infile_cosmo, bound_to_cosmo_area=True) - - torch.save(cosmo[0,:,:,:], os.path.join(output_directory, 'cosmo-static')) - shutil.copy(infile_cosmo, os.path.join(output_directory, "cosmo-static.yaml")) def main(): # TODO: Do better arg parsing so it's not as easy to reverse era and cosmo configs. @@ -255,9 +219,7 @@ def main(): level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') - #load_static(infile_era, infile_cosmo, output_directory) - #interpolate_and_save(infile_era, infile_cosmo, output_directory, threaded=False, outfile_plots_path=os.path.join(output_directory, "plots/")) - interpolate_and_save(infile_era, infile_cosmo, output_directory, threaded=False, outfile_plots_path=None) + interpolate_era5_cosmo_and_save(infile_era, infile_cosmo, output_directory, threaded=False, outfile_plots_path=None) if __name__ == "__main__": main() From 7b9364f84836b8474daaa985a6a166a82814f6cf Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 3 Nov 2025 10:38:49 +0100 Subject: [PATCH 184/189] bit more skeleton to new inteprolation --- src/hirad/input_data/interpolate_realch1.py | 82 ++++++++++++++------ src/hirad/input_data/regrid_copernicus_tp.py | 23 +++--- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/src/hirad/input_data/interpolate_realch1.py b/src/hirad/input_data/interpolate_realch1.py index 1334865..2a2e1d7 100644 --- a/src/hirad/input_data/interpolate_realch1.py +++ b/src/hirad/input_data/interpolate_realch1.py @@ -1,4 +1,5 @@ - +import hirad.input_data.interpolate_basic as interpolate_basic +import hirad.input_data.regrid_copernicus_tp as regrid_copernicus_tp import datetime import logging @@ -25,37 +26,29 @@ '/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2017-2018.nc', '/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2019-2020.nc'] -def _read_input(era_config_file: str, realch1_config_file: str, ) -> tuple[Dataset, Dataset, array.array]: +def _read_input(era_config_file: str, realch1_latlon_file: str) -> tuple[Dataset, Dataset, array.array, np.ndarray]: """ - Read both ERA and REA-L-CH1 data, and return the 2m - temperature values for the time range under COSMO. + Read ERA data, and return the values for the area under REA-L-CH1 (plus a margin). """ - # trim edge removes boundary, we will use the same - with open(realch1_config_file) as realch1_file: - realch1_config = yaml.safe_load(realch1_file) - realch1 = open_dataset(realch1_config) + # read the lat/lon data for REA-L-CH1 + realch1_latlon = torch.load(realch1_latlon_file) + # we expect the start and end dates to be specified in config. with open(era_config_file) as era_file: era_config = yaml.safe_load(era_file) era = open_dataset(era_config) - # Subset the ERA dataset to have REAL-CH-1 area/dates. - start_date = realch1.metadata()['start_date'] - end_date = realch1.metadata()['end_date'] - # load era5 2m-temperature in the time-range of cosmo + # Subset the ERA dataset to have REAL-CH-1 area. # area = N, W, S, E - min_lat = min(realch1.latitudes) - ERA_MARGIN_DEGREES - max_lat = max(realch1.latitudes) + ERA_MARGIN_DEGREES - min_lon = min(realch1.longitudes) - ERA_MARGIN_DEGREES - max_lon = max(realch1.longitudes) + ERA_MARGIN_DEGREES - era = open_dataset(era, start=start_date, end=end_date, + min_lat = min(realch1_latlon[:,0]) - ERA_MARGIN_DEGREES + max_lat = max(realch1_latlon[:,0]) + ERA_MARGIN_DEGREES + min_lon = min(realch1_latlon[:,1]) - ERA_MARGIN_DEGREES + max_lon = max(realch1_latlon[:,1]) + ERA_MARGIN_DEGREES + era = open_dataset(era, area=(max_lat, min_lon, min_lat, max_lon)) - - copernicus_netcdf = [] for f in COPERNICUS_FILES: netcdf_data = netCDF4.Dataset(f) copernicus_netcdf.append(netcdf_data) - - return (era, realch1, copernicus_netcdf) + return (era, copernicus_netcdf, realch1_latlon) def regrid_all(era: Dataset, realch1: Dataset, copernicus: array.array): @@ -77,4 +70,49 @@ def regrid_era(): # Take the output grid from realch1-regrid (rotated lat lon). # regrid all variables *except* tp directly from era5 data # regrid the pt variable from the netcdf data - # save the output as torch \ No newline at end of file + # save the output as torch + pass + +def main(): + # read REA-L-CH1 latlon grid + era_config_file = sys.argv[1] + realch1_latlon_file = sys.argv[2] + netcdf_file = sys.argv[3] + output_directory = sys.argv[4] + + realch1_grid = torch.load(realch1_latlon_file, weights_only=False) + # read ERA input + min_lat = min(realch1_grid[:,0]) - interpolate_basic.ERA_MARGIN_DEGREES + max_lat = min(realch1_grid[:,0]) + interpolate_basic.ERA_MARGIN_DEGREES + min_lon = min(realch1_grid[:,1]) - interpolate_basic.ERA_MARGIN_DEGREES + max_lon = min(realch1_grid[:,1]) + interpolate_basic.ERA_MARGIN_DEGREES + era = interpolate_basic.read_era5_anemoi(era_config_file, + area=(max_lat, min_lon, min_lat, max_lon)) + era_grid = np.column_stack((era.longitudes, era.latitudes)) + + # read copernicus input for tp variable + netcdf_data = netCDF4.Dataset(netcdf_file) + logging.info('processing netcdf data') + netcdf_latitudes, netcdf_longitudes = regrid_copernicus_tp.extract_lat_lon_025(netcdf_data) + netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) + # TODO: start and end date functionality + netcdf_tp_values = regrid_copernicus_tp.extract_values(netcdf_data, 'tp', start_date=era.start_date, end_date=era.end_date) + assert(netcdf_tp_values.shape[0] == era.shape[0]) + + + # Iterate over ERA time range, which should be subsetted in configuration. + for i in range(era.shape[0]): + t = era.dates[i] + # Get everything but the tp variable + tp_index = era.variables.index('tp') + logging.info('tp index' + tp_index) + # shape time, channel, ensemble, grid + era_for_time = np.delete(era[i,:,:,:], tp_index, axis=0) + era_regridded = interpolate_basic.regrid(era_for_time, era_grid, realch1_grid) + copernicus_regridded = interpolate_basic.regrid(netcdf_data[0,:], netcdf_grid, realch1_grid) + output=np.stack((era_regridded, copernicus_regridded), axis=1) + filename = os.path.join(output_directory, 'era-copernicus-interpolated', + interpolate_basic._format_date(t)) + torch.save(output, filename) + + return 0 \ No newline at end of file diff --git a/src/hirad/input_data/regrid_copernicus_tp.py b/src/hirad/input_data/regrid_copernicus_tp.py index 7e3f14f..9686601 100644 --- a/src/hirad/input_data/regrid_copernicus_tp.py +++ b/src/hirad/input_data/regrid_copernicus_tp.py @@ -62,20 +62,21 @@ def extract_lat_lon_n320(data): logging.info('extracting lat/lon') logging.info(f'lat lon shapes {lat.shape} {lon.shape}') -def extract_values(data, variable, area=None): +def extract_values(data: netCDF4.Dataset, variable, start_date=None, end_date=None, area=None): values = data[variable][:] print(values.shape) - if area: - lat = data['latitude'][:] - lon = data['longitude'][:] + #if area: + # Not sure this is working. + # lat = data['latitude'][:] + # lon = data['longitude'][:] # https://stackoverflow.com/questions/29135885/netcdf4-extract-for-subset-of-lat-lon - latli = np.argmin( np.abs(lat - area[2])) - latui = np.argmin( np.abs(lat - area[0])) - lonli = np.argmin( np.abs(lon - area[1])) - lonui = np.argmin( np.abs(lon - area[3])) - lat = data['latitude'][latli:latui] - lon = data['longitude'][lonli:lonui] - values = data[variable][latli:latui,lonli:lonui] + # latli = np.argmin( np.abs(lat - area[2])) + # latui = np.argmin( np.abs(lat - area[0])) + # lonli = np.argmin( np.abs(lon - area[1])) + # lonui = np.argmin( np.abs(lon - area[3])) + # lat = data['latitude'][latli:latui] + # lon = data['longitude'][lonli:lonui] + # values = data[variable][latli:latui,lonli:lonui] return np.reshape(values, (values.shape[0], values.shape[1]*values.shape[2])) def reshape_to_cosmo(vals): From d2ca5e50f9a4e7d99f1ad38f7792703db21f3ae4 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 4 Nov 2025 16:59:07 +0100 Subject: [PATCH 185/189] working script --- src/hirad/input_data/interpolate_basic.py | 12 +-- src/hirad/input_data/interpolate_realch1.py | 107 +++++++++++++------- 2 files changed, 77 insertions(+), 42 deletions(-) diff --git a/src/hirad/input_data/interpolate_basic.py b/src/hirad/input_data/interpolate_basic.py index f88b5e1..5a77429 100644 --- a/src/hirad/input_data/interpolate_basic.py +++ b/src/hirad/input_data/interpolate_basic.py @@ -136,11 +136,11 @@ def _save_datetime_file(values: np.ndarray[np.intp], variables: np.ndarray, date filename = filepath + _format_date(date) torch.save(values, filename) -def _save_latlon_grid(dataset: Dataset, filename: str): +def save_anemoi_latlon_grid(dataset: Dataset, filename: str): grid = np.column_stack((dataset.latitudes, dataset.longitudes)) torch.save(grid, filename) -def _save_stats(dataset: Dataset, filename: str): +def save_anemoi_stats(dataset: Dataset, filename: str): torch.save(dataset.statistics, filename) def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None, s = None): @@ -192,10 +192,10 @@ def interpolate_era5_cosmo_and_save(infile_era: str, infile_cosmo: str, outfile_ logging.info('Successfully read input') # Output stats and grid - _save_stats(era, os.path.join(outfile_data_path, "info/era-stats")) - _save_stats(cosmo, os.path.join(outfile_data_path, "info/cosmo-stats")) - _save_latlon_grid(cosmo, os.path.join(outfile_data_path, "info/cosmo-lat-lon")) - _save_latlon_grid(era, os.path.join(outfile_data_path, "info/era-lat-lon")) + save_anemoi_stats(era, os.path.join(outfile_data_path, "info/era-stats")) + save_anemoi_stats(cosmo, os.path.join(outfile_data_path, "info/cosmo-stats")) + save_anemoi_latlon_grid(cosmo, os.path.join(outfile_data_path, "info/cosmo-lat-lon")) + save_anemoi_latlon_grid(era, os.path.join(outfile_data_path, "info/era-lat-lon")) # Copy the .yaml files over for recording purposes shutil.copy(infile_cosmo, os.path.join(outfile_data_path, "info/cosmo.yaml")) diff --git a/src/hirad/input_data/interpolate_realch1.py b/src/hirad/input_data/interpolate_realch1.py index 2a2e1d7..e1a17bc 100644 --- a/src/hirad/input_data/interpolate_realch1.py +++ b/src/hirad/input_data/interpolate_realch1.py @@ -51,28 +51,6 @@ def _read_input(era_config_file: str, realch1_latlon_file: str) -> tuple[Dataset return (era, copernicus_netcdf, realch1_latlon) -def regrid_all(era: Dataset, realch1: Dataset, copernicus: array.array): - # iterate through the dates - realch1 = regrid_realch1 - # convert to xarray.dataarray - regrid.icon2rotlatlon - - pass - -def regrid_realch1(): - # Use the meteodatalab functions to regrid the realch1 anemoi data (one time point) - # onto the rotated lat lon - # save the output as torch - # return the np array - pass - -def regrid_era(): - # Take the output grid from realch1-regrid (rotated lat lon). - # regrid all variables *except* tp directly from era5 data - # regrid the pt variable from the netcdf data - # save the output as torch - pass - def main(): # read REA-L-CH1 latlon grid era_config_file = sys.argv[1] @@ -80,39 +58,96 @@ def main(): netcdf_file = sys.argv[3] output_directory = sys.argv[4] - realch1_grid = torch.load(realch1_latlon_file, weights_only=False) + logging.basicConfig( + filename=os.path.join(output_directory, 'interpolate_realch1.log'), + format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + + # Copy ERA yml file + shutil.copy(era_config_file, os.path.join(output_directory, 'info')) + + logging.info('reading realch1 lat/lon') + realch1_latlon = torch.load(realch1_latlon_file, weights_only=False) + realch1_lat = realch1_latlon[:,0] + realch1_lon = realch1_latlon[:,1] # read ERA input - min_lat = min(realch1_grid[:,0]) - interpolate_basic.ERA_MARGIN_DEGREES - max_lat = min(realch1_grid[:,0]) + interpolate_basic.ERA_MARGIN_DEGREES - min_lon = min(realch1_grid[:,1]) - interpolate_basic.ERA_MARGIN_DEGREES - max_lon = min(realch1_grid[:,1]) + interpolate_basic.ERA_MARGIN_DEGREES + min_lat = min(realch1_lat) - interpolate_basic.ERA_MARGIN_DEGREES + max_lat = max(realch1_lat) + interpolate_basic.ERA_MARGIN_DEGREES + min_lon = min(realch1_lon) - interpolate_basic.ERA_MARGIN_DEGREES + max_lon = max(realch1_lon) + interpolate_basic.ERA_MARGIN_DEGREES + logging.info('reading era') + era = interpolate_basic.read_era5_anemoi(era_config_file, area=(max_lat, min_lon, min_lat, max_lon)) era_grid = np.column_stack((era.longitudes, era.latitudes)) + realch1_grid = np.column_stack((realch1_lon, realch1_lat)) + logging.info(f'lat lon area is {min_lat}-{max_lat} {min_lon}-{max_lon}') + + # save era stats and lat lon + interpolate_basic.save_anemoi_latlon_grid(era, os.path.join(output_directory, 'info', 'era-lat-lon')) + interpolate_basic.save_anemoi_stats(era, os.path.join(output_directory, 'info', 'era-stats')) # read copernicus input for tp variable + logging.info('reading copernicus') netcdf_data = netCDF4.Dataset(netcdf_file) logging.info('processing netcdf data') netcdf_latitudes, netcdf_longitudes = regrid_copernicus_tp.extract_lat_lon_025(netcdf_data) netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes)) + # TODO: start and end date functionality netcdf_tp_values = regrid_copernicus_tp.extract_values(netcdf_data, 'tp', start_date=era.start_date, end_date=era.end_date) assert(netcdf_tp_values.shape[0] == era.shape[0]) + # todo: incorporate this somehow + netcdf_tp_values = netcdf_tp_values.reshape((netcdf_tp_values.shape[0], 1,1, netcdf_tp_values.shape[1])) + # save copernicus stats and lat lon + torch.save(np.column_stack((netcdf_grid[:,1], netcdf_grid[:,0])), + os.path.join(output_directory, 'info', 'copernicus-lat-lon')) + regrid_copernicus_tp.make_stats(os.path.join(output_directory, 'info'), + os.path.join(output_directory, 'info'), + netcdf_tp_values) # Iterate over ERA time range, which should be subsetted in configuration. + tp_index = era.variables.index('tp') + logging.info(f'tp index {tp_index}') + + plot_indices = {12} + + logging.info('interpolating') + #for i in plot_indices: for i in range(era.shape[0]): + # T t = era.dates[i] # Get everything but the tp variable - tp_index = era.variables.index('tp') - logging.info('tp index' + tp_index) - # shape time, channel, ensemble, grid - era_for_time = np.delete(era[i,:,:,:], tp_index, axis=0) + era_for_time = era[i,:,:,:] era_regridded = interpolate_basic.regrid(era_for_time, era_grid, realch1_grid) - copernicus_regridded = interpolate_basic.regrid(netcdf_data[0,:], netcdf_grid, realch1_grid) - output=np.stack((era_regridded, copernicus_regridded), axis=1) + # Regrid TP from copernicus + copernicus_regridded = interpolate_basic.regrid(netcdf_tp_values[i,:], netcdf_grid, realch1_grid) + # Concatenate and save + era_regridded[tp_index,:] = copernicus_regridded + #output=np.concatenate((era_regridded, copernicus_regridded), axis=0) + datefmt = interpolate_basic._format_date(t) filename = os.path.join(output_directory, 'era-copernicus-interpolated', - interpolate_basic._format_date(t)) - torch.save(output, filename) + datefmt) + torch.save(era_regridded, filename) + + if i in plot_indices: + realch1var = ['t2m', '10u', '10v', 'tp'] + realch1_data = torch.load(os.path.join(output_directory, 'realch1', datefmt), weights_only=False) + for j in range(realch1_data.shape[0]): + interpolate_basic.plot_and_save_projection(realch1_lon, realch1_lat, realch1_data[j,:], + os.path.join(output_directory, 'plots', + f'{datefmt}-{realch1var[j]}-realch1')) + for j in range(era_regridded.shape[0]): + interpolate_basic.plot_and_save_projection(era.longitudes, era.latitudes, era_for_time[j,:], + os.path.join(output_directory, 'plots', + f'{datefmt}-{era.variables[j]}-era')) + interpolate_basic.plot_and_save_projection(realch1_lon, realch1_lat, + era_regridded[j,0,:], + os.path.join(output_directory, 'plots', + f'{datefmt}-{era.variables[j]}-interpolated')) + return 0 - return 0 \ No newline at end of file +if __name__ == "__main__": + main() From d0b9c1fd65a365f1297018a239532b273777ec6e Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 5 Nov 2025 13:07:20 +0100 Subject: [PATCH 186/189] changes to copernicus regrid --- src/hirad/input_data/regrid_copernicus_tp.py | 37 +++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/src/hirad/input_data/regrid_copernicus_tp.py b/src/hirad/input_data/regrid_copernicus_tp.py index 9686601..adcea16 100644 --- a/src/hirad/input_data/regrid_copernicus_tp.py +++ b/src/hirad/input_data/regrid_copernicus_tp.py @@ -62,9 +62,9 @@ def extract_lat_lon_n320(data): logging.info('extracting lat/lon') logging.info(f'lat lon shapes {lat.shape} {lon.shape}') +# Get values for a given date range (inclusive) def extract_values(data: netCDF4.Dataset, variable, start_date=None, end_date=None, area=None): values = data[variable][:] - print(values.shape) #if area: # Not sure this is working. # lat = data['latitude'][:] @@ -77,6 +77,15 @@ def extract_values(data: netCDF4.Dataset, variable, start_date=None, end_date=No # lat = data['latitude'][latli:latui] # lon = data['longitude'][lonli:lonui] # values = data[variable][latli:latui,lonli:lonui] + date_indices = range(values.shape[0]) + if start_date and end_date: + date_indices = np.intersect1d( + np.where(data['valid_time'][:] >= start_date.astype(np.int64)), + np.where(data['valid_time'][:] <= end_date.astype(np.int64))) + values = values[date_indices,:] + if len(date_indices) == 0: + raise KeyError(f'{start_date} and {end_date} not valid range') + return np.reshape(values, (values.shape[0], values.shape[1]*values.shape[2])) def reshape_to_cosmo(vals): @@ -200,12 +209,7 @@ def process_era(netcdf_data, netcdf_tp_values): torch.save(era_data, os.path.join(OUTPUT_DATA_FILEPATH_ERA, date_filename)) t4 = datetime.datetime.now() -def make_stats(): - #cosmo_files = os.listdir(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'cosmo')) - #era_files = os.listdir(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'cosmo')) - - stats = torch.load(os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, 'info', 'era-stats'), weights_only=False) - print(stats) +def extract_all_values(): set1 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2015-2016.nc") set2 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2017-2018.nc") set3 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc") @@ -214,17 +218,24 @@ def make_stats(): set3_tp = extract_values(set3, 'tp') all_tp = np.row_stack((set1_tp, set2_tp, set3_tp)) print(all_tp.shape) - all_tp = all_tp.reshape(all_tp.shape[0] * all_tp.shape[1], 1) - mean = np.mean(all_tp) - max = np.max(all_tp) - min = np.min(all_tp) - stdev = np.std(all_tp) + return all_tp + +# Get stats from ERA and replace the TP variable with stats from Copernicus +def make_stats(input_stats_directory: str, output_stats_directory: str, extracted_tp_values: np.ndarray): + stats = torch.load(os.path.join(input_stats_directory, 'era-stats'), weights_only=False) + print(stats) + #extracted_tp_values = extracted_tp_values.reshape(extracted_tp_values.shape[0] * extracted_tp_values.shape[1], 1) + flat_values = extracted_tp_values.flatten() + mean = np.mean(flat_values) + max = np.max(flat_values) + min = np.min(flat_values) + stdev = np.std(flat_values) stats['mean'][TP_INDEX] = mean stats['maximum'][TP_INDEX] = max stats['minimum'][TP_INDEX] = min stats['stdev'][TP_INDEX] = stdev print(stats) - torch.save(stats, os.path.join(OUTPUT_DATA_FILEPATH_ERA_INTERPOLATED, 'era-stats')) + torch.save(stats, os.path.join(output_stats_directory, 'era-copernicus-stats')) #process_era(netcdf_data, netcdf_tp_values) From c9b0f37baaa2276193a5ff79cba13d4bff27423d Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 5 Nov 2025 13:07:50 +0100 Subject: [PATCH 187/189] config updates --- src/hirad/input_data/era.yaml | 7 +++++-- src/hirad/input_data/realch1.yaml | 8 ++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/hirad/input_data/era.yaml b/src/hirad/input_data/era.yaml index 3234321..9018031 100644 --- a/src/hirad/input_data/era.yaml +++ b/src/hirad/input_data/era.yaml @@ -1,4 +1,7 @@ -dataset: '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr' +#dataset: '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr' +dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr' select: ['2t', '10u', '10v', 'tcw', 't_850', 'z_850', 'u_850', 'v_850', 't_500', 'z_500', 'u_500', 'v_500', 'tp'] # See table S2 from corrdiff paper for the inputs. -# Note: Bounding dates/area will be done in .py code. \ No newline at end of file +# Note: Bounding dates/area will be done in .py code. +start: 2020-01-01 +end: 2020-01-31 \ No newline at end of file diff --git a/src/hirad/input_data/realch1.yaml b/src/hirad/input_data/realch1.yaml index d6ce166..c536704 100644 --- a/src/hirad/input_data/realch1.yaml +++ b/src/hirad/input_data/realch1.yaml @@ -1,4 +1,4 @@ -dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr' -select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC'] -start: 2020-01-01 -end: 2020-01-01 +dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr' +select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1 +#start: 2020-01-01 +#end: 2020-01-01 From 72ecbdec6e68b88bee6dca40d08cc0d4577a37ff Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Fri, 7 Nov 2025 21:34:43 +0100 Subject: [PATCH 188/189] updates to CRPS eval --- src/hirad/eval/compute_eval.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/hirad/eval/compute_eval.py b/src/hirad/eval/compute_eval.py index 8a01db6..00e524e 100644 --- a/src/hirad/eval/compute_eval.py +++ b/src/hirad/eval/compute_eval.py @@ -47,9 +47,10 @@ def main(cfg: DictConfig) -> None: dataset, sampler = get_dataset_and_sampler_inference( dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time ) - output_path = getattr(cfg.generation.io, "output_path", "./outputs") + pred_path = getattr(cfg.generation.io, "output_path", "./outputs") + output_path = './plots/analysis202511' - compute_crps_per_time(times, dataset, output_path) + compute_crps_per_time(times, dataset, pred_path, output_path) compute_crps_over_time_and_area(times, output_path) plot_crps_over_time_and_area(times, dataset, output_path) @@ -66,14 +67,14 @@ def save_data(data, output_path, time=None, filename=None): path = _get_data_path(output_path, time, filename) torch.save(data, path) -def compute_crps_per_time(times, dataset, output_path): +def compute_crps_per_time(times, dataset, pred_path, output_path): logging.info('Computing CRPS for each time point') input_channels = dataset.input_channels() output_channels = dataset.output_channels() start_time=times[0] # Load one prediction ensemble to get the shape - prediction_ensemble = torch.load(os.path.join(output_path, start_time, f'{start_time}-predictions'), weights_only=False) + prediction_ensemble = torch.load(os.path.join(pred_path, start_time, f'{start_time}-predictions'), weights_only=False) # Get a map of output to input channel, for building baseline errors output_to_input_channel_map = {} @@ -89,9 +90,9 @@ def compute_crps_per_time(times, dataset, output_path): curr_time = times[i] if i % (24*5) == 0: logging.info(f'on time {curr_time}') - prediction_ensemble = load_data(output_path, time=curr_time, filename=f'{curr_time}-predictions') - baseline = load_data(output_path, time=curr_time, filename=f'{curr_time}-baseline') - target = load_data(output_path, time=curr_time, filename=f'{curr_time}-target') + prediction_ensemble = load_data(pred_path, time=curr_time, filename=f'{curr_time}-predictions') + baseline = load_data(pred_path, time=curr_time, filename=f'{curr_time}-baseline') + target = load_data(pred_path, time=curr_time, filename=f'{curr_time}-target') # Calculate ensemble mean error ensemble_mean = np.mean(prediction_ensemble, 0) @@ -107,12 +108,12 @@ def compute_crps_per_time(times, dataset, output_path): # Calculate persistence error (baseline #2) persistence_error = np.zeros(target.shape) if i > 0: - prev = load_data(output_path, time=times[i-1], filename=f'{times[i-1]}-target') + prev = load_data(pred_path, time=times[i-1], filename=f'{times[i-1]}-target') persistence_error = absolute_error(prev, target) else: # for the first time point, persist the next-time-point target. # This is fiction but it keeps the plots from looking weird. - prev = load_data(output_path, time=times[i+1], filename=f'{times[i+1]}-target') + prev = load_data(pred_path, time=times[i+1], filename=f'{times[i+1]}-target') persistence_error = absolute_error(prev, target) From 0047a68481d249f71204b23a720fce3ec48dc6e4 Mon Sep 17 00:00:00 2001 From: David Leutwyler Date: Mon, 10 Nov 2025 10:40:22 +0100 Subject: [PATCH 189/189] fix shading --- src/hirad/eval/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index e863053..e83a625 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -67,7 +67,7 @@ def concat_and_group_diurnal(list_of_da, is_member=False, scale=1.0): if is_member: timmean = da.mean(dim='time') * scale mean = timmean.mean(dim='member') - std = timmean.std(dim='member') + std = da.std(dim='member').mean(dim='time') * scale else: mean = da.mean(dim='time') * scale std = None