Skip to content

Commit 7bb80eb

Browse files
javak87Javad Kasravi
andauthored
vectorize s2tor3 (ecmwf#745)
* vectorize s2tor3 * ruff code --------- Co-authored-by: Javad Kasravi <j.kasravi@fz-juelich.de>
1 parent 0b201b3 commit 7bb80eb

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/weathergen/datasets/utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,16 @@ def s2tor3(lats, lons):
5959
Note: mathematics convention with lats in [0,pi] and lons in [0,2pi] is used
6060
(which is not problematic for lons but for lats care is required)
6161
"""
62-
x = torch.sin(lats) * torch.cos(lons)
63-
y = torch.sin(lats) * torch.sin(lons)
64-
z = torch.cos(lats)
65-
out = torch.stack([x, y, z])
66-
return out.permute([*list(np.arange(len(out.shape))[:-1] + 1), 0])
62+
sin_lats = torch.sin(lats)
63+
cos_lats = torch.cos(lats)
64+
65+
# Calculate the x, y, and z coordinates using vectorized operations.
66+
x = sin_lats * torch.cos(lons)
67+
y = sin_lats * torch.sin(lons)
68+
z = cos_lats
69+
70+
# Stack the x, y, and z tensors along the last dimension.
71+
return torch.stack([x, y, z], dim=-1)
6772

6873

6974
####################################################################################################

0 commit comments

Comments
 (0)