@@ -95,7 +95,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
95
95
for i in range (ceil (ps .shape [0 ] / batch_size )):
96
96
x = ps [batch_size * i :batch_size * (i + 1 ), 0 ]
97
97
y = ps [batch_size * i :batch_size * (i + 1 ), 1 ]
98
- xn = nf [0 ](x ); yn = nf [1 ](y )
98
+ xn , yn = nf [0 ](x ), nf [1 ](y )
99
99
100
100
bs = x .shape [0 ] # may be different from batch size for the last batch
101
101
@@ -133,12 +133,16 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
133
133
134
134
# Compute interpolation weights. We use the mean function of a GP regressor.
135
135
mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
136
- grid_points = torch .cat ([m .unsqueeze (- 1 ) for m in mesh ], dim = - 1 ) # e.g. 100x100x2 a 2D grid with 100 points in each dim
136
+ grid_points = torch .cat (
137
+ [m .unsqueeze (- 1 ) for m in mesh ], dim = - 1
138
+ ) # e.g. 100x100x2 a 2D grid with 100 points in each dim
137
139
K = self ._kernel (grid_points .view (- 1 , len (self .grid ))) # (num_grid)x(num_grid)
138
140
if interpolation_noise > 0.0 :
139
141
K += interpolation_noise * torch .eye (K .shape [0 ])
140
142
num_grid = K .shape [0 ]
141
- self ._alpha = torch .linalg .solve (K , self .__metric__ .view (num_grid , - 1 )) # (num_grid)x(d²) or (num_grid)x(d)
143
+ self ._alpha = torch .linalg .solve (
144
+ K , self .__metric__ .view (num_grid , - 1 )
145
+ ) # (num_grid)x(d²) or (num_grid)x(d)
142
146
except :
143
147
import warnings
144
148
warnings .warn ("It appears that your model does not implement a metric." )
@@ -187,7 +191,7 @@ def _grid_dist2(self, p):
187
191
"""
188
192
189
193
dist2 = torch .zeros (p .shape [0 ], self .G .number_of_nodes ())
190
- mesh = torch .meshgrid (* self .grid , indexing = 'ij' ) # XXX: IT MUST BE POSSIBLE TO AVOID THIS GRID CONSTRUCTION
194
+ mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
191
195
for mesh_dim , dim in zip (mesh , range (len (self .grid ))):
192
196
dist2 += (p [:, dim ].view (- 1 , 1 ) - mesh_dim .reshape (1 , - 1 ))** 2
193
197
return dist2
0 commit comments