Skip to content

Commit 6a1fd7e

Browse files
authored
Add ying yang grid (#20)
* Add ying yang grid This PR adds a new projection based grid, the ying yang grid. It restructures some of the lambert conformal logic a bit, so Simon should take a look. * fix lcc test, and change convention * update change log
1 parent c51739d commit 6a1fd7e

12 files changed

+424
-78
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,14 @@ test_grid_visualize.png
116116
*.png
117117
*.jpg
118118
*.jpeg
119+
*.gif
119120
public/
120121

121122
a.out
122123
*.o
124+
125+
# editor backup files
126+
# helix
127+
\#*\#
128+
# emacs
129+
*~

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ New APIs
88
- ring2double
99
- to_rotated_pixelization
1010
- to_double_pixelization
11+
- earth2grid.projections. Grids in arbitrary projections
12+
- earth2grid.yingyang
13+
14+
Breaking changes:
15+
16+
- change coordinate transform and shape of lcc grid
1117

1218
## 2025.4.1
1319

earth2grid/_regrid.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
from typing import Dict, Sequence
1717

1818
import einops
19-
import netCDF4 as nc
2019
import torch
20+
21+
try:
22+
import netCDF4 as nc
23+
except ImportError:
24+
nc = None
25+
2126
from scipy import spatial
2227

2328
from earth2grid.spatial import ang2vec, haversine_distance
@@ -59,6 +64,9 @@ def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder":
5964
class TempestRegridder(torch.nn.Module):
6065
def __init__(self, file_path):
6166
super().__init__()
67+
if nc is None:
68+
raise ImportError("netCDF4 not imported. Please install for this feature.")
69+
6270
dataset = nc.Dataset(file_path)
6371
self.lat = dataset["latc_b"][:]
6472
self.lon = dataset["lonc_b"][:]

earth2grid/latlon.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,18 @@
2525

2626

2727
class LatLonGrid(base.Grid):
28-
def __init__(self, lat: list[float], lon: list[float]):
28+
def __init__(self, lat: list[float], lon: list[float], cylinder: bool = True):
2929
"""
3030
Args:
3131
lat: center of lat cells
3232
lon: center of lon cells
33+
cylinder: if true, then lon is considered a periodic coordinate
34+
on cylinder so that interpolation wraps around the edge.
35+
Otherwise, it is assumed to be a finite plane.
3336
"""
3437
self._lat = lat
3538
self._lon = lon
39+
self.cylinder = cylinder
3640

3741
@property
3842
def lat(self):
@@ -48,7 +52,7 @@ def shape(self):
4852

4953
def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
5054
"""Get regridder to the specified lat and lon points"""
51-
return _RegridFromLatLon(self, lat, lon)
55+
return _RegridFromLatLon(self, lat, lon, cylinder=self.cylinder)
5256

5357
def _lonb(self):
5458
edges = (self.lon[1:] + self.lon[:-1]) / 2
@@ -78,15 +82,22 @@ def to_pyvista(self):
7882
class _RegridFromLatLon(torch.nn.Module):
7983
"""Regrid from lat-lon to unstructured grid with bilinear interpolation"""
8084

81-
def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray):
85+
def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray, cylinder: bool = True):
86+
"""
87+
Args:
88+
cylinder: if True than lon is assumed to be periodic
89+
"""
8290
super().__init__()
91+
self.cylinder = cylinder
8392

8493
lat, lon = np.broadcast_arrays(lat, lon)
8594
self.shape = lat.shape
8695

8796
# TODO add device switching logic (maybe use torch registers for this
8897
# info)
89-
long = np.concatenate([src.lon.ravel(), [360]], axis=-1)
98+
long = src.lon.ravel()
99+
if self.cylinder:
100+
long = np.concatenate([long, [360]], axis=-1)
90101
long_t = torch.from_numpy(long)
91102

