Skip to content

Commit c068605

Browse files
committed
Fix flake8 formating
1 parent 8108f65 commit c068605

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

stochman/discretized_manifold.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
9595
for i in range(ceil(ps.shape[0] / batch_size)):
9696
x = ps[batch_size * i:batch_size * (i + 1), 0]
9797
y = ps[batch_size * i:batch_size * (i + 1), 1]
98-
xn = nf[0](x); yn = nf[1](y)
98+
xn, yn = nf[0](x), nf[1](y)
9999

100100
bs = x.shape[0] # may be different from batch size for the last batch
101101

@@ -133,12 +133,16 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
133133

134134
# Compute interpolation weights. We use the mean function of a GP regressor.
135135
mesh = torch.meshgrid(*self.grid, indexing='ij')
136-
grid_points = torch.cat([m.unsqueeze(-1) for m in mesh], dim=-1) # e.g. 100x100x2 a 2D grid with 100 points in each dim
136+
grid_points = torch.cat(
137+
[m.unsqueeze(-1) for m in mesh], dim=-1
138+
) # e.g. 100x100x2 a 2D grid with 100 points in each dim
137139
K = self._kernel(grid_points.view(-1, len(self.grid))) # (num_grid)x(num_grid)
138140
if interpolation_noise > 0.0:
139141
K += interpolation_noise * torch.eye(K.shape[0])
140142
num_grid = K.shape[0]
141-
self._alpha = torch.linalg.solve(K, self.__metric__.view(num_grid, -1)) # (num_grid)x(d²) or (num_grid)x(d)
143+
self._alpha = torch.linalg.solve(
144+
K, self.__metric__.view(num_grid, -1)
145+
) # (num_grid)x(d²) or (num_grid)x(d)
142146
except:
143147
import warnings
144148
warnings.warn("It appears that your model does not implement a metric.")
@@ -187,7 +191,7 @@ def _grid_dist2(self, p):
187191
"""
188192

189193
dist2 = torch.zeros(p.shape[0], self.G.number_of_nodes())
190-
mesh = torch.meshgrid(*self.grid, indexing='ij') # XXX: IT MUST BE POSSIBLE TO AVOID THIS GRID CONSTRUCTION
194+
mesh = torch.meshgrid(*self.grid, indexing='ij')
191195
for mesh_dim, dim in zip(mesh, range(len(self.grid))):
192196
dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1))**2
193197
return dist2

0 commit comments

Comments
 (0)