Skip to content

Commit 8fbe981

Browse files
committed
Add ying yang grid
This PR adds a new projection based grid, the ying yang grid. It restructures some of the lambert conformal logic a bit, so Simon should take a look.
1 parent c9cb58d commit 8fbe981

File tree

10 files changed

+415
-74
lines changed

10 files changed

+415
-74
lines changed

earth2grid/_regrid.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
from typing import Dict, Sequence
1717

1818
import einops
19-
import netCDF4 as nc
2019
import torch
20+
21+
try:
22+
import netCDF4 as nc
23+
except ImportError:
24+
nc = None
25+
2126
from scipy import spatial
2227

2328
from earth2grid.spatial import ang2vec, haversine_distance
@@ -59,6 +64,9 @@ def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder":
5964
class TempestRegridder(torch.nn.Module):
6065
def __init__(self, file_path):
6166
super().__init__()
67+
if nc is None:
68+
raise ImportError("netCDF4 not imported. Please install for this feature.")
69+
6270
dataset = nc.Dataset(file_path)
6371
self.lat = dataset["latc_b"][:]
6472
self.lon = dataset["lonc_b"][:]

earth2grid/latlon.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,18 @@
2525

2626

2727
class LatLonGrid(base.Grid):
28-
def __init__(self, lat: list[float], lon: list[float]):
28+
def __init__(self, lat: list[float], lon: list[float], cylinder: bool = True):
2929
"""
3030
Args:
3131
lat: center of lat cells
3232
lon: center of lon cells
33+
cylinder: if true, then lon is considered a periodic coordinate
34+
on cylinder so that interpolation wraps around the edge.
35+
Otherwise, it is assumed to be a finite plane.
3336
"""
3437
self._lat = lat
3538
self._lon = lon
39+
self.cylinder = cylinder
3640

3741
@property
3842
def lat(self):
@@ -48,7 +52,7 @@ def shape(self):
4852

4953
def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
5054
"""Get regridder to the specified lat and lon points"""
51-
return _RegridFromLatLon(self, lat, lon)
55+
return _RegridFromLatLon(self, lat, lon, cylinder=self.cylinder)
5256

5357
def _lonb(self):
5458
edges = (self.lon[1:] + self.lon[:-1]) / 2
@@ -78,15 +82,22 @@ def to_pyvista(self):
7882
class _RegridFromLatLon(torch.nn.Module):
7983
"""Regrid from lat-lon to unstructured grid with bilinear interpolation"""
8084

81-
def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray):
85+
def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray, cylinder: bool = True):
86+
"""
87+
Args:
88+
cylinder: if True than lon is assumed to be periodic
89+
"""
8290
super().__init__()
91+
self.cylinder = cylinder
8392

8493
lat, lon = np.broadcast_arrays(lat, lon)
8594
self.shape = lat.shape
8695

8796
# TODO add device switching logic (maybe use torch registers for this
8897
# info)
89-
long = np.concatenate([src.lon.ravel(), [360]], axis=-1)
98+
long = src.lon.ravel()
99+
if self.cylinder:
100+
long = np.concatenate([long, [360]], axis=-1)
90101
long_t = torch.from_numpy(long)
91102

92103
# flip the order latg since bilinear only works with increasing coordinate values
@@ -104,7 +115,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
104115
# pad z in lon direction
105116
# only works for a global grid
106117
# TODO generalize this to local grids and add options for padding
107-
x = torch.cat([x, x[..., 0:1]], axis=-1)
118+
if self.cylinder:
119+
x = torch.cat([x, x[..., 0:1]], axis=-1)
108120
out = self._bilinear(x)
109121
return out.view(out.shape[:-1] + self.shape)
110122

earth2grid/lcc.py

Lines changed: 5 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import numpy as np
16-
import torch
1716

18-
from earth2grid import base
19-
from earth2grid._regrid import BilinearInterpolator
17+
from earth2grid import projections
2018

2119
try:
2220
import pyvista as pv
@@ -31,7 +29,9 @@
3129
]
3230

3331