92103
# flip the order latg since bilinear only works with increasing coordinate values
@@ -104,7 +115,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
104115
# pad z in lon direction
105116
# only works for a global grid
106117
# TODO generalize this to local grids and add options for padding
107-
x = torch.cat([x, x[..., 0:1]], axis=-1)
118+
if self.cylinder:
119+
x = torch.cat([x, x[..., 0:1]], axis=-1)
108120
out = self._bilinear(x)
109121
return out.view(out.shape[:-1] + self.shape)
110122

earth2grid/lcc.py

Lines changed: 9 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import numpy as np
16-
import torch
1716

18-
from earth2grid import base
19-
from earth2grid._regrid import BilinearInterpolator
17+
from earth2grid import projections
2018

2119
try:
2220
import pyvista as pv
@@ -31,7 +29,10 @@
3129
]
3230

3331

34-
class LambertConformalConicProjection:
32+
LambertConformalConicGrid = projections.Grid
33+
34+
35+
class LambertConformalConicProjection(projections.Projection):
3536
def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
3637
"""
3738
@@ -80,7 +81,7 @@ def _theta(self, lon):
8081
delta_lon = delta_lon - np.round(delta_lon / 360) * 360 # convert to [-180, 180]
8182
return self.n * np.deg2rad(delta_lon)
8283

83-
def project(self, lat, lon):
84+
def project(self, lon, lat):
8485
"""
8586
Compute the projected x,y from lat,lon.
8687
"""
@@ -100,90 +101,27 @@ def inverse_project(self, x, y):
100101

101102
lat = np.rad2deg(2 * np.arctan(np.power(self.RF / rho, 1 / self.n))) - 90
102103
lon = self.lon0 + np.rad2deg(theta / self.n)
103-
return lat, lon
104+
return lon, lat
104105

105106

106107
# Projection used by HRRR CONUS (Continental US) data
107108
# https://rapidrefresh.noaa.gov/hrrr/HRRR_conus.domain.txt
108109
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0)
109110

110111

111-
class LambertConformalConicGrid(base.Grid):
112-
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
113-
def __init__(self, projection: LambertConformalConicProjection, x, y):
114-
"""
115-
Args:
116-
projection: LambertConformalConicProjection object
117-
x: range of x values
118-
y: range of y values
119-
120-
"""
121-
self.projection = projection
122-
123-
self.x = np.array(x)
124-
self.y = np.array(y)
125-
126-
@property
127-
def lat_lon(self):
128-
mesh_x, mesh_y = np.meshgrid(self.x, self.y)
129-
return self.projection.inverse_project(mesh_x, mesh_y)
130-
131-
@property
132-
def lat(self):
133-
return self.lat_lon[0]
134-
135-
@property
136-
def lon(self):
137-
return self.lat_lon[1]
138-
139-
@property
140-
def shape(self):
141-
return (len(self.y), len(self.x))
142-
143-
def __getitem__(self, idxs):
144-
yidxs, xidxs = idxs
145-
return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs])
146-
147-
def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
148-
"""Get regridder to the specified lat and lon points"""
149-
150-
x, y = self.projection.project(lat, lon)
151-
152-
return BilinearInterpolator(
153-
x_coords=torch.from_numpy(self.x),
154-
y_coords=torch.from_numpy(self.y),
155-
x_query=torch.from_numpy(x),
156-
y_query=torch.from_numpy(y),
157-
)
158-
159-
def visualize(self, data):
160-
raise NotImplementedError()
161-
162-
def to_pyvista(self):
163-
if pv is None:
164-
raise ImportError("Need to install pyvista")
165-
166-
lat, lon = self.lat_lon
167-
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
168-
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
169-
z = np.sin(np.deg2rad(lat))
170-
grid = pv.StructuredGrid(x, y, z)
171-
return grid
172-
173-
174112
def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
175113
# coordinates of point in top-left corner
176114
lat0 = 21.138123
177115
lon0 = 237.280472
178116
# grid length (m)
179117
scale = 3000.0
180118
# coordinates on projected space
181-
x0, y0 = HRRR_CONUS_PROJECTION.project(lat0, lon0)
119+
x0, y0 = HRRR_CONUS_PROJECTION.project(lon0, lat0)
182120

183121
x = [x0 + i * scale for i in range(ix0, ix0 + nx)]
184122
y = [y0 + i * scale for i in range(iy0, iy0 + ny)]
185123

186-
return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)
124+
return projections.Grid(HRRR_CONUS_PROJECTION, x, y)
187125

188126

189127
# Grid used by HRRR CONUS (Continental US) data

earth2grid/projections.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
Conventions:
17+
18+
- lon,lat is preferred to lat,lon for projection equations (i.e. x=lon in Plate
19+
Carree)
20+
- projected grids have shape [ny, nx]
21+
"""
22+
import abc
23+
24+
import numpy as np
25+
import torch
26+
27+
from earth2grid import base
28+
from earth2grid._regrid import BilinearInterpolator
29+
30+
try:
31+
import pyvista as pv
32+
except ImportError:
33+
pv = None
34+
35+
36+
class Projection(abc.ABC):
37+
@abc.abstractmethod
38+
def project(self, lon: np.ndarray, lat: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
39+
"""
40+
Compute the projected x,y from lon,lat.
41+
42+
"""
43+
pass
44+
45+
@abc.abstractmethod
46+
def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
47+
"""
48+
Compute the lon,lat from the projected x,y.
49+
"""
50+
pass
51+
52+
53+
class Grid(base.Grid):
54+
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
55+
def __init__(self, projection: Projection, x, y):
56+
"""
57+
Args:
58+
x: range of x values
59+
y: range of y values
60+
61+
"""
62+
self.projection = projection
63+
64+
self.x = np.array(x)
65+
self.y = np.array(y)
66+
67+
@property
68+
def lat_lon(self):
69+
mesh_y, mesh_x = np.meshgrid(self.y, self.x, indexing='ij')
70+
return self.projection.inverse_project(mesh_x, mesh_y)
71+
72+
@property
73+
def lat(self):
74+
return self.lat_lon[1]
75+
76+
@property
77+
def lon(self):
78+
return self.lat_lon[0]
79+
80+
@property
81+
def shape(self):
82+
return (len(self.y), len(self.x))
83+
84+
def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
85+
"""Get regridder to the specified lat and lon points"""
86+
87+
x, y = self.projection.project(lon, lat)
88+
89+
return BilinearInterpolator(
90+
x_coords=torch.from_numpy(self.x),
91+
y_coords=torch.from_numpy(self.y),
92+
x_query=torch.from_numpy(x),
93+
y_query=torch.from_numpy(y),
94+
)
95+
96+
def __getitem__(self, idxs) -> "Grid":
97+
yidxs, xidxs = idxs
98+
return Grid(self.projection, x=self.x[xidxs], y=self.y[yidxs])
99+
100+
def visualize(self, data):
101+
raise NotImplementedError()
102+
103+
def to_pyvista(self):
104+
if pv is None:
105+
raise ImportError("Need to install pyvista")
106+
107+
lat, lon = self.lat_lon
108+
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
109+
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
110+
z = np.sin(np.deg2rad(lat))
111+
grid = pv.StructuredGrid(x, y, z)
112+
return grid

earth2grid/spatial.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,10 @@ def ang2vec(lon, lat):
4444
y = torch.cos(lat) * torch.sin(lon)
4545
z = torch.sin(lat)
4646
return (x, y, z)
47+
48+
49+
def vec2ang(x, y, z):
50+
"""convert lon,lat in radians to cartesian coordinates"""
51+
lat = torch.asin(z)
52+
lon = torch.atan2(y, x)
53+
return lon, lat

0 commit comments

Comments
 (0)