Skip to content

Commit 8108f65

Browse files
committed
flake8 improvements
1 parent 647b6f2 commit 8108f65

File tree

1 file changed

+56
-57
lines changed

1 file changed

+56
-57
lines changed

stochman/discretized_manifold.py

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def __init__(self):
1818
self._diagonal_metric = False
1919
self._alpha = torch.Tensor()
2020

21-
2221
def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise=0.0):
2322
"""
2423
Discretize a manifold to a given grid.
@@ -30,18 +29,18 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
3029
the manifold will be discretized. For example,
3130
grid = [torch.linspace(-3, 3, 50), torch.linspace(-3, 3, 50)]
3231
will discretize a two-dimensional manifold on a 50x50 grid.
33-
32+
3433
use_diagonals:
3534
If True, diagonal edges are included in the graph, otherwise
3635
they are excluded.
3736
Default: True.
38-
37+
3938
batch_size: Number of edge-lengths that are computed in parallel. The larger
4039
value you pick here, the faster the discretization will be.
4140
However, memory usage increases with this number, so a good
4241
choice is model and hardware specific.
4342
Default: 4.
44-
43+
4544
interpolation_noise:
4645
On fitting, the manifold metric is evalated on the provided grid.
4746
The `metric` function then performs interpolation of this metric,
@@ -60,53 +59,53 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
6059

6160
# Add nodes to graph
6261
xsize, ysize = len(grid[0]), len(grid[1])
63-
node_idx = lambda x, y: x*ysize + y
64-
self.G.add_nodes_from(range(xsize*ysize))
62+
node_idx = lambda x, y: x * ysize + y
63+
self.G.add_nodes_from(range(xsize * ysize))
6564

6665
point_set = torch.cartesian_prod(
67-
torch.linspace(0, xsize-1, xsize, dtype=torch.long),
68-
torch.linspace(0, ysize-1, ysize, dtype=torch.long)
66+
torch.linspace(0, xsize - 1, xsize, dtype=torch.long),
67+
torch.linspace(0, ysize - 1, ysize, dtype=torch.long)
6968
) # (big)x2
7069

71-
point_sets = [ ] # these will be [N, 2] matrices of index points
72-
neighbour_funcs = [ ] # these will be functions for getting the neighbour index
70+
point_sets = [] # these will be [N, 2] matrices of index points
71+
neighbour_funcs = [] # these will be functions for getting the neighbour index
7372

7473
# add sets
7574
point_sets.append(point_set[point_set[:, 0] > 0]) # x > 0
76-
neighbour_funcs.append([lambda x: x-1, lambda y: y])
75+
neighbour_funcs.append([lambda x: x - 1, lambda y: y])
7776

7877
point_sets.append(point_set[point_set[:, 1] > 0]) # y > 0
79-
neighbour_funcs.append([lambda x: x, lambda y: y-1])
78+
neighbour_funcs.append([lambda x: x, lambda y: y - 1])
8079

81-
point_sets.append(point_set[point_set[:, 0] < xsize-1]) # x < xsize-1
82-
neighbour_funcs.append([lambda x: x+1, lambda y: y])
80+
point_sets.append(point_set[point_set[:, 0] < xsize - 1]) # x < xsize-1
81+
neighbour_funcs.append([lambda x: x + 1, lambda y: y])
82+
83+
point_sets.append(point_set[point_set[:, 1] < ysize - 1]) # y < ysize-1
84+
neighbour_funcs.append([lambda x: x, lambda y: y + 1])
8385

84-
point_sets.append(point_set[point_set[:, 1] < ysize-1]) # y < ysize-1
85-
neighbour_funcs.append([lambda x: x, lambda y: y+1])
86-
8786
if use_diagonals:
88-
point_sets.append(point_set[torch.logical_and(point_set[:,0] > 0, point_set[:,1] > 0)])
89-
neighbour_funcs.append([lambda x: x-1, lambda y: y-1])
87+
point_sets.append(point_set[torch.logical_and(point_set[:, 0] > 0, point_set[:, 1] > 0)])
88+
neighbour_funcs.append([lambda x: x - 1, lambda y: y - 1])
89+
90+
point_sets.append(point_set[torch.logical_and(point_set[:, 0] < xsize - 1, point_set[:, 1] > 0)])
91+
neighbour_funcs.append([lambda x: x + 1, lambda y: y - 1])
9092

91-
point_sets.append(point_set[torch.logical_and(point_set[:,0] < xsize-1, point_set[:,1] > 0)])
92-
neighbour_funcs.append([lambda x: x+1, lambda y: y-1])
93-
9493
t = torch.linspace(0, 1, 2)
9594
for ps, nf in zip(point_sets, neighbour_funcs):
96-
for i in range(ceil(ps.shape[0] / batch_size)):
97-
x = ps[batch_size*i:batch_size*(i+1), 0]
98-
y = ps[batch_size*i:batch_size*(i+1), 1]
95+
for i in range(ceil(ps.shape[0] / batch_size)):
96+
x = ps[batch_size * i:batch_size * (i + 1), 0]
97+
y = ps[batch_size * i:batch_size * (i + 1), 1]
9998
xn = nf[0](x); yn = nf[1](y)
100-
99+
101100
bs = x.shape[0] # may be different from batch size for the last batch
102101

103102
line = CubicSpline(begin=torch.zeros(bs, dim), end=torch.ones(bs, dim), num_nodes=2)
104103
line.begin = torch.cat([grid[0][x].view(-1, 1), grid[1][y].view(-1, 1)], dim=1) # (bs)x2
105104
line.end = torch.cat([grid[0][xn].view(-1, 1), grid[1][yn].view(-1, 1)], dim=1) # (bs)x2
106105

107-
#if external_curve_length_function:
108-
# weight = external_curve_length_function(model, line(t))
109-
#else:
106+
# if external_curve_length_function:
107+
# weight = external_curve_length_function(model, line(t))
108+
# else:
110109
with torch.no_grad():
111110
weight = model.curve_length(line(t))
112111

@@ -123,23 +122,23 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
123122
for x in range(xsize):
124123
for y in range(ysize):
125124
p = torch.tensor([self.grid[0][x], self.grid[1][y]])
126-
Mlist.append(model.metric(p)) # 1x(d)x(d) or 1x(d)
127-
M = torch.cat(Mlist, dim=0) # (big)x(d)x(d) or (big)x(d)
125+
Mlist.append(model.metric(p)) # 1x(d)x(d) or 1x(d)
126+
M = torch.cat(Mlist, dim=0) # (big)x(d)x(d) or (big)x(d)
128127
self._diagonal_metric = M.dim() == 2
129128
d = M.shape[-1]
130129
if self._diagonal_metric:
131-
self.__metric__ = M.view([*self.grid_size, d]) # e.g. (xsize)x(ysize)x(d)
130+
self.__metric__ = M.view([*self.grid_size, d]) # e.g. (xsize)x(ysize)x(d)
132131
else:
133-
self.__metric__ = M.view([*self.grid_size, d, d]) # e.g. (xsize)x(ysize)x(d)x(d)
134-
132+
self.__metric__ = M.view([*self.grid_size, d, d]) # e.g. (xsize)x(ysize)x(d)x(d)
133+
135134
# Compute interpolation weights. We use the mean function of a GP regressor.
136135
mesh = torch.meshgrid(*self.grid, indexing='ij')
137-
grid_points = torch.cat([m.unsqueeze(-1) for m in mesh], dim=-1) # e.g. 100x100x2 a 2D grid with 100 points in each dim
138-
K = self._kernel(grid_points.view(-1, len(self.grid))) # (num_grid)x(num_grid)
136+
grid_points = torch.cat([m.unsqueeze(-1) for m in mesh], dim=-1) # e.g. 100x100x2 a 2D grid with 100 points in each dim
137+
K = self._kernel(grid_points.view(-1, len(self.grid))) # (num_grid)x(num_grid)
139138
if interpolation_noise > 0.0:
140139
K += interpolation_noise * torch.eye(K.shape[0])
141140
num_grid = K.shape[0]
142-
self._alpha = torch.linalg.solve(K, self.__metric__.view(num_grid, -1)) # (num_grid)x(d²) or (num_grid)x(d)
141+
self._alpha = torch.linalg.solve(K, self.__metric__.view(num_grid, -1)) # (num_grid)x(d²) or (num_grid)x(d)
143142
except:
144143
import warnings
145144
warnings.warn("It appears that your model does not implement a metric.")
@@ -169,10 +168,10 @@ def metric(self, points):
169168
a (d)x(d) diagonal matrix.
170169
"""
171170
# XXX: We should also support returning the derivative of the metric! (for ODEs; see local_PCA)
172-
K = self._kernel(points) # Nx(num_grid)
173-
M = K.mm(self._alpha) # Nx(d²) or Nx(d)
171+
K = self._kernel(points) # Nx(num_grid)
172+
M = K.mm(self._alpha) # Nx(d²) or Nx(d)
174173
if not self._diagonal_metric:
175-
d = len(grid)
174+
d = len(self.grid)
176175
M = M.view(-1, d, d)
177176
return M
178177

