Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions graph_weather/configs/features_default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# config replicates the original 78+24 feature set
# var names based on ERA5 dataset gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr
features:
# --------------------------------------------------------------------------
# 78 Dynamic Features (6 Variables @ 13 Pressure Levels)
# --------------------------------------------------------------------------
- name: "geopotential"
type: "dynamic"
source: "era5"
levels:
- 50
- 100
- 150
- 200
- 250
- 300
- 400
- 500
- 600
- 700
- 850
- 925
- 1000
- name: "specific_humidity"
type: "dynamic"
source: "era5"
levels:
- 50
- 100
- 150
- 200
- 250
- 300
- 400
- 500
- 600
- 700
- 850
- 925
- 1000
- name: "temperature"
type: "dynamic"
source: "era5"
levels:
- 50
- 100
- 150
- 200
- 250
- 300
- 400
- 500
- 600
- 700
- 850
- 925
- 1000
- name: "u_component_of_wind"
type: "dynamic"
source: "era5"
levels:
- 50
- 100
- 150
- 200
- 250
- 300
- 400
- 500
- 600
- 700
- 850
- 925
- 1000
- name: "v_component_of_wind"
type: "dynamic"
source: "era5"
levels:
- 50
- 100
- 150
- 200
- 250
- 300
- 400
- 500
- 600
- 700
- 850
- 925
- 1000
- name: "vertical_velocity"
type: "dynamic"
source: "era5"
levels:
- 50
- 100
- 150
- 200
- 250
- 300
- 400
- 500
- 600
- 700
- 850
- 925
- 1000
# --------------------------------------------------------------------------
# 24 Static F& Single-Level Features ("non-NWP")
# --------------------------------------------------------------------------
# 14 Static contants
- name: "angle_of_sub_gridscale_orography"
type: "static"
source: "constants"
- name: "anisotropy_of_sub_gridscale_orography"
type: "static"
source: "constants"
- name: "geopotential_at_surface"
type: "static"
source: "constants"
- name: "high_vegetation_cover"
type: "static"
source: "constants"
- name: "lake_cover"
type: "static"
source: "constants"
- name: "lake_depth"
type: "static"
source: "constants"
- name: "land_sea_mask"
type: "static"
source: "constants"
- name: "low_vegetation_cover"
type: "static"
source: "constants"
- name: "slope_of_sub_gridscale_orography"
type: "static"
source: "constants"
- name: "soil_type"
type: "static"
source: "constants"
- name: "standard_deviation_of_filtered_subgrid_orography"
type: "static"
source: "constants"
- name: "standard_deviation_of_orography"
type: "static"
source: "constants"
- name: "type_of_high_vegetation"
type: "static"
source: "constants"
- name: "type_of_low_vegetation"
type: "static"
source: "constants"

# 6 Single-Level Variables
- name: "10m_u_component_of_wind"
type: "static" # Treated as static input for the model's purpose
source: "era5"
- name: "10m_v_component_of_wind"
type: "static"
source: "era5"
- name: "2m_temperature"
type: "static"
source: "era5"
- name: "mean_sea_level_pressure"
type: "static"
source: "era5"
- name: "surface_pressure"
type: "static"
source: "era5"
- name: "toa_incident_solar_radiation"
type: "static"
source: "era5"

# 4 Derived Temporal Encodings
- name: "year_progress_sin"
type: "static"
source: "derived"
- name: "year_progress_cos"
type: "static"
source: "derived"
- name: "day_progress_sin"
type: "static"
source: "derived"
- name: "day_progress_cos"
type: "static"
source: "derived"
4 changes: 4 additions & 0 deletions graph_weather/configs/test_features.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
features:
- name: "2m_temperature"
type: "dynamic"
source: "era5"
93 changes: 93 additions & 0 deletions graph_weather/models/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Manages the loading, parsing, and assembly of weather features from a configuration file."""

from typing import Dict, List

import torch

from graph_weather.utils.config import load_feature_config


class FeatureManager:
"""Parses a feature config and provides an interface to manage feature dimensions and assembly."""

def __init__(self, config_path: str):
"""Initializes the FeatureManager."""
self.config = load_feature_config(config_path)
self._process_config()

def _process_config(self):
"""Internal method to parse loaded configuration."""
self.feature_order: List[str] = []
self.dynamic_variable_names: List[str] = []
self.static_feature_names: List[str] = []
self.level_map: Dict[str, List[int]] = {}

num_dynamic_features = 0

for feature in self.config.features:
if feature.type == "dynamic":
self.dynamic_variable_names.append(feature.name)
if feature.levels:
self.dynamic_variable_names.append(feature.name)
self.level_map[feature.name] = feature.levels
for level in feature.levels:
self.feature_order.append(f"{feature.name}_L{level}")
num_dynamic_features += 1
else:
self.feature_order.append(feature.name)
num_dynamic_features += 1
elif feature.type == "static":
self.static_feature_names.append(feature.name)
self.feature_order.append(feature.name)

self.num_features = len(self.feature_order)
self.num_dynamic_features = num_dynamic_features
self.num_static_features = len(self.static_feature_names)

def assemble_features(self, data: Dict) -> torch.Tensor:
"""
Assembles final input tensor from a dictionary of raw data tensors.

Args:
data (Dict): mapping feature names to tensors.
For multi-level variables, the tensors should
have shape (batch, nodes, levels). else (batch, nodes, 1).

Returns:
single concatenated tensor of shape (batch, nodes, num_features).

"""
feature_tensors = []

for feature_name_with_level in self.feature_order:
parts = feature_name_with_level.split("_")

# if feature name corresponds to a multi-level variable
base_name = "_".join(parts[:-1])
level_str = parts[-1]

if base_name in self.level_map and level_str.isdigit():
level = int(level_str)
# data[base_name] has shape (nodes, num_levels); index for this pressure level
try:
level_idx = self.level_map[base_name].index(level)
tensor_slice = data[base_name][:, level_idx].unsqueeze(-1)
feature_tensors.append(tensor_slice)
except (ValueError, KeyError) as e:
raise ValueError(
f"Level {level} for variable {base_name} not found in config or data."
) from e
else:
# single level/ derived feature
base_name = feature_name_with_level
if base_name in data:
tensor = data[base_name]
if tensor.ndim == 1:
tensor = tensor.unsqueeze(-1)
feature_tensors.append(tensor)
else:
raise ValueError(
f"Feature {base_name} from config not found in data dictionary."
)

return torch.cat(feature_tensors, dim=-1)
Loading