Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions graph_weather/models/cafa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
CaFA (Climate-Aware Factorized Attention)'s Architectural Design:
- Transformer-based weather forecast for computational efficiency
- Uses Factorized Attention to reduce the cost of the attention mechanism
- A Three-Part System for Efficient Forecasting: Encoder, Factorized Transformer, Decoder
"""

from .decoder import CaFADecoder
from .encoder import CaFAEncoder
from .factorize import AxialAttention, FactorizedAttention, FactorizedTransformerBlock
from .model import CaFAForecaster
from .processor import CaFAProcessor

__version__ = "0.1.0"
38 changes: 38 additions & 0 deletions graph_weather/models/cafa/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from torch import nn


class CaFADecoder(nn.Module):
"""
Decoder for for CaFA
After the Processor and FactorizedTransformer generated a prediction
in the latent space, the decoder's role is to translate this abstract
representation back into a physical prediction
"""

def __init__(self, model_dim: int, output_channels: int, upsampling_factor: int = 1):
"""
Args:
output_channels: No. of channels/features in output prediction
model_dim: Dimensions of the model's hidden layers (output channels)
upsampling_factor: Factor to upsample the spatial dimensions.
Must match the downsampling factor in encoder.
"""
super().__init__()
self.decoder = nn.ConvTranspose2d(
in_channels=model_dim,
out_channels=output_channels,
kernel_size=upsampling_factor,
stride=upsampling_factor,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (batch, model_dim, height, width).

Returns:
Output tensor of shape (batch, output_channels, height*factor, width*factor)
"""
x = self.decoder(x)
return x
37 changes: 37 additions & 0 deletions graph_weather/models/cafa/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torch import nn


class CaFAEncoder(nn.Module):
"""
Encoder for CaFA
This projects complex, high-resolution input weather state
and transform it into a lower-resolution, high-dimensional
latent representation that the processor can work with
"""

def __init__(self, input_channels: int, model_dim: int, downsampling_factor: int = 1):
"""
Args:
input_channel: No. of channels/features in raw input data
model_dim: Dimensions of the model's hidden layers (output channels)
downsampling_factor: Factor to downsample the spatial dimensions by (i.e., 2 means H/2, W/2)
"""
super().__init__()
self.encoder = nn.Conv2d(
in_channels=input_channels,
out_channels=model_dim,
kernel_size=downsampling_factor,
stride=downsampling_factor,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (batch, channels, height, width)

Returns:
Encoded tensor of shape (batch, model_dim, height/downsampling_factor, width/downsampling_factor)
"""
x = self.encoder(x)
return x
124 changes: 124 additions & 0 deletions graph_weather/models/cafa/factorize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Core components for the Factorized Attention mechanism,
based on the principles of Axial Attention.
"""

from einops import rearrange
from torch import einsum, nn


def FeedFoward(dim, multiply=4, dropout=0.0):
"""
Standard feed-forward block used in transformer architecture.
Consists of 2 linear layers with GELU activation and dropouts, in between.
"""
inner_dim = int(dim * multiply)
return nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout),
)


class AxialAttention(nn.Module):
"""
Performs multi-head self-attention on a single axis of a 2D feature map.
Core building block for Factorized Attention.
"""

def __init__(self, dim, heads, dim_head=64, dropout=0.0):
super().__init__()
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = dim_head * heads

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x, axis):
"""
Forward pass for axial attention
Args:
x: Input tensor of shape (batch, height, width, channels)
axis: Axis to perform attention on (1 for height, 2 for width)
"""
b, h, w, d = x.shape

# rearrange tensor to isolate attention axis as the sequence dim
if axis == 1:
x = rearrange(x, "b h w d -> (b w) h d") # attention along height
elif axis == 2:
x = rearrange(x, "b h w d -> (b h) w d") # attention along width
else:
raise ValueError("Axis must be 1 (height) or 2 (width)")

# project to query, key and value tensors
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
) # reshape for multi-head attn

sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale # attention scores
attn = sim.softmax(dim=-1)
attn = self.dropout(attn)

# attn to the value tensors
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out(out)

# original 2D grid format
if axis == 1:
out = rearrange(out, "(b w) h d -> b h w d", w=w)
elif axis == 2:
out = rearrange(out, "(b h) w d -> b h w d", h=h)

return out


class FactorizedAttention(nn.Module):
"""
Combines 2 AxialAttention blocks to perform full factorized attention
over a 2D feature map, first along height then along width.
"""

