diff --git a/graph_weather/configs/features_default.yaml b/graph_weather/configs/features_default.yaml new file mode 100644 index 00000000..7e286740 --- /dev/null +++ b/graph_weather/configs/features_default.yaml @@ -0,0 +1,188 @@ +# config replicates the original 78+24 feature set +# var names based on ERA5 dataset gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr +features: + # -------------------------------------------------------------------------- + # 78 Dynamic Features (6 Variables @ 13 Pressure Levels) + # -------------------------------------------------------------------------- + - name: "geopotential" + type: "dynamic" + source: "era5" + levels: + - 50 + - 100 + - 150 + - 200 + - 250 + - 300 + - 400 + - 500 + - 600 + - 700 + - 850 + - 925 + - 1000 + - name: "specific_humidity" + type: "dynamic" + source: "era5" + levels: + - 50 + - 100 + - 150 + - 200 + - 250 + - 300 + - 400 + - 500 + - 600 + - 700 + - 850 + - 925 + - 1000 + - name: "temperature" + type: "dynamic" + source: "era5" + levels: + - 50 + - 100 + - 150 + - 200 + - 250 + - 300 + - 400 + - 500 + - 600 + - 700 + - 850 + - 925 + - 1000 + - name: "u_component_of_wind" + type: "dynamic" + source: "era5" + levels: + - 50 + - 100 + - 150 + - 200 + - 250 + - 300 + - 400 + - 500 + - 600 + - 700 + - 850 + - 925 + - 1000 + - name: "v_component_of_wind" + type: "dynamic" + source: "era5" + levels: + - 50 + - 100 + - 150 + - 200 + - 250 + - 300 + - 400 + - 500 + - 600 + - 700 + - 850 + - 925 + - 1000 + - name: "vertical_velocity" + type: "dynamic" + source: "era5" + levels: + - 50 + - 100 + - 150 + - 200 + - 250 + - 300 + - 400 + - 500 + - 600 + - 700 + - 850 + - 925 + - 1000 + # -------------------------------------------------------------------------- + # 24 Static F& Single-Level Features ("non-NWP") + # -------------------------------------------------------------------------- + # 14 Static contants + - name: "angle_of_sub_gridscale_orography" + type: "static" + source: "constants" + - name: "anisotropy_of_sub_gridscale_orography" + type: "static" + source: "constants" + - name: "geopotential_at_surface" + type: "static" + source: "constants" + - name: "high_vegetation_cover" + type: "static" + source: "constants" + - name: "lake_cover" + type: "static" + source: "constants" + - name: "lake_depth" + type: "static" + source: "constants" + - name: "land_sea_mask" + type: "static" + source: "constants" + - name: "low_vegetation_cover" + type: "static" + source: "constants" + - name: "slope_of_sub_gridscale_orography" + type: "static" + source: "constants" + - name: "soil_type" + type: "static" + source: "constants" + - name: "standard_deviation_of_filtered_subgrid_orography" + type: "static" + source: "constants" + - name: "standard_deviation_of_orography" + type: "static" + source: "constants" + - name: "type_of_high_vegetation" + type: "static" + source: "constants" + - name: "type_of_low_vegetation" + type: "static" + source: "constants" + + # 6 Single-Level Variables + - name: "10m_u_component_of_wind" + type: "static" # Treated as static input for the model's purpose + source: "era5" + - name: "10m_v_component_of_wind" + type: "static" + source: "era5" + - name: "2m_temperature" + type: "static" + source: "era5" + - name: "mean_sea_level_pressure" + type: "static" + source: "era5" + - name: "surface_pressure" + type: "static" + source: "era5" + - name: "toa_incident_solar_radiation" + type: "static" + source: "era5" + + # 4 Derived Temporal Encodings + - name: "year_progress_sin" + type: "static" + source: "derived" + - name: "year_progress_cos" + type: "static" + source: "derived" + - name: "day_progress_sin" + type: "static" + source: "derived" + - name: "day_progress_cos" + type: "static" + source: "derived" diff --git a/graph_weather/configs/test_features.yaml b/graph_weather/configs/test_features.yaml new file mode 100644 index 00000000..db478793 --- /dev/null +++ b/graph_weather/configs/test_features.yaml @@ -0,0 +1,4 @@ +features: + - name: "2m_temperature" + type: "dynamic" + source: "era5" diff --git a/graph_weather/models/features.py b/graph_weather/models/features.py new file mode 100644 index 00000000..e6cfaf30 --- /dev/null +++ b/graph_weather/models/features.py @@ -0,0 +1,93 @@ +"""Manages the loading, parsing, and assembly of weather features from a configuration file.""" + +from typing import Dict, List + +import torch + +from graph_weather.utils.config import load_feature_config + + +class FeatureManager: + """Parses a feature config and provides an interface to manage feature dimensions and assembly.""" + + def __init__(self, config_path: str): + """Initializes the FeatureManager.""" + self.config = load_feature_config(config_path) + self._process_config() + + def _process_config(self): + """Internal method to parse loaded configuration.""" + self.feature_order: List[str] = [] + self.dynamic_variable_names: List[str] = [] + self.static_feature_names: List[str] = [] + self.level_map: Dict[str, List[int]] = {} + + num_dynamic_features = 0 + + for feature in self.config.features: + if feature.type == "dynamic": + self.dynamic_variable_names.append(feature.name) + if feature.levels: + self.dynamic_variable_names.append(feature.name) + self.level_map[feature.name] = feature.levels + for level in feature.levels: + self.feature_order.append(f"{feature.name}_L{level}") + num_dynamic_features += 1 + else: + self.feature_order.append(feature.name) + num_dynamic_features += 1 + elif feature.type == "static": + self.static_feature_names.append(feature.name) + self.feature_order.append(feature.name) + + self.num_features = len(self.feature_order) + self.num_dynamic_features = num_dynamic_features + self.num_static_features = len(self.static_feature_names) + + def assemble_features(self, data: Dict) -> torch.Tensor: + """ + Assembles final input tensor from a dictionary of raw data tensors. + + Args: + data (Dict): mapping feature names to tensors. + For multi-level variables, the tensors should + have shape (batch, nodes, levels). else (batch, nodes, 1). + + Returns: + single concatenated tensor of shape (batch, nodes, num_features). + + """ + feature_tensors = [] + + for feature_name_with_level in self.feature_order: + parts = feature_name_with_level.split("_") + + # if feature name corresponds to a multi-level variable + base_name = "_".join(parts[:-1]) + level_str = parts[-1] + + if base_name in self.level_map and level_str.isdigit(): + level = int(level_str) + # data[base_name] has shape (nodes, num_levels); index for this pressure level + try: + level_idx = self.level_map[base_name].index(level) + tensor_slice = data[base_name][:, level_idx].unsqueeze(-1) + feature_tensors.append(tensor_slice) + except (ValueError, KeyError) as e: + raise ValueError( + f"Level {level} for variable {base_name} not found in config or data." + ) from e + else: + # single level/ derived feature + base_name = feature_name_with_level + if base_name in data: + tensor = data[base_name] + if tensor.ndim == 1: + tensor = tensor.unsqueeze(-1) + feature_tensors.append(tensor) + else: + raise ValueError( + f"Feature {base_name} from config not found in data dictionary." + ) + + return torch.cat(feature_tensors, dim=-1) diff --git a/graph_weather/models/forecast.py b/graph_weather/models/forecast.py index 15d38b76..fcbbc4d3 100755 --- a/graph_weather/models/forecast.py +++ b/graph_weather/models/forecast.py @@ -1,7 +1,6 @@ """Model for forecasting weather from NWP states""" from dataclasses import dataclass -from typing import Optional import torch from einops import rearrange, repeat @@ -16,10 +15,9 @@ class GraphWeatherForecasterConfig: """Configuration for GraphWeatherForecaster model.""" lat_lons: list + input_features: int # Total number of input features i.e. 102 resolution: int = 2 - feature_dim: int = 78 - aux_dim: int = 24 - output_dim: Optional[int] = None + output_features: int = 78 # Number of features to predict i.e. 78 node_dim: int = 256 edge_dim: int = 256 num_blocks: int = 9 @@ -38,10 +36,9 @@ def build(self) -> "GraphWeatherForecaster": """Build GraphWeatherForecaster from this configuration.""" return GraphWeatherForecaster( lat_lons=self.lat_lons, + input_features=self.input_features, resolution=self.resolution, - feature_dim=self.feature_dim, - aux_dim=self.aux_dim, - output_dim=self.output_dim, + output_features=self.output_features, node_dim=self.node_dim, edge_dim=self.edge_dim, num_blocks=self.num_blocks, @@ -64,10 +61,9 @@ class GraphWeatherForecaster(torch.nn.Module, PyTorchModelHubMixin): def __init__( self, lat_lons: list, + input_features: int, + output_features: int, resolution: int = 2, - feature_dim: int = 78, - aux_dim: int = 24, - output_dim: Optional[int] = None, node_dim: int = 256, edge_dim: int = 256, num_blocks: int = 9, @@ -89,9 +85,8 @@ def __init__( lat_lons: List of latitude and longitudes for the grid resolution: Resolution of the H3 grid, prefer even resolutions, as odd ones have octogons and heptagons as well - feature_dim: Input feature size - aux_dim: Number of non-NWP features (i.e. landsea mask, lat/lon, etc) - output_dim: Optional, output feature size, useful if want only subset of variables in + input_features: Input feature size including non-NWP features (i.e. landsea mask, lat/lon, etc) + output_features: Optional, output feature size, useful if want only subset of variables in output node_dim: Node hidden dimension edge_dim: Edge hidden dimension @@ -110,12 +105,10 @@ def __init__( use_thermalizer: Whether to use the thermalizer layer """ super().__init__() - self.feature_dim = feature_dim self.constraint_type = constraint_type self.use_thermalizer = use_thermalizer - if output_dim is None: - output_dim = self.feature_dim - self.output_dim = output_dim + self.input_features = input_features + self.output_features = output_features # Compute the geographical grid shape from lat_lons. unique_lats = sorted(set(lat for lat, _ in lat_lons)) @@ -129,7 +122,7 @@ def __init__( self.encoder = Encoder( lat_lons=lat_lons, resolution=resolution, - input_dim=feature_dim + aux_dim, + input_dim=input_features, output_dim=node_dim, output_edge_dim=edge_dim, hidden_dim_processor_edge=hidden_dim_processor_edge, @@ -154,7 +147,7 @@ def __init__( lat_lons=lat_lons, resolution=resolution, input_dim=node_dim, - output_dim=output_dim, + output_dim=output_features, output_edge_dim=edge_dim, hidden_dim_processor_edge=hidden_dim_processor_edge, hidden_layers_processor_node=hidden_layers_processor_node, @@ -225,20 +218,20 @@ def forward(self, features: torch.Tensor, t: int = 0) -> torch.Tensor: """ x, edge_idx, edge_attr = self.encoder(features) x = self.processor(x, edge_idx, edge_attr, t) - x = self.decoder(x, features[..., : self.feature_dim]) + x = self.decoder(x, features[..., : self.output_features]) # Here, assume decoder output x is a 4D tensor, - # e.g. [B, output_dim, H, W] where H and W are grid dimensions. + # e.g. [B, output_features, H, W] where H and W are grid dimensions. # Convert graph output to grid format # Apply physical constraints to decoder output if self.constraint_type != "none": x = rearrange(x, "b (h w) c -> b c h w", h=self.grid_shape[0], w=self.grid_shape[1]) # Extract the low-res reference from the input. - # (Original features has shape [B, num_nodes, feature_dim]) - lr = features[..., : self.feature_dim] # shape: [B, num_nodes, feature_dim] + # (Original features has shape [B, num_nodes, input_features]) + lr = features[..., : self.input_features] # shape: [B, num_nodes, input_features] # Convert from node format to grid format using the grid_shape computed in __init__ - # From [B, num_nodes, feature_dim] to [B, feature_dim, H, W] + # From [B, num_nodes, input_features] to [B, input_features, H, W] lr = rearrange(lr, "b (h w) c -> b c h w", h=self.grid_shape[0], w=self.grid_shape[1]) if lr.size(1) != x.size(1): repeat_factor = x.size(1) // lr.size(1) diff --git a/graph_weather/utils/__init__.py b/graph_weather/utils/__init__.py new file mode 100644 index 00000000..b70d650a --- /dev/null +++ b/graph_weather/utils/__init__.py @@ -0,0 +1 @@ +from .config import FeatureConfig, FeatureSetConfig, load_feature_config diff --git a/graph_weather/utils/config.py b/graph_weather/utils/config.py new file mode 100644 index 00000000..0dbc9f0e --- /dev/null +++ b/graph_weather/utils/config.py @@ -0,0 +1,28 @@ +"""Configuration loading and validation utilities""" + +from typing import List, Optional + +import yaml +from pydantic import BaseModel + + +class FeatureConfig(BaseModel): + """Defines the schema for a single feature in the config.""" + + name: str + type: str + source: str + levels: Optional[List[int]] = None + + +class FeatureSetConfig(BaseModel): + """Defines the schema for the top-level feature configuration.""" + + features: List[FeatureConfig] + + +def load_feature_config(path: str) -> FeatureSetConfig: + """Loads and validates the feature configuration YAML file.""" + with open(path, "r") as f: + config_dict = yaml.safe_load(f) + return FeatureSetConfig(**config_dict) diff --git a/node_modules/.cache/prettier/.prettier-caches/2a99d6acf9781c4f323b319293fe33a4dc2b1c2d.json b/node_modules/.cache/prettier/.prettier-caches/2a99d6acf9781c4f323b319293fe33a4dc2b1c2d.json new file mode 100644 index 00000000..b35fcea8 --- /dev/null +++ b/node_modules/.cache/prettier/.prettier-caches/2a99d6acf9781c4f323b319293fe33a4dc2b1c2d.json @@ -0,0 +1 @@ +{"e91ea1115a29906ff3a73d11f480d0df724ac231":{"files":{".github/workflows/release.yaml":["sF/VGT1KDugkqH1HdsNmOCdQeeY=",true],".github/workflows/workflows.yaml":["xOtcnFZlf09B8cZWSKHzSFa3iRo=",true],".pre-commit-config.yaml":["KobOQ/RD1JtNs+oDleiLdW/NEeA=",true],"graph_weather/configs/features_default.yaml":["7abmauaGTfARkD7XpoRxIpwRhRg=",true],"environment_cpu.yml":["oO5PZGAhV2KMxeuh9lWZcT2gSWg=",true],"environment_cuda.yml":["dxMWA76LPK7BqoBvCKt5zLLyJl8=",true],"graph_weather/configs/test_features.yaml":["3nws8k2nsnUMjbP9xRAUxbxDKCI=",true]},"modified":1760562491812}} diff --git a/train/train_configurable.py b/train/train_configurable.py new file mode 100755 index 00000000..1a3c8825 --- /dev/null +++ b/train/train_configurable.py @@ -0,0 +1,409 @@ +"""PyTorch Lightning training script for the weather forecasting model""" + +import click +import datasets +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +from pysolar.util import extraterrestrial_irrad +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.utils.data import DataLoader + +from graph_weather import GraphWeatherForecaster +from graph_weather.data import const +from graph_weather.models.features import FeatureManager +from graph_weather.models.losses import NormalizedMSELoss + +const.FORECAST_MEANS = {var: np.asarray(value) for var, value in const.FORECAST_MEANS.items()} +const.FORECAST_STD = {var: np.asarray(value) for var, value in const.FORECAST_STD.items()} + + +def worker_init_fn(worker_id): + """ + Initialize random seed for worker. + + Args: + worker_id (int): ID of the worker. + + Returns: + None + """ + np.random.seed(np.random.get_state()[1][0] + worker_id) + + +def get_mean_stds(): + """ + Calculate means and standard deviations for forecast variables. + + Returns: + means and standard deviations dict for forecast variables + """ + names = [ + "CLMR", + "GRLE", + "VVEL", + "VGRD", + "UGRD", + "O3MR", + "CAPE", + "TMP", + "PLPL", + "DZDT", + "CIN", + "HGT", + "RH", + "ICMR", + "SNMR", + "SPFH", + "RWMR", + "TCDC", + "ABSV", + ] + means = {} + stds = {} + # For pressure level values + for n in names: + if ( + len( + sorted( + [ + float(var.split(".", 1)[-1].split("_")[0]) + for var in const.FORECAST_MEANS + if "mb" in var and n in var and "-" not in var + ] + ) + ) + > 0 + ): + means[n + "_mb"] = [] + stds[n + "_mb"] = [] + for value in sorted( + [ + float(var.split(".", 1)[-1].split("_")[0]) + for var in const.FORECAST_MEANS + if "mb" in var and n in var and "-" not in var + ] + ): + # Is floats now, but will be fixed + if value >= 1: + value = int(value) + var_name = f"{n}.{value}_mb" + # print(var_name) + + means[n + "_mb"].append(const.FORECAST_MEANS[var_name]) + stds[n + "_mb"].append(const.FORECAST_STD[var_name]) + means[n + "_mb"] = np.mean(np.stack(means[n + "_mb"], axis=-1)) + stds[n + "_mb"] = np.mean(np.stack(stds[n + "_mb"], axis=-1)) + + # For surface values + for n in list( + set( + [ + var.split(".", 1)[0] + for var in const.FORECAST_MEANS + if "surface" in var + and "level" not in var + and "2e06" not in var + and "below" not in var + and "atmos" not in var + and "tropo" not in var + and "iso" not in var + and "planetary_boundary_layer" not in var + ] + ) + ): + means[n] = const.FORECAST_MEANS[n + ".surface"] + stds[n] = const.FORECAST_STD[n + ".surface"] + + # For Cloud levels + for n in list( + set( + [ + var.split(".", 1)[0] + for var in const.FORECAST_MEANS + if "sigma" not in var + and "level" not in var + and "2e06" not in var + and "below" not in var + and "atmos" not in var + and "tropo" not in var + and "iso" not in var + and "planetary_boundary_layer" not in var + ] + ) + ): + if "LCDC" in n: # or "MCDC" in n or "HCDC" in n: + means[n] = const.FORECAST_MEANS["LCDC.low_cloud_layer"] + stds[n] = const.FORECAST_STD["LCDC.low_cloud_layer"] + if "MCDC" in n: # or "HCDC" in n: + means[n] = const.FORECAST_MEANS["MCDC.middle_cloud_layer"] + stds[n] = const.FORECAST_STD["MCDC.middle_cloud_layer"] + if "HCDC" in n: + means[n] = const.FORECAST_MEANS["HCDC.high_cloud_layer"] + stds[n] = const.FORECAST_STD["HCDC.high_cloud_layer"] + + # Now for each of these + means["max_wind"] = [] + stds["max_wind"] = [] + for n in sorted([var for var in const.FORECAST_MEANS if "max_wind" in var]): + means["max_wind"].append(const.FORECAST_MEANS[n]) + stds["max_wind"].append(const.FORECAST_STD[n]) + means["max_wind"] = np.stack(means["max_wind"], axis=-1) + stds["max_wind"] = np.stack(stds["max_wind"], axis=-1) + + for i in [2, 10, 20, 30, 40, 50, 80, 100]: + means[f"{i}m_above_ground"] = [] + stds[f"{i}m_above_ground"] = [] + for n in sorted([var for var in const.FORECAST_MEANS if f"{i}_m_above_ground" in var]): + means[f"{i}m_above_ground"].append(const.FORECAST_MEANS[n]) + stds[f"{i}m_above_ground"].append(const.FORECAST_STD[n]) + means[f"{i}m_above_ground"] = np.stack(means[f"{i}m_above_ground"], axis=-1) + stds[f"{i}m_above_ground"] = np.stack(stds[f"{i}m_above_ground"], axis=-1) + return means, stds + + +means, stds = get_mean_stds() + + +def process_data(data): + """Process the input data.""" + data.update( + { + key: np.expand_dims(np.asarray(value), axis=-1) + for key, value in data.items() + if key.replace("current_", "").replace("next_", "") in means.keys() + and np.asarray(value).ndim == 2 + } + ) + input_data = { + key.replace("current_", ""): torch.from_numpy( + (value - means[key.replace("current_", "")]) / stds[key.replace("current_", "")] + ) + for key, value in data.items() + if "current" in key and "time" not in key + } + output_data = { + key.replace("next_", ""): torch.from_numpy( + (value - means[key.replace("next_", "")]) / stds[key.replace("next_", "")] + ) + for key, value in data.items() + if "next" in key and "time" not in key + } + lat_lons = np.array( + np.meshgrid(np.asarray(data["latitude"]).flatten(), np.asarray(data["longitude"]).flatten()) + ).T.reshape((-1, 2)) + sin_lat_lons = np.sin(lat_lons * np.pi / 180.0) + cos_lat_lons = np.cos(lat_lons * np.pi / 180.0) + date = pd.to_datetime(data["timestamps"][0], utc=True) + solar_times = [ + np.array( + [ + extraterrestrial_irrad( + when=date.to_pydatetime(), latitude_deg=lat, longitude_deg=lon + ) + for lat, lon in lat_lons + ] + ) + ] + for when in pd.date_range( + date - pd.Timedelta("12 hours"), date + pd.Timedelta("12 hours"), freq="h" + ): + solar_times.append( + np.array( + [ + extraterrestrial_irrad( + when=when.to_pydatetime(), latitude_deg=lat, longitude_deg=lon + ) + for lat, lon in lat_lons + ] + ) + ) + solar_times = np.array(solar_times) + # Normalize to between -1 and 1 + solar_times -= const.SOLAR_MEAN + solar_times /= const.SOLAR_STD + input_data = torch.concat([value for _, value in input_data.items()], dim=-1) + output_data = torch.concat([value for _, value in output_data.items()], dim=-1) + input_data = input_data.transpose(0, 1).reshape(-1, input_data.shape[-1]) + output_data = output_data.transpose(0, 1).reshape(-1, input_data.shape[-1]) + + # CORRECTED LINE: Use only the first timestamp for the example + day_of_year = pd.to_datetime(data["timestamps"][0], utc=True).dayofyear / 366.0 + + sin_of_year = np.ones_like(lat_lons)[:, 0] * np.sin(day_of_year) + cos_of_year = np.ones_like(lat_lons)[:, 0] * np.cos(day_of_year) + to_concat = [ + input_data, + torch.permute(torch.from_numpy(solar_times), (1, 0)), + torch.from_numpy(sin_lat_lons), + torch.from_numpy(cos_lat_lons), + torch.from_numpy(np.expand_dims(sin_of_year, axis=-1)), + torch.from_numpy(np.expand_dims(cos_of_year, axis=-1)), + ] + input_data = torch.concat(to_concat, dim=-1) + new_data = { + "input": input_data.float().numpy(), + "output": output_data.float().numpy(), + "has_nans": not np.isnan(input_data.float().numpy()).any() + and not np.isnan(output_data.float().numpy()).any(), + } + return new_data + + +class GraphDataModule(pl.LightningDataModule): + def __init__(self, deg: str = "2.0", batch_size: int = 1): + super().__init__() + self.batch_size = batch_size + self.dataset = datasets.load_dataset( + "openclimatefix/gfs-surface-pressure-2deg", split="train+validation", streaming=False + ) + features = datasets.Features( + { + "input": datasets.Array2D(shape=(16380, 637), dtype="float32"), + "output": datasets.Array2D(shape=(16380, 605), dtype="float32"), + "has_nans": datasets.Value("bool"), + } + ) + self.dataset = ( + self.dataset.map( + process_data, + remove_columns=self.dataset.column_names, + features=features, + num_proc=16, + writer_batch_size=2, + ) + .filter(lambda x: x["has_nans"]) + .with_format("torch") + ) + + def train_dataloader(self): + return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=2) + + +class LitGraphForecaster(pl.LightningModule): + def __init__( + self, + lat_lons: list, + input_features: int, + output_features: int, + hidden_dim: int = 64, + num_blocks: int = 3, + lr: float = 3e-4, + ): + super().__init__() + self.model = GraphWeatherForecaster( + lat_lons, + input_features=input_features, + output_features=output_features, + hidden_dim_decoder=hidden_dim, + hidden_dim_processor_node=hidden_dim, + hidden_dim_processor_edge=hidden_dim, + num_blocks=num_blocks, + ) + self.criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=np.ones((output_features,)) + ) + self.lr = lr + self.save_hyperparameters() + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch["input"], batch["output"] + if torch.isnan(x).any() or torch.isnan(y).any(): + return None + y_hat = self.forward(x) + loss = self.criterion(y_hat, y) + return loss + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.lr) + + +@click.command() +@click.option( + "--feature-config", + default="graph_weather/configs/features_default.yaml", + help="Path to the feature configuration YAML file.", + type=click.Path(exists=True), +) +@click.option( + "--num-blocks", + default=5, + help="Number of processor blocks in the model.", + type=click.INT, +) +@click.option( + "--hidden", + default=32, + help="Hidden dimension size.", + type=click.INT, +) +@click.option( + "--batch", + default=1, + help="Batch size for training.", + type=click.INT, +) +@click.option( + "--gpus", + default=1, + help="Number of GPUs to use.", + type=click.INT, +) +@click.option("--fast-dev-run", is_flag=True, default=False, help="Run a single batch for testing.") +def run(feature_config, num_blocks, hidden, batch, gpus, fast_dev_run): + """ + Training process. + """ + # 1. Instantiate FeatureManager to get the correct dimensions + feature_manager = FeatureManager(config_path=feature_config) + + # Get lat_lons from a sample of the dataset + hf_ds = datasets.load_dataset( + "openclimatefix/gfs-surface-pressure-2deg", split="train", streaming=False + ) + + if fast_dev_run: + hf_ds = hf_ds.select(range(16)) # Use a smaller subset for fast dev run + + example_batch = next(iter(hf_ds)) + lat_lons = np.array( + np.meshgrid( + np.asarray(example_batch["latitude"]).flatten(), + np.asarray(example_batch["longitude"]).flatten(), + ) + ).T.reshape((-1, 2)) + + checkpoint_callback = ModelCheckpoint(dirpath="./", save_top_k=2, monitor="loss") + + # 2. Instantiate the DataModule (it no longer needs the manager) + dset = GraphDataModule(batch_size=batch) + + # 3. Instantiate the Lightning model, passing the feature counts from the manager + # NOTE: This specific dataset has hardcoded dimensions. We will use those for now. + # The flexibility is in the GraphWeatherForecaster model itself. + # When you add a new dataloader, you would use feature_manager.num_features. + model = LitGraphForecaster( + lat_lons=lat_lons, + input_features=637, # Hardcoded dimension from this specific dataset + output_features=605, # Hardcoded dimension from this specific dataset + num_blocks=num_blocks, + hidden_dim=hidden, + ) + + trainer = pl.Trainer( + accelerator="gpu", + devices=gpus, + max_epochs=100, + precision=16, + callbacks=[checkpoint_callback], + fast_dev_run=fast_dev_run, + ) + trainer.fit(model, dset) + + +if __name__ == "__main__": + run()