Skip to content

Add ying yang grid #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,14 @@ test_grid_visualize.png
*.png
*.jpg
*.jpeg
*.gif
public/

a.out
*.o

# editor backup files
# helix
\#*\#
# emacs
*~
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ New APIs
- ring2double
- to_rotated_pixelization
- to_double_pixelization
- earth2grid.projections. Grids in arbitrary projections
- earth2grid.yingyang

Breaking changes:

- change coordinate transform and shape of lcc grid

## 2025.4.1

Expand Down
10 changes: 9 additions & 1 deletion earth2grid/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
from typing import Dict, Sequence

import einops
import netCDF4 as nc
import torch

try:
import netCDF4 as nc
except ImportError:
nc = None

from scipy import spatial

from earth2grid.spatial import ang2vec, haversine_distance
Expand Down Expand Up @@ -59,6 +64,9 @@ def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder":
class TempestRegridder(torch.nn.Module):
def __init__(self, file_path):
super().__init__()
if nc is None:
raise ImportError("netCDF4 not imported. Please install for this feature.")

dataset = nc.Dataset(file_path)
self.lat = dataset["latc_b"][:]
self.lon = dataset["lonc_b"][:]
Expand Down
22 changes: 17 additions & 5 deletions earth2grid/latlon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@


class LatLonGrid(base.Grid):
def __init__(self, lat: list[float], lon: list[float]):
def __init__(self, lat: list[float], lon: list[float], cylinder: bool = True):
"""
Args:
lat: center of lat cells
lon: center of lon cells
cylinder: if true, then lon is considered a periodic coordinate
on cylinder so that interpolation wraps around the edge.
Otherwise, it is assumed to be a finite plane.
"""
self._lat = lat
self._lon = lon
self.cylinder = cylinder

@property
def lat(self):
Expand All @@ -48,7 +52,7 @@ def shape(self):

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""
return _RegridFromLatLon(self, lat, lon)
return _RegridFromLatLon(self, lat, lon, cylinder=self.cylinder)

def _lonb(self):
edges = (self.lon[1:] + self.lon[:-1]) / 2
Expand Down Expand Up @@ -78,15 +82,22 @@ def to_pyvista(self):
class _RegridFromLatLon(torch.nn.Module):
"""Regrid from lat-lon to unstructured grid with bilinear interpolation"""

def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray):
def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray, cylinder: bool = True):
"""
Args:
cylinder: if True than lon is assumed to be periodic
"""
super().__init__()
self.cylinder = cylinder

lat, lon = np.broadcast_arrays(lat, lon)
self.shape = lat.shape

# TODO add device switching logic (maybe use torch registers for this
# info)
long = np.concatenate([src.lon.ravel(), [360]], axis=-1)
long = src.lon.ravel()
if self.cylinder:
long = np.concatenate([long, [360]], axis=-1)
long_t = torch.from_numpy(long)

# flip the order latg since bilinear only works with increasing coordinate values
Expand All @@ -104,7 +115,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# pad z in lon direction
# only works for a global grid
# TODO generalize this to local grids and add options for padding
x = torch.cat([x, x[..., 0:1]], axis=-1)
if self.cylinder:
x = torch.cat([x, x[..., 0:1]], axis=-1)
out = self._bilinear(x)
return out.view(out.shape[:-1] + self.shape)

Expand Down
80 changes: 9 additions & 71 deletions earth2grid/lcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch

from earth2grid import base
from earth2grid._regrid import BilinearInterpolator
from earth2grid import projections

try:
import pyvista as pv
Expand All @@ -31,7 +29,10 @@
]


class LambertConformalConicProjection:
LambertConformalConicGrid = projections.Grid


class LambertConformalConicProjection(projections.Projection):
def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
"""

Expand Down Expand Up @@ -80,7 +81,7 @@ def _theta(self, lon):
delta_lon = delta_lon - np.round(delta_lon / 360) * 360 # convert to [-180, 180]
return self.n * np.deg2rad(delta_lon)

def project(self, lat, lon):
def project(self, lon, lat):
"""
Compute the projected x,y from lat,lon.
"""
Expand All @@ -100,90 +101,27 @@ def inverse_project(self, x, y):

lat = np.rad2deg(2 * np.arctan(np.power(self.RF / rho, 1 / self.n))) - 90
lon = self.lon0 + np.rad2deg(theta / self.n)
return lat, lon
return lon, lat


# Projection used by HRRR CONUS (Continental US) data
# https://rapidrefresh.noaa.gov/hrrr/HRRR_conus.domain.txt
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0)


class LambertConformalConicGrid(base.Grid):
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
def __init__(self, projection: LambertConformalConicProjection, x, y):
"""
Args:
projection: LambertConformalConicProjection object
x: range of x values
y: range of y values