@@ -198,16 +197,16 @@ def _kernel(self, p):
198197
199198
Input:
200199
p: a torch Tensor corresponding to a point on the manifold.
201-
200+
202201
Output:
203202
val: a torch Tensor with the kernel values.
204203
"""
205-
lengthscales = [(g[1]-g[0])**2 for g in self.grid]
204+
lengthscales = [(g[1] - g[0])**2 for g in self.grid]
206205

207206
dist2 = torch.zeros(p.shape[0], self.G.number_of_nodes())
208207
mesh = torch.meshgrid(*self.grid, indexing='ij')
209208
for mesh_dim, dim in zip(mesh, range(len(self.grid))):
210-
dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1))**2/lengthscales[dim]
209+
dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1))**2 / lengthscales[dim]
211210

212211
return torch.exp(-dist2)
213212

@@ -216,7 +215,7 @@ def _grid_point(self, p):
216215
217216
Input:
218217
p: a torch Tensor corresponding to a latent point.
219-
218+
220219
Output:
221220
idx: an integer correponding to the node index of
222221
the nearest point on the grid.
@@ -230,42 +229,42 @@ def shortest_path(self, p1, p2):
230229
p1: a torch Tensor corresponding to one latent point.
231230
232231
p2: a torch Tensor corresponding to another latent point.
233-
232+
234233
Outputs:
235234
curve: a DiscreteCurve forming the shortest path from p1 to p2.
236235
237236
dist: a scalar indicating the length of the shortest curve.
238237
"""
239238
idx1 = self._grid_point(p1)
240239
idx2 = self._grid_point(p2)
241-
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
242-
#coordinates = self.grid.view(self.grid.shape[0], -1)[:, path] # (dim)xN
240+
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
241+
# coordinates = self.grid.view(self.grid.shape[0], -1)[:, path] # (dim)xN
243242
mesh = torch.meshgrid(*self.grid, indexing='ij')
244243
raw_coordinates = [m.flatten()[path].view(1, -1) for m in mesh]
245-
coordinates = torch.cat(raw_coordinates, dim=0) # (dim)xN
244+
coordinates = torch.cat(raw_coordinates, dim=0) # (dim)xN
246245
N = len(path)
247246
curve = DiscreteCurve(begin=coordinates[:, 0], end=coordinates[:, -1], num_nodes=N)
248247
with torch.no_grad():
249248
curve.parameters[:, :] = coordinates[:, 1:-1].t()
250249
dist = 0
251-
for i in range(N-1):
252-
dist += self.G.edges[path[i], path[i+1]]['weight']
250+
for i in range(N - 1):
251+
dist += self.G.edges[path[i], path[i + 1]]['weight']
253252
return curve, dist
254253

