diff --git a/graph_weather/models/cafa/__init__.py b/graph_weather/models/cafa/__init__.py new file mode 100644 index 0000000..5301617 --- /dev/null +++ b/graph_weather/models/cafa/__init__.py @@ -0,0 +1,14 @@ +""" +CaFA (Climate-Aware Factorized Attention)'s Architectural Design: +- Transformer-based weather forecast for computational efficiency +- Uses Factorized Attention to reduce the cost of the attention mechanism +- A Three-Part System for Efficient Forecasting: Encoder, Factorized Transformer, Decoder +""" + +from .decoder import CaFADecoder +from .encoder import CaFAEncoder +from .factorize import AxialAttention, FactorizedAttention, FactorizedTransformerBlock +from .model import CaFAForecaster +from .processor import CaFAProcessor + +__version__ = "0.1.0" diff --git a/graph_weather/models/cafa/decoder.py b/graph_weather/models/cafa/decoder.py new file mode 100644 index 0000000..3a320db --- /dev/null +++ b/graph_weather/models/cafa/decoder.py @@ -0,0 +1,38 @@ +import torch +from torch import nn + + +class CaFADecoder(nn.Module): + """ + Decoder for for CaFA + After the Processor and FactorizedTransformer generated a prediction + in the latent space, the decoder's role is to translate this abstract + representation back into a physical prediction + """ + + def __init__(self, model_dim: int, output_channels: int, upsampling_factor: int = 1): + """ + Args: + output_channels: No. of channels/features in output prediction + model_dim: Dimensions of the model's hidden layers (output channels) + upsampling_factor: Factor to upsample the spatial dimensions. + Must match the downsampling factor in encoder. + """ + super().__init__() + self.decoder = nn.ConvTranspose2d( + in_channels=model_dim, + out_channels=output_channels, + kernel_size=upsampling_factor, + stride=upsampling_factor, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch, model_dim, height, width). + + Returns: + Output tensor of shape (batch, output_channels, height*factor, width*factor) + """ + x = self.decoder(x) + return x diff --git a/graph_weather/models/cafa/encoder.py b/graph_weather/models/cafa/encoder.py new file mode 100644 index 0000000..a2b78be --- /dev/null +++ b/graph_weather/models/cafa/encoder.py @@ -0,0 +1,37 @@ +import torch +from torch import nn + + +class CaFAEncoder(nn.Module): + """ + Encoder for CaFA + This projects complex, high-resolution input weather state + and transform it into a lower-resolution, high-dimensional + latent representation that the processor can work with + """ + + def __init__(self, input_channels: int, model_dim: int, downsampling_factor: int = 1): + """ + Args: + input_channel: No. of channels/features in raw input data + model_dim: Dimensions of the model's hidden layers (output channels) + downsampling_factor: Factor to downsample the spatial dimensions by (i.e., 2 means H/2, W/2) + """ + super().__init__() + self.encoder = nn.Conv2d( + in_channels=input_channels, + out_channels=model_dim, + kernel_size=downsampling_factor, + stride=downsampling_factor, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch, channels, height, width) + + Returns: + Encoded tensor of shape (batch, model_dim, height/downsampling_factor, width/downsampling_factor) + """ + x = self.encoder(x) + return x diff --git a/graph_weather/models/cafa/factorize.py b/graph_weather/models/cafa/factorize.py new file mode 100644 index 0000000..f230476 --- /dev/null +++ b/graph_weather/models/cafa/factorize.py @@ -0,0 +1,124 @@ +""" +Core components for the Factorized Attention mechanism, +based on the principles of Axial Attention. +""" + +from einops import rearrange +from torch import einsum, nn + + +def FeedFoward(dim, multiply=4, dropout=0.0): + """ + Standard feed-forward block used in transformer architecture. + Consists of 2 linear layers with GELU activation and dropouts, in between. + """ + inner_dim = int(dim * multiply) + return nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + nn.Dropout(dropout), + ) + + +class AxialAttention(nn.Module): + """ + Performs multi-head self-attention on a single axis of a 2D feature map. + Core building block for Factorized Attention. + """ + + def __init__(self, dim, heads, dim_head=64, dropout=0.0): + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + inner_dim = dim_head * heads + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Linear(inner_dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, axis): + """ + Forward pass for axial attention + Args: + x: Input tensor of shape (batch, height, width, channels) + axis: Axis to perform attention on (1 for height, 2 for width) + """ + b, h, w, d = x.shape + + # rearrange tensor to isolate attention axis as the sequence dim + if axis == 1: + x = rearrange(x, "b h w d -> (b w) h d") # attention along height + elif axis == 2: + x = rearrange(x, "b h w d -> (b h) w d") # attention along width + else: + raise ValueError("Axis must be 1 (height) or 2 (width)") + + # project to query, key and value tensors + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v) + ) # reshape for multi-head attn + + sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale # attention scores + attn = sim.softmax(dim=-1) + attn = self.dropout(attn) + + # attn to the value tensors + out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + + # original 2D grid format + if axis == 1: + out = rearrange(out, "(b w) h d -> b h w d", w=w) + elif axis == 2: + out = rearrange(out, "(b h) w d -> b h w d", h=h) + + return out + + +class FactorizedAttention(nn.Module): + """ + Combines 2 AxialAttention blocks to perform full factorized attention + over a 2D feature map, first along height then along width. + """ + + def __init__(self, dim, heads, dim_head=64, dropout=0.0): + super().__init__() + self.attn_height = AxialAttention(dim, heads, dim_head, dropout) + self.attn_width = AxialAttention(dim, heads, dim_head, dropout) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + def forward(self, x): + """ + Args: + x: Input tensor of shape (batch, height, width, channels) + """ + x = x + self.attn_height(self.norm1(x), axis=1) + x = x + self.attn_width(self.norm2(x), axis=2) + return x + + +class FactorizedTransformerBlock(nn.Module): + """ + Standalone transformer block using Factorized attention + """ + + def __init__(self, dim, heads, dim_head=64, ff_mult=4, dropout=0.0): + super().__init__() + self.attn = FactorizedAttention(dim, heads, dim_head, dropout) + self.ffn = FeedFoward(dim, ff_mult, dropout) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + def forward(self, x): + """ + Args: + x: Input tensor of shape (batch, height, width, channels) + """ + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x diff --git a/graph_weather/models/cafa/model.py b/graph_weather/models/cafa/model.py new file mode 100644 index 0000000..145882a --- /dev/null +++ b/graph_weather/models/cafa/model.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from .decoder import CaFADecoder +from .encoder import CaFAEncoder +from .processor import CaFAProcessor + + +class CaFAForecaster(nn.Module): + """ + CaFA (Climate-Aware Factorized Attention) model + Puts together Encoder, Processor and Decoder into an end-to-end model + """ + + def __init__( + self, + input_channels: int, + output_channels: int, + model_dim: int = 256, + downsampling_factor: int = 2, + processor_depth: int = 6, + num_heads: int = 8, + dim_head: int = 64, + ff_mult: int = 4, + dropout: float = 0.0, + ): + """ + Args: + input_channels: No. of input channels/features + output_channels: No. of channels to predict + model_dim: Internal dimensions of the model + downsampling_factor: Down/Up-sampling factor in the encoder-decoder + processor_depth: No. of transformer blocks in the processor + num_heads: No. of attention heads in each block + dim_head: Dimension of each attention head + ff_mult: Multiplier for the feedforward network's inner dimension + dropout: Dropout rate + """ + super().__init__() + + self.downsampling_factor = downsampling_factor + + self.encoder = CaFAEncoder( + input_channels=input_channels, + model_dim=model_dim, + downsampling_factor=downsampling_factor, + ) + + self.processor = CaFAProcessor( + dim=model_dim, + depth=processor_depth, + heads=num_heads, + dim_head=dim_head, + ff_mult=ff_mult, + dropout=dropout, + ) + + self.decoder = CaFADecoder( + model_dim=model_dim, + output_channels=output_channels, + upsampling_factor=downsampling_factor, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch, input_channels, height, width) + + Returns: + Output tensor of shape (batch, output_channels, height, width) + """ + + # to handle odd-sized inputs, we pad the input to be divisible by downsampling factor + _, _, h, w = x.shape + pad_h = ( + self.downsampling_factor - (h % self.downsampling_factor) + ) % self.downsampling_factor + pad_w = ( + self.downsampling_factor - (w % self.downsampling_factor) + ) % self.downsampling_factor + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, pad_w, 0, pad_h)) + + x = self.encoder(x) + x = self.processor(x) + x = self.decoder(x) + + if pad_h > 0 or pad_w > 0: + x = x[:, :, :h, :w] + + return x diff --git a/graph_weather/models/cafa/processor.py b/graph_weather/models/cafa/processor.py new file mode 100644 index 0000000..51dd6aa --- /dev/null +++ b/graph_weather/models/cafa/processor.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from einops import rearrange + +from .factorize import FactorizedTransformerBlock + + +class CaFAProcessor(nn.Module): + """ + Processor module for CaFA + Handles latent feature map through multiple layers of self-attention, + allowing information to propagate across the entire global grid. + """ + + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int = 64, + ff_mult: int = 4, + dropout: float = 0.0, + ): + """ + Args: + dim: No. of input channels/ features + depth: No. of FactorizedTransformerBlocks to stack + heads: No. of attention heads in each block + dim_head: Dimension of each attention head + ff_mult: Multiplier for the feedforward network dimension + dropout: Dropout rate + """ + super().__init__() + self.blocks = nn.ModuleList( + [ + FactorizedTransformerBlock(dim, heads, dim_head, ff_mult, dropout) + for _ in range(depth) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch, height, width, channels) + + Returns: + Refined tensor of same shape + """ + x = rearrange(x, "b c h w -> b h w c") + for block in self.blocks: + x = block(x) + x = rearrange(x, "b h w c -> b c h w") + return x diff --git a/tests/test_cafa.py b/tests/test_cafa.py new file mode 100644 index 0000000..2a746df --- /dev/null +++ b/tests/test_cafa.py @@ -0,0 +1,84 @@ +import torch +import pytest + +from graph_weather.models.cafa.encoder import CaFAEncoder +from graph_weather.models.cafa.processor import CaFAProcessor +from graph_weather.models.cafa.decoder import CaFADecoder +from graph_weather.models.cafa.model import CaFAForecaster + +# common params for test +BATCH_SIZE = 2 +HEIGHT = 32 +WIDTH = 64 +MODEL_DIM = 128 +INPUT_CHANNELS = 3 +OUTPUT_CHANNELS = 3 +HEADS = 4 +DEPTH = 2 +DOWNSAMPLING = 2 + + +def test_encoder(): + """Tests the CaFAEncoder for correct shape transformation.""" + x = torch.randn(BATCH_SIZE, INPUT_CHANNELS, HEIGHT, WIDTH) + encoder = CaFAEncoder( + input_channels=INPUT_CHANNELS, model_dim=MODEL_DIM, downsampling_factor=DOWNSAMPLING + ) + output = encoder(x) + + assert output.shape == (BATCH_SIZE, MODEL_DIM, HEIGHT // DOWNSAMPLING, WIDTH // DOWNSAMPLING) + + +def test_decoder(): + """Tests the CaFADecoder for correct shape transformation.""" + x = torch.randn(BATCH_SIZE, MODEL_DIM, HEIGHT // DOWNSAMPLING, WIDTH // DOWNSAMPLING) + decoder = CaFADecoder( + model_dim=MODEL_DIM, output_channels=OUTPUT_CHANNELS, upsampling_factor=DOWNSAMPLING + ) + output = decoder(x) + + assert output.shape == (BATCH_SIZE, OUTPUT_CHANNELS, HEIGHT, WIDTH) + + +def test_processor(): + """Tests the CaFAProcessor to ensure it preserves shape.""" + x = torch.randn(BATCH_SIZE, MODEL_DIM, HEIGHT, WIDTH) + processor = CaFAProcessor(dim=MODEL_DIM, depth=DEPTH, heads=HEADS) + output = processor(x) + + assert output.shape == x.shape + + +def test_cafa_forecaster_end_to_end(): + """Tests the full CaFAForecaster model to ensure it works end-to-end.""" + x = torch.randn(BATCH_SIZE, INPUT_CHANNELS, HEIGHT, WIDTH) + model = CaFAForecaster( + input_channels=INPUT_CHANNELS, + output_channels=OUTPUT_CHANNELS, + model_dim=MODEL_DIM, + downsampling_factor=DOWNSAMPLING, + processor_depth=DEPTH, + num_heads=HEADS, + ) + output = model(x) + + assert output.shape == (BATCH_SIZE, OUTPUT_CHANNELS, HEIGHT, WIDTH) + + +def test_cafa_forecaster_odd_dimensions(): + """Tests that the model's internal padding handles odd-sized inputs correctly.""" + + # Use odd dimensions for height and width + x = torch.randn(BATCH_SIZE, INPUT_CHANNELS, HEIGHT + 1, WIDTH + 1) + model = CaFAForecaster( + input_channels=INPUT_CHANNELS, + output_channels=OUTPUT_CHANNELS, + model_dim=MODEL_DIM, + downsampling_factor=DOWNSAMPLING, + processor_depth=DEPTH, + num_heads=HEADS, + ) + output = model(x) + + # The model should return a tensor with the original odd-sized dimensions + assert output.shape == x.shape