def __init__(self, dim, heads, dim_head=64, dropout=0.0):
super().__init__()
self.attn_height = AxialAttention(dim, heads, dim_head, dropout)
self.attn_width = AxialAttention(dim, heads, dim_head, dropout)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)

def forward(self, x):
"""
Args:
x: Input tensor of shape (batch, height, width, channels)
"""
x = x + self.attn_height(self.norm1(x), axis=1)
x = x + self.attn_width(self.norm2(x), axis=2)
return x


class FactorizedTransformerBlock(nn.Module):
"""
Standalone transformer block using Factorized attention
"""

def __init__(self, dim, heads, dim_head=64, ff_mult=4, dropout=0.0):
super().__init__()
self.attn = FactorizedAttention(dim, heads, dim_head, dropout)
self.ffn = FeedFoward(dim, ff_mult, dropout)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)

def forward(self, x):
"""
Args:
x: Input tensor of shape (batch, height, width, channels)
"""
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
92 changes: 92 additions & 0 deletions graph_weather/models/cafa/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
import torch.nn.functional as F
from torch import nn

from .decoder import CaFADecoder
from .encoder import CaFAEncoder
from .processor import CaFAProcessor


class CaFAForecaster(nn.Module):
"""
CaFA (Climate-Aware Factorized Attention) model
Puts together Encoder, Processor and Decoder into an end-to-end model
"""

def __init__(
self,
input_channels: int,
output_channels: int,
model_dim: int = 256,
downsampling_factor: int = 2,
processor_depth: int = 6,
num_heads: int = 8,
dim_head: int = 64,
ff_mult: int = 4,
dropout: float = 0.0,
):
"""
Args:
input_channels: No. of input channels/features
output_channels: No. of channels to predict
model_dim: Internal dimensions of the model
downsampling_factor: Down/Up-sampling factor in the encoder-decoder
processor_depth: No. of transformer blocks in the processor
num_heads: No. of attention heads in each block
dim_head: Dimension of each attention head
ff_mult: Multiplier for the feedforward network's inner dimension
dropout: Dropout rate
"""
super().__init__()

self.downsampling_factor = downsampling_factor

self.encoder = CaFAEncoder(
input_channels=input_channels,
model_dim=model_dim,
downsampling_factor=downsampling_factor,
)

self.processor = CaFAProcessor(
dim=model_dim,
depth=processor_depth,
heads=num_heads,
dim_head=dim_head,
ff_mult=ff_mult,
dropout=dropout,
)

self.decoder = CaFADecoder(
model_dim=model_dim,
output_channels=output_channels,
upsampling_factor=downsampling_factor,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (batch, input_channels, height, width)

Returns:
Output tensor of shape (batch, output_channels, height, width)
"""

# to handle odd-sized inputs, we pad the input to be divisible by downsampling factor
_, _, h, w = x.shape
pad_h = (
self.downsampling_factor - (h % self.downsampling_factor)
) % self.downsampling_factor
pad_w = (
self.downsampling_factor - (w % self.downsampling_factor)
) % self.downsampling_factor
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, pad_w, 0, pad_h))

x = self.encoder(x)
x = self.processor(x)
x = self.decoder(x)

if pad_h > 0 or pad_w > 0:
x = x[:, :, :h, :w]

return x
53 changes: 53 additions & 0 deletions graph_weather/models/cafa/processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
from einops import rearrange

from .factorize import FactorizedTransformerBlock


class CaFAProcessor(nn.Module):
"""
Processor module for CaFA
Handles latent feature map through multiple layers of self-attention,
allowing information to propagate across the entire global grid.
"""

def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int = 64,
ff_mult: int = 4,
dropout: float = 0.0,
):
"""
Args:
dim: No. of input channels/ features
depth: No. of FactorizedTransformerBlocks to stack
heads: No. of attention heads in each block
dim_head: Dimension of each attention head
ff_mult: Multiplier for the feedforward network dimension
dropout: Dropout rate
"""
super().__init__()
self.blocks = nn.ModuleList(
[
FactorizedTransformerBlock(dim, heads, dim_head, ff_mult, dropout)
for _ in range(depth)
]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (batch, height, width, channels)

Returns:
Refined tensor of same shape
"""
x = rearrange(x, "b c h w -> b h w c")
for block in self.blocks:
x = block(x)
x = rearrange(x, "b h w c -> b c h w")
return x
Loading