Skip to content

Commit 427c310

Browse files
committed
Add ying yang grid
This PR adds a new projection based grid, the ying yang grid. It restructures some of the lambert conformal logic a bit, so Simon should take a look.
1 parent c9cb58d commit 427c310

11 files changed

+403
-75
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,14 @@ test_grid_visualize.png
116116
*.png
117117
*.jpg
118118
*.jpeg
119+
*.gif
119120
public/
120121

121122
a.out
122123
*.o
124+
125+
# editor backup files
126+
# helix
127+
\#*\#
128+
# emacs
129+
*~

earth2grid/_regrid.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
from typing import Dict, Sequence
1717

1818
import einops
19-
import netCDF4 as nc
2019
import torch
20+
21+
try:
22+
import netCDF4 as nc
23+
except ImportError:
24+
nc = None
25+
2126
from scipy import spatial
2227

2328
from earth2grid.spatial import ang2vec, haversine_distance
@@ -59,6 +64,9 @@ def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder":
5964
class TempestRegridder(torch.nn.Module):
6065
def __init__(self, file_path):
6166
super().__init__()
67+
if nc is None:
68+
raise ImportError("netCDF4 not imported. Please install for this feature.")
69+
6270
dataset = nc.Dataset(file_path)
6371
self.lat = dataset["latc_b"][:]
6472
self.lon = dataset["lonc_b"][:]

earth2grid/latlon.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,18 @@
2525

2626

2727
class LatLonGrid(base.Grid):
28-
def __init__(self, lat: list[float], lon: list[float]):
28+
def __init__(self, lat: list[float], lon: list[float], cylinder: bool = True):
2929
"""
3030
Args:
3131
lat: center of lat cells
3232
lon: center of lon cells
33+
cylinder: if true, then lon is considered a periodic coordinate
34+
on cylinder so that interpolation wraps around the edge.
35+
Otherwise, it is assumed to be a finite plane.
3336
"""
3437
self._lat = lat
3538
self._lon = lon
39+
self.cylinder = cylinder
3640

3741
@property
3842
def lat(self):
@@ -48,7 +52,7 @@ def shape(self):
4852

4953
def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
5054
"""Get regridder to the specified lat and lon points"""
51-
return _RegridFromLatLon(self, lat, lon)
55+
return _RegridFromLatLon(self, lat, lon, cylinder=self.cylinder)
5256

5357
def _lonb(self):
5458
edges = (self.lon[1:] + self.lon[:-1]) / 2
@@ -78,15 +82,22 @@ def to_pyvista(self):
7882
class _RegridFromLatLon(torch.nn.Module):
7983
"""Regrid from lat-lon to unstructured grid with bilinear interpolation"""
8084

81-
def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray):
85+
def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray, cylinder: bool = True):
86+
"""
87+
Args:
88+
cylinder: if True than lon is assumed to be periodic
89+
"""
8290
super().__init__()
91+
self.cylinder = cylinder
8392

8493
lat, lon = np.broadcast_arrays(lat, lon)
8594
self.shape = lat.shape
8695

8796
# TODO add device switching logic (maybe use torch registers for this
8897
# info)
89-
long = np.concatenate([src.lon.ravel(), [360]], axis=-1)
98+
long = src.lon.ravel()
99+
if self.cylinder:
100+
long = np.concatenate([long, [360]], axis=-1)
90101
long_t = torch.from_numpy(long)
91102

92103
# flip the order latg since bilinear only works with increasing coordinate values
@@ -104,7 +115,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
104115
# pad z in lon direction
105116
# only works for a global grid
106117
# TODO generalize this to local grids and add options for padding
107-
x = torch.cat([x, x[..., 0:1]], axis=-1)
118+
if self.cylinder:
119+
x = torch.cat([x, x[..., 0:1]], axis=-1)
108120
out = self._bilinear(x)
109121
return out.view(out.shape[:-1] + self.shape)
110122

earth2grid/lcc.py

Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import numpy as np
16-
import torch
1716

18-
from earth2grid import base
19-
from earth2grid._regrid import BilinearInterpolator
17+
from earth2grid import projections
2018

2119
try:
2220
import pyvista as pv
@@ -31,7 +29,10 @@
3129
]
3230

3331