255-
def connecting_geodesic(self, p1, p2, curve=None):
254+
def connecting_geodesic(self, p1, p2, curve=None):
256255
"""Compute the shortest path on the discretized manifold and fit
257256
a smooth curve to the resulting discrete curve.
258257
259258
Inputs:
260259
p1: a torch Tensor corresponding to one latent point.
261260
262261
p2: a torch Tensor corresponding to another latent point.
263-
262+
264263
Optional input:
265264
curve: a curve that should be fitted to the discrete graph
266265
geodesic. By default this is None and a CubicSpline
267266
with default paramaters will be constructed.
268-
267+
269268
Outputs:
270269
curve: a smooth curve forming the shortest path from p1 to p2.
271270
By default the curve is a CubicSpline with its default
@@ -275,13 +274,13 @@ def connecting_geodesic(self, p1, p2, curve=None):
275274
device = p1.device
276275
idx1 = self._grid_point(p1)
277276
idx2 = self._grid_point(p2)
278-
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
279-
weights = [self.G.edges[path[k], path[k+1]]['weight'] for k in range(len(path)-1)]
277+
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
278+
weights = [self.G.edges[path[k], path[k + 1]]['weight'] for k in range(len(path) - 1)]
280279
mesh = torch.meshgrid(*self.grid, indexing='ij')
281280
raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh]
282-
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
281+
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
283282
t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights)
284-
283+
285284
if curve is None:
286285
curve = CubicSpline(p1, p2)
287286
else:

0 commit comments

Comments
 (0)