34-
class LambertConformalConicProjection:
32+
LambertConformalConicGrid = projections.Grid
33+
34+
class LambertConformalConicProjection(projections.Projection):
3535
def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
3636
"""
3737
@@ -108,68 +108,6 @@ def inverse_project(self, x, y):
108108
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0)
109109

110110

111-
class LambertConformalConicGrid(base.Grid):
112-
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
113-
def __init__(self, projection: LambertConformalConicProjection, x, y):
114-
"""
115-
Args:
116-
projection: LambertConformalConicProjection object
117-
x: range of x values
118-
y: range of y values
119-
120-
"""
121-
self.projection = projection
122-
123-
self.x = np.array(x)
124-
self.y = np.array(y)
125-
126-
@property
127-
def lat_lon(self):
128-
mesh_x, mesh_y = np.meshgrid(self.x, self.y)
129-
return self.projection.inverse_project(mesh_x, mesh_y)
130-
131-
@property
132-
def lat(self):
133-
return self.lat_lon[0]
134-
135-
@property
136-
def lon(self):
137-
return self.lat_lon[1]
138-
139-
@property
140-
def shape(self):
141-
return (len(self.y), len(self.x))
142-
143-
def __getitem__(self, idxs):
144-
yidxs, xidxs = idxs
145-
return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs])
146-
147-
def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
148-
"""Get regridder to the specified lat and lon points"""
149-
150-
x, y = self.projection.project(lat, lon)
151-
152-
return BilinearInterpolator(
153-
x_coords=torch.from_numpy(self.x),
154-
y_coords=torch.from_numpy(self.y),
155-
x_query=torch.from_numpy(x),
156-
y_query=torch.from_numpy(y),
157-
)
158-
159-
def visualize(self, data):
160-
raise NotImplementedError()
161-
162-
def to_pyvista(self):
163-
if pv is None:
164-
raise ImportError("Need to install pyvista")
165-
166-
lat, lon = self.lat_lon
167-
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
168-
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
169-
z = np.sin(np.deg2rad(lat))
170-
grid = pv.StructuredGrid(x, y, z)
171-
return grid
172-
173111

174112
def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
175113
# coordinates of point in top-left corner
@@ -183,7 +121,7 @@ def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
183121
x = [x0 + i * scale for i in range(ix0, ix0 + nx)]
184122
y = [y0 + i * scale for i in range(iy0, iy0 + ny)]
185123

186-
return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)
124+
return projections.Grid(HRRR_CONUS_PROJECTION, x, y)
187125

188126

189127
# Grid used by HRRR CONUS (Continental US) data

earth2grid/projections.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import abc
16+
17+
import numpy as np
18+
import torch
19+
20+
from earth2grid import base
21+
from earth2grid._regrid import BilinearInterpolator
22+
23+
try:
24+
import pyvista as pv
25+
except ImportError:
26+
pv = None
27+
28+
29+
class Projection(abc.ABC):
30+
@abc.abstractmethod
31+
def project(self, lat: np.ndarray, lon: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
32+
"""
33+
Compute the projected x,y from lat,lon.
34+
"""
35+
pass
36+
37+
@abc.abstractmethod
38+
def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
39+
"""
40+
Compute the lat,lon from the projected x,y.
41+
"""
42+
pass
43+
44+
45+
class Grid(base.Grid):
46+
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
47+
def __init__(self, projection: Projection, x, y):
48+
"""
49+
Args:
50+
x: range of x values
51+
y: range of y values
52+
53+
"""
54+
self.projection = projection
55+
56+
self.x = np.array(x)
57+
self.y = np.array(y)
58+
59+
@property
60+
def lat_lon(self):
61+
mesh_x, mesh_y = np.meshgrid(self.x, self.y, indexing='ij')
62+
return self.projection.inverse_project(mesh_x, mesh_y)
63+
64+
@property
65+
def lat(self):
66+
return self.lat_lon[0]
67+
68+
@property
69+
def lon(self):
70+
return self.lat_lon[1]
71+
72+
@property
73+
def shape(self):
74+
return (len(self.x), len(self.y))
75+
76+
def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
77+
"""Get regridder to the specified lat and lon points"""
78+
79+
x, y = self.projection.project(lat, lon)
80+
81+
return BilinearInterpolator(
82+
x_coords=torch.from_numpy(self.x),
83+
y_coords=torch.from_numpy(self.y),
84+
x_query=torch.from_numpy(x),
85+
y_query=torch.from_numpy(y),
86+
)
87+
88+
def visualize(self, data):
89+
raise NotImplementedError()
90+
91+
def to_pyvista(self):
92+
if pv is None:
93+
raise ImportError("Need to install pyvista")
94+
95+
lat, lon = self.lat_lon
96+
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
97+
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
98+
z = np.sin(np.deg2rad(lat))
99+
grid = pv.StructuredGrid(x, y, z)
100+
return grid

earth2grid/spatial.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,11 @@ def ang2vec(lon, lat):
4444
y = torch.cos(lat) * torch.sin(lon)
4545
z = torch.sin(lat)
4646
return (x, y, z)
47+
48+
49+
def vec2ang(x, y, z):
50+
"""convert lon,lat in radians to cartesian coordinates"""
51+
lat = torch.asin(z)
52+
lon = torch.atan2(y, x)
53+
return lon, lat
54+

earth2grid/yinyang.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Yin Yang
16+
17+
the ying yang grid is an overset grid for the sphere containing two faces
18+
- Yin: a normal lat lon grid for 2/3 of lon, and 2/3 of lat
19+
- Yang: Yin but with pole along x
20+
21+
22+
Key facts
23+
24+
ying
25+
lon: [-3 pi /4 - delta, 3 pi / 4 + delta ]
26+
lat: [-pi / 4 - delta, pi / 4 + delta]
27+
28+
ying to yang transformation: alpha = 0, beta = 90, gamma = 180
29+
30+
(x, y, z) - > (-x, z, y)
31+
32+
"""
33+
import math
34+
35+
import numpy as np
36+
import torch
37+
38+
from earth2grid import latlon, projections, spatial
39+
40+
41+
def Ying(nlat: int, nlon: int, delta: int):
42+
"""The ying grid
43+
44+
nlat, and nlon are as in the latlon.equiangular_latlon_grid and
45+
refer to full sphere.
46+
47+
``nlat`` includes the poles [90, -90], and ``nlon`` is [0, 2 pi).
48+
49+
``delta`` is the amount of overlap in terms of number of grid points.
50+
51+
"""
52+
# TODO test that min(lat) = -max(lat), and for lon too
53+
54+
dlat = 180 / (nlat - 1)
55+
dlon = 360 / nlon
56+
57+
n = math.ceil(3 * nlon / 8)
58+
lon = np.arange(- n - delta, n + delta + 1) * dlon
59+
lat = np.arange(- (nlat - 1) // 4 - delta, (nlat + 1) // 4 + delta + 1) * dlat
60+
61+
return latlon.LatLonGrid(lat.tolist(), lon.tolist(), cylinder=False)
62+
63+
64+
class YangProjection(projections.Projection):
65+
66+
def project(self, lat: np.ndarray, lon: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
67+
"""
68+
Compute the projected x,y from lat,lon.
69+
"""
70+
lat = torch.from_numpy(lat)
71+
lon = torch.from_numpy(lon)
72+
73+
lat = torch.deg2rad(lat)
74+
lon = torch.deg2rad(lon)
75+
76+
x, y, z = spatial.ang2vec(lat=lat, lon=lon)
77+
x, y, z = -x, z, y
78+
lon, lat = spatial.vec2ang(x, y ,z)
79+
80+
lat = torch.rad2deg(lat)
81+
lon = torch.rad2deg(lon)
82+
83+
return lat.numpy(), lon.numpy()
84+
85+
def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
86+
"""
87+
Compute the lat,lon from the projected x,y.
88+
"""
89+
# ying-yang is its own inverse
90+
return self.project(x, y)
91+
92+
93+
def Yang(nlat, nlon, delta):
94+
ying = Ying(nlat, nlon, delta)
95+
return projections.Grid(YangProjection(), ying.lat, ying.lon)
96+

0 commit comments

Comments
 (0)