@@ -43,65 +43,65 @@ def __init__(self, healpix_level: int, seed: int, masker: Masker):
4343 self .num_healpix_cells_source = 12 * 4 ** self .hl_source
4444 self .num_healpix_cells_target = 12 * 4 ** self .hl_target
4545
46- verts00 , verts00_Rs = healpix_verts_rots (self .hl_source , 0.0 , 0.0 )
47- verts10 , verts10_Rs = healpix_verts_rots (self .hl_source , 1.0 , 0.0 )
48- verts11 , verts11_Rs = healpix_verts_rots (self .hl_source , 1.0 , 1.0 )
49- verts01 , verts01_Rs = healpix_verts_rots (self .hl_source , 0.0 , 1.0 )
50- vertsmm , vertsmm_Rs = healpix_verts_rots (self .hl_source , 0.5 , 0.5 )
46+ verts00 , verts00_rots = healpix_verts_rots (self .hl_source , 0.0 , 0.0 )
47+ verts10 , verts10_rots = healpix_verts_rots (self .hl_source , 1.0 , 0.0 )
48+ verts11 , verts11_rots = healpix_verts_rots (self .hl_source , 1.0 , 1.0 )
49+ verts01 , verts01_rots = healpix_verts_rots (self .hl_source , 0.0 , 1.0 )
50+ vertsmm , vertsmm_rots = healpix_verts_rots (self .hl_source , 0.5 , 0.5 )
5151 self .hpy_verts = [
5252 verts00 .to (torch .float32 ),
5353 verts10 .to (torch .float32 ),
5454 verts11 .to (torch .float32 ),
5555 verts01 .to (torch .float32 ),
5656 vertsmm .to (torch .float32 ),
5757 ]
58- self .hpy_verts_Rs_source = [
59- verts00_Rs .to (torch .float32 ),
60- verts10_Rs .to (torch .float32 ),
61- verts11_Rs .to (torch .float32 ),
62- verts01_Rs .to (torch .float32 ),
63- vertsmm_Rs .to (torch .float32 ),
58+ self .hpy_verts_rots_source = [
59+ verts00_rots .to (torch .float32 ),
60+ verts10_rots .to (torch .float32 ),
61+ verts11_rots .to (torch .float32 ),
62+ verts01_rots .to (torch .float32 ),
63+ vertsmm_rots .to (torch .float32 ),
6464 ]
6565
66- verts00 , verts00_Rs = healpix_verts_rots (self .hl_target , 0.0 , 0.0 )
67- verts10 , verts10_Rs = healpix_verts_rots (self .hl_target , 1.0 , 0.0 )
68- verts11 , verts11_Rs = healpix_verts_rots (self .hl_target , 1.0 , 1.0 )
69- verts01 , verts01_Rs = healpix_verts_rots (self .hl_target , 0.0 , 1.0 )
70- vertsmm , vertsmm_Rs = healpix_verts_rots (self .hl_target , 0.5 , 0.5 )
66+ verts00 , verts00_rots = healpix_verts_rots (self .hl_target , 0.0 , 0.0 )
67+ verts10 , verts10_rots = healpix_verts_rots (self .hl_target , 1.0 , 0.0 )
68+ verts11 , verts11_rots = healpix_verts_rots (self .hl_target , 1.0 , 1.0 )
69+ verts01 , verts01_rots = healpix_verts_rots (self .hl_target , 0.0 , 1.0 )
70+ vertsmm , vertsmm_rots = healpix_verts_rots (self .hl_target , 0.5 , 0.5 )
7171 self .hpy_verts = [
7272 verts00 .to (torch .float32 ),
7373 verts10 .to (torch .float32 ),
7474 verts11 .to (torch .float32 ),
7575 verts01 .to (torch .float32 ),
7676 vertsmm .to (torch .float32 ),
7777 ]
78- self .hpy_verts_Rs_target = [
79- verts00_Rs .to (torch .float32 ),
80- verts10_Rs .to (torch .float32 ),
81- verts11_Rs .to (torch .float32 ),
82- verts01_Rs .to (torch .float32 ),
83- vertsmm_Rs .to (torch .float32 ),
78+ self .hpy_verts_rots_target = [
79+ verts00_rots .to (torch .float32 ),
80+ verts10_rots .to (torch .float32 ),
81+ verts11_rots .to (torch .float32 ),
82+ verts01_rots .to (torch .float32 ),
83+ vertsmm_rots .to (torch .float32 ),
8484 ]
8585
8686 self .verts_local = []
8787 verts = torch .stack ([verts10 , verts11 , verts01 , vertsmm ])
88- temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts00_Rs , verts .transpose (0 , 1 )))
88+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts00_rots , verts .transpose (0 , 1 )))
8989 self .verts_local .append (temp .flatten (1 , 2 ))
9090
9191 verts = torch .stack ([verts00 , verts11 , verts01 , vertsmm ])
92- temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts10_Rs , verts .transpose (0 , 1 )))
92+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts10_rots , verts .transpose (0 , 1 )))
9393 self .verts_local .append (temp .flatten (1 , 2 ))
9494
9595 verts = torch .stack ([verts00 , verts10 , verts01 , vertsmm ])
96- temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts11_Rs , verts .transpose (0 , 1 )))
96+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts11_rots , verts .transpose (0 , 1 )))
9797 self .verts_local .append (temp .flatten (1 , 2 ))
9898
9999 verts = torch .stack ([verts00 , verts11 , verts10 , vertsmm ])
100- temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts01_Rs , verts .transpose (0 , 1 )))
100+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts01_rots , verts .transpose (0 , 1 )))
101101 self .verts_local .append (temp .flatten (1 , 2 ))
102102
103103 verts = torch .stack ([verts00 , verts10 , verts11 , verts01 ])
104- temp = ref - torch .stack (locs_to_cell_coords_ctrs (vertsmm_Rs , verts .transpose (0 , 1 )))
104+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (vertsmm_rots , verts .transpose (0 , 1 )))
105105 self .verts_local .append (temp .flatten (1 , 2 ))
106106
107107 self .hpy_verts_local_target = torch .stack (self .verts_local ).transpose (0 , 1 )
@@ -171,7 +171,7 @@ def batchify_source(
171171 time_win = time_win ,
172172 token_size = token_size ,
173173 hl = self .hl_source ,
174- hpy_verts_Rs = self .hpy_verts_Rs_source [- 1 ],
174+ hpy_verts_rots = self .hpy_verts_rots_source [- 1 ],
175175 n_coords = normalizer .normalize_coords ,
176176 n_geoinfos = normalizer .normalize_geoinfos ,
177177 n_data = normalizer .normalize_source_channels ,
@@ -257,7 +257,7 @@ def id(arg):
257257 time_win = time_win ,
258258 token_size = token_size ,
259259 hl = self .hl_source ,
260- hpy_verts_Rs = self .hpy_verts_Rs_source [- 1 ],
260+ hpy_verts_rots = self .hpy_verts_rots_source [- 1 ],
261261 n_coords = id ,
262262 n_geoinfos = normalizer .normalize_geoinfos ,
263263 n_data = normalizer .normalize_target_channels ,
@@ -311,7 +311,7 @@ def id(arg):
311311 target_coords ,
312312 target_geoinfos ,
313313 target_times ,
314- self .hpy_verts_Rs_target ,
314+ self .hpy_verts_rots_target ,
315315 self .hpy_verts_local_target ,
316316 self .hpy_nctrs_target ,
317317 )
0 commit comments