Skip to content

Commit 4ed392a

Browse files
tjhuntershmh40
andauthored
[906] Bug fix in tokenizer (ecmwf#907)
* changes * changes * changes * changes * changes * changes * changes * changes * cleanups * changes * comments --------- Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com>
1 parent f73067d commit 4ed392a

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

src/weathergen/datasets/tokenizer.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from weathergen.datasets.utils import (
1717
healpix_verts_rots,
18-
locs_to_cell_coords_ctrs,
1918
r3tos2,
2019
)
2120

@@ -76,26 +75,26 @@ def __init__(self, healpix_level: int):
7675
vertsmm_rots.to(torch.float32),
7776
]
7877

79-
self.verts_local = []
80-
verts = torch.stack([verts10, verts11, verts01, vertsmm])
81-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts00_rots, verts.transpose(0, 1)))
82-
self.verts_local.append(temp.flatten(1, 2))
83-
84-
verts = torch.stack([verts00, verts11, verts01, vertsmm])
85-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts10_rots, verts.transpose(0, 1)))
86-
self.verts_local.append(temp.flatten(1, 2))
87-
88-
verts = torch.stack([verts00, verts10, verts01, vertsmm])
89-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts11_rots, verts.transpose(0, 1)))
90-
self.verts_local.append(temp.flatten(1, 2))
91-
92-
verts = torch.stack([verts00, verts11, verts10, vertsmm])
93-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts01_rots, verts.transpose(0, 1)))
94-
self.verts_local.append(temp.flatten(1, 2))
78+
transforms = [
79+
([verts10, verts11, verts01, vertsmm], verts00_rots),
80+
([verts00, verts11, verts01, vertsmm], verts10_rots),
81+
([verts00, verts10, verts01, vertsmm], verts11_rots),
82+
([verts00, verts11, verts10, vertsmm], verts01_rots),
83+
([verts00, verts10, verts11, verts01], vertsmm_rots),
84+
]
9585

96-
verts = torch.stack([verts00, verts10, verts11, verts01])
97-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(vertsmm_rots, verts.transpose(0, 1)))
98-
self.verts_local.append(temp.flatten(1, 2))
86+
self.verts_local = []
87+
for _verts, rot in transforms:
88+
# Compute local coordinates
89+
verts = torch.stack(_verts)
90+
# shape: <healpix, 4, 3>
91+
verts = verts.transpose(0, 1)
92+
# Batch multiplication by the 3x3 rotation matrices.
93+
# shape: <healpix, 3, 3> @ <healpix, 4, 3> -> <healpix, 4, 3>
94+
# Needs to transpose first to <healpix, 3, 4> then transpose back.
95+
t1 = torch.bmm(rot, verts.transpose(-1, -2)).transpose(-2, -1)
96+
t2 = ref - t1
97+
self.verts_local.append(t2.flatten(1, 2))
9998

10099
self.hpy_verts_local_target = torch.stack(self.verts_local).transpose(0, 1)
101100

0 commit comments

Comments
 (0)