34-
class LambertConformalConicProjection:
32+
LambertConformalConicGrid = projections.Grid
33+
34+
35+
class LambertConformalConicProjection(projections.Projection):
3536
def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
3637
"""
3738
@@ -108,69 +109,6 @@ def inverse_project(self, x, y):
108109
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0)
109110

110111

111-
class LambertConformalConicGrid(base.Grid):
112-
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
113-
def __init__(self, projection: LambertConformalConicProjection, x, y):
114-
"""
115-
Args:
116-
projection: LambertConformalConicProjection object
117-
x: range of x values
118-
y: range of y values
119-
120-
"""
121-
self.projection = projection
122-
123-
self.x = np.array(x)
124-
self.y = np.array(y)
125-
126-
@property
127-
def lat_lon(self):
128-
mesh_x, mesh_y = np.meshgrid(self.x, self.y)
129-
return self.projection.inverse_project(mesh_x, mesh_y)
130-
131-
@property
132-
def lat(self):
133-
return self.lat_lon[0]
134-
135-
@property
136-
def lon(self):
137-
return self.lat_lon[1]
138-
139-
@property
140-
def shape(self):
141-
return (len(self.y), len(self.x))
142-
143-
def __getitem__(self, idxs):
144-
yidxs, xidxs = idxs
145-
return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs])
146-
147-
def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
148-
"""Get regridder to the specified lat and lon points"""
149-
150-
x, y = self.projection.project(lat, lon)
151-
152-
return BilinearInterpolator(
153-
x_coords=torch.from_numpy(self.x),
154-
y_coords=torch.from_numpy(self.y),
155-
x_query=torch.from_numpy(x),
156-
y_query=torch.from_numpy(y),
157-
)
158-
159-
def visualize(self, data):
160-
raise NotImplementedError()
161-
162-
def to_pyvista(self):
163-
if pv is None:
164-
raise ImportError("Need to install pyvista")
165-
166-
lat, lon = self.lat_lon
167-
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
168-
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
169-
z = np.sin(np.deg2rad(lat))
170-
grid = pv.StructuredGrid(x, y, z)
171-
return grid
172-
173-
174112
def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
175113
# coordinates of point in top-left corner
176114
lat0 = 21.138123
@@ -183,7 +121,7 @@ def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
183121
x = [x0 + i * scale for i in range(ix0, ix0 + nx)]
184122
y = [y0 + i * scale for i in range(iy0, iy0 + ny)]
185123

186-
return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)
124+
return projections.Grid(HRRR_CONUS_PROJECTION, x, y)
187125

188126

189127
# Grid used by HRRR CONUS (Continental US) data

earth2grid/projections.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import abc
16+
17+
import numpy as np
18+
import torch
19+
20+
from earth2grid import base
21+
from earth2grid._regrid import BilinearInterpolator
22+
23+
try:
24+
import pyvista as pv
25+
except ImportError:
26+
pv = None
27+
28+
29+
class Projection(abc.ABC):
30+
@abc.abstractmethod
31+
def project(self, lat: np.ndarray, lon: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
32+
"""
33+
Compute the projected x,y from lat,lon.
34+
"""
35+
pass
36+
37+
@abc.abstractmethod
38+
def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
39+
"""
40+
Compute the lat,lon from the projected x,y.
41+
"""
42+
pass
43+
44+
45+
class Grid(base.Grid):
46+
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
47+
def __init__(self, projection: Projection, x, y):
48+
"""
49+
Args:
50+
x: range of x values
51+
y: range of y values
52+
53+
"""
54+
self.projection = projection
55+
56+
self.x = np.array(x)
57+
self.y = np.array(y)
58+
59+
@property
60+
def lat_lon(self):
61+
mesh_x, mesh_y = np.meshgrid(self.x, self.y, indexing='ij')
62+
return self.projection.inverse_project(mesh_x, mesh_y)
63+
64+
@property
65+
def lat(self):
66+
return self.lat_lon[0]
67+
68+
@property
69+
def lon(self):
70+
return self.lat_lon[1]
71+
72+
@property
73+
def shape(self):
74+
return (len(self.x), len(self.y))
75+
76+
def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
77+
"""Get regridder to the specified lat and lon points"""
78+
79+
x, y = self.projection.project(lat, lon)
80+
81+
return BilinearInterpolator(
82+
x_coords=torch.from_numpy(self.x),
83+
y_coords=torch.from_numpy(self.y),
84+
x_query=torch.from_numpy(x),
85+
y_query=torch.from_numpy(y),
86+
)
87+
88+
def visualize(self, data):
89+
raise NotImplementedError()
90+
91+
def to_pyvista(self):
92+
if pv is None:
93+
raise ImportError("Need to install pyvista")
94+
95+
lat, lon = self.lat_lon
96+
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
97+
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
98+
z = np.sin(np.deg2rad(lat))
99+
grid = pv.StructuredGrid(x, y, z)
100+
return grid

earth2grid/spatial.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,10 @@ def ang2vec(lon, lat):
4444
y = torch.cos(lat) * torch.sin(lon)
4545
z = torch.sin(lat)
4646
return (x, y, z)
47+
48+
49+
def vec2ang(x, y, z):
50+
"""convert lon,lat in radians to cartesian coordinates"""
51+
lat = torch.asin(z)
52+
lon = torch.atan2(y, x)
53+
return lon, lat

0 commit comments

Comments
 (0)