"""
self.projection = projection

self.x = np.array(x)
self.y = np.array(y)

@property
def lat_lon(self):
mesh_x, mesh_y = np.meshgrid(self.x, self.y)
return self.projection.inverse_project(mesh_x, mesh_y)

@property
def lat(self):
return self.lat_lon[0]

@property
def lon(self):
return self.lat_lon[1]

@property
def shape(self):
return (len(self.y), len(self.x))

def __getitem__(self, idxs):
yidxs, xidxs = idxs
return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs])

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""

x, y = self.projection.project(lat, lon)

return BilinearInterpolator(
x_coords=torch.from_numpy(self.x),
y_coords=torch.from_numpy(self.y),
x_query=torch.from_numpy(x),
y_query=torch.from_numpy(y),
)

def visualize(self, data):
raise NotImplementedError()

def to_pyvista(self):
if pv is None:
raise ImportError("Need to install pyvista")

lat, lon = self.lat_lon
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
z = np.sin(np.deg2rad(lat))
grid = pv.StructuredGrid(x, y, z)
return grid


def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
# coordinates of point in top-left corner
lat0 = 21.138123
lon0 = 237.280472
# grid length (m)
scale = 3000.0
# coordinates on projected space
x0, y0 = HRRR_CONUS_PROJECTION.project(lat0, lon0)
x0, y0 = HRRR_CONUS_PROJECTION.project(lon0, lat0)

x = [x0 + i * scale for i in range(ix0, ix0 + nx)]
y = [y0 + i * scale for i in range(iy0, iy0 + ny)]

return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)
return projections.Grid(HRRR_CONUS_PROJECTION, x, y)


# Grid used by HRRR CONUS (Continental US) data
Expand Down
112 changes: 112 additions & 0 deletions earth2grid/projections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Conventions:

- lon,lat is preferred to lat,lon for projection equations (i.e. x=lon in Plate
Carree)
- projected grids have shape [ny, nx]
"""
import abc

import numpy as np
import torch

from earth2grid import base
from earth2grid._regrid import BilinearInterpolator

try:
import pyvista as pv
except ImportError:
pv = None


class Projection(abc.ABC):
@abc.abstractmethod
def project(self, lon: np.ndarray, lat: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the projected x,y from lon,lat.

"""
pass

@abc.abstractmethod
def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the lon,lat from the projected x,y.
"""
pass


class Grid(base.Grid):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just renamed your LCC grid class @simonbyrne . Intended to be used like earth2grid.projections.Grid.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you want to give it a more specific name (i.e. it assumes the underlying grid is rectilinear, not an unstructured mesh).

Copy link
Collaborator Author

@nbren12 nbren12 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "projection" implies rectangle. This is just a style preference for projection.Grid over projection.ProjectionGrid. Same style as healpix.Grid.

# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
def __init__(self, projection: Projection, x, y):
"""
Args:
x: range of x values
y: range of y values

"""
self.projection = projection

self.x = np.array(x)
self.y = np.array(y)

@property
def lat_lon(self):
mesh_y, mesh_x = np.meshgrid(self.y, self.x, indexing='ij')
return self.projection.inverse_project(mesh_x, mesh_y)

@property
def lat(self):
return self.lat_lon[1]

@property
def lon(self):
return self.lat_lon[0]

@property
def shape(self):
return (len(self.y), len(self.x))

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""

x, y = self.projection.project(lon, lat)

return BilinearInterpolator(
x_coords=torch.from_numpy(self.x),
y_coords=torch.from_numpy(self.y),
x_query=torch.from_numpy(x),
y_query=torch.from_numpy(y),
)

def __getitem__(self, idxs) -> "Grid":
yidxs, xidxs = idxs
return Grid(self.projection, x=self.x[xidxs], y=self.y[yidxs])

def visualize(self, data):
raise NotImplementedError()

def to_pyvista(self):
if pv is None:
raise ImportError("Need to install pyvista")

lat, lon = self.lat_lon
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
z = np.sin(np.deg2rad(lat))
grid = pv.StructuredGrid(x, y, z)
return grid
7 changes: 7 additions & 0 deletions earth2grid/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,10 @@ def ang2vec(lon, lat):
y = torch.cos(lat) * torch.sin(lon)
z = torch.sin(lat)
return (x, y, z)


def vec2ang(x, y, z):
"""convert lon,lat in radians to cartesian coordinates"""
lat = torch.asin(z)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes your inputs are already normalized. If you want it to work with non-normalized inputs, you would need

Suggested change
lat = torch.asin(z)
lat = torch.atan2(z, torch.hypot(y, x))

lon = torch.atan2(y, x)
return lon, lat
Loading