|
15 | 15 |
|
16 | 16 | from weathergen.datasets.utils import ( |
17 | 17 | healpix_verts_rots, |
18 | | - locs_to_cell_coords_ctrs, |
19 | 18 | r3tos2, |
20 | 19 | ) |
21 | 20 |
|
@@ -76,26 +75,26 @@ def __init__(self, healpix_level: int): |
76 | 75 | vertsmm_rots.to(torch.float32), |
77 | 76 | ] |
78 | 77 |
|
79 | | - self.verts_local = [] |
80 | | - verts = torch.stack([verts10, verts11, verts01, vertsmm]) |
81 | | - temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts00_rots, verts.transpose(0, 1))) |
82 | | - self.verts_local.append(temp.flatten(1, 2)) |
83 | | - |
84 | | - verts = torch.stack([verts00, verts11, verts01, vertsmm]) |
85 | | - temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts10_rots, verts.transpose(0, 1))) |
86 | | - self.verts_local.append(temp.flatten(1, 2)) |
87 | | - |
88 | | - verts = torch.stack([verts00, verts10, verts01, vertsmm]) |
89 | | - temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts11_rots, verts.transpose(0, 1))) |
90 | | - self.verts_local.append(temp.flatten(1, 2)) |
91 | | - |
92 | | - verts = torch.stack([verts00, verts11, verts10, vertsmm]) |
93 | | - temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts01_rots, verts.transpose(0, 1))) |
94 | | - self.verts_local.append(temp.flatten(1, 2)) |
| 78 | + transforms = [ |
| 79 | + ([verts10, verts11, verts01, vertsmm], verts00_rots), |
| 80 | + ([verts00, verts11, verts01, vertsmm], verts10_rots), |
| 81 | + ([verts00, verts10, verts01, vertsmm], verts11_rots), |
| 82 | + ([verts00, verts11, verts10, vertsmm], verts01_rots), |
| 83 | + ([verts00, verts10, verts11, verts01], vertsmm_rots), |
| 84 | + ] |
95 | 85 |
|
96 | | - verts = torch.stack([verts00, verts10, verts11, verts01]) |
97 | | - temp = ref - torch.stack(locs_to_cell_coords_ctrs(vertsmm_rots, verts.transpose(0, 1))) |
98 | | - self.verts_local.append(temp.flatten(1, 2)) |
| 86 | + self.verts_local = [] |
| 87 | + for _verts, rot in transforms: |
| 88 | + # Compute local coordinates |
| 89 | + verts = torch.stack(_verts) |
| 90 | + # shape: <healpix, 4, 3> |
| 91 | + verts = verts.transpose(0, 1) |
| 92 | + # Batch multiplication by the 3x3 rotation matrices. |
| 93 | + # shape: <healpix, 3, 3> @ <healpix, 4, 3> -> <healpix, 4, 3> |
| 94 | + # Needs to transpose first to <healpix, 3, 4> then transpose back. |
| 95 | + t1 = torch.bmm(rot, verts.transpose(-1, -2)).transpose(-2, -1) |
| 96 | + t2 = ref - t1 |
| 97 | + self.verts_local.append(t2.flatten(1, 2)) |
99 | 98 |
|
100 | 99 | self.hpy_verts_local_target = torch.stack(self.verts_local).transpose(0, 1) |
101 | 100 |
|
|
0 commit comments