@@ -18,7 +18,6 @@ def __init__(self):
18
18
self ._diagonal_metric = False
19
19
self ._alpha = torch .Tensor ()
20
20
21
-
22
21
def fit (self , model , grid , use_diagonals = True , batch_size = 4 , interpolation_noise = 0.0 ):
23
22
"""
24
23
Discretize a manifold to a given grid.
@@ -30,18 +29,18 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
30
29
the manifold will be discretized. For example,
31
30
grid = [torch.linspace(-3, 3, 50), torch.linspace(-3, 3, 50)]
32
31
will discretize a two-dimensional manifold on a 50x50 grid.
33
-
32
+
34
33
use_diagonals:
35
34
If True, diagonal edges are included in the graph, otherwise
36
35
they are excluded.
37
36
Default: True.
38
-
37
+
39
38
batch_size: Number of edge-lengths that are computed in parallel. The larger
40
39
value you pick here, the faster the discretization will be.
41
40
However, memory usage increases with this number, so a good
42
41
choice is model and hardware specific.
43
42
Default: 4.
44
-
43
+
45
44
interpolation_noise:
46
45
On fitting, the manifold metric is evalated on the provided grid.
47
46
The `metric` function then performs interpolation of this metric,
@@ -60,53 +59,53 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
60
59
61
60
# Add nodes to graph
62
61
xsize , ysize = len (grid [0 ]), len (grid [1 ])
63
- node_idx = lambda x , y : x * ysize + y
64
- self .G .add_nodes_from (range (xsize * ysize ))
62
+ node_idx = lambda x , y : x * ysize + y
63
+ self .G .add_nodes_from (range (xsize * ysize ))
65
64
66
65
point_set = torch .cartesian_prod (
67
- torch .linspace (0 , xsize - 1 , xsize , dtype = torch .long ),
68
- torch .linspace (0 , ysize - 1 , ysize , dtype = torch .long )
66
+ torch .linspace (0 , xsize - 1 , xsize , dtype = torch .long ),
67
+ torch .linspace (0 , ysize - 1 , ysize , dtype = torch .long )
69
68
) # (big)x2
70
69
71
- point_sets = [ ] # these will be [N, 2] matrices of index points
72
- neighbour_funcs = [ ] # these will be functions for getting the neighbour index
70
+ point_sets = [] # these will be [N, 2] matrices of index points
71
+ neighbour_funcs = [] # these will be functions for getting the neighbour index
73
72
74
73
# add sets
75
74
point_sets .append (point_set [point_set [:, 0 ] > 0 ]) # x > 0
76
- neighbour_funcs .append ([lambda x : x - 1 , lambda y : y ])
75
+ neighbour_funcs .append ([lambda x : x - 1 , lambda y : y ])
77
76
78
77
point_sets .append (point_set [point_set [:, 1 ] > 0 ]) # y > 0
79
- neighbour_funcs .append ([lambda x : x , lambda y : y - 1 ])
78
+ neighbour_funcs .append ([lambda x : x , lambda y : y - 1 ])
80
79
81
- point_sets .append (point_set [point_set [:, 0 ] < xsize - 1 ]) # x < xsize-1
82
- neighbour_funcs .append ([lambda x : x + 1 , lambda y : y ])
80
+ point_sets .append (point_set [point_set [:, 0 ] < xsize - 1 ]) # x < xsize-1
81
+ neighbour_funcs .append ([lambda x : x + 1 , lambda y : y ])
82
+
83
+ point_sets .append (point_set [point_set [:, 1 ] < ysize - 1 ]) # y < ysize-1
84
+ neighbour_funcs .append ([lambda x : x , lambda y : y + 1 ])
83
85
84
- point_sets .append (point_set [point_set [:, 1 ] < ysize - 1 ]) # y < ysize-1
85
- neighbour_funcs .append ([lambda x : x , lambda y : y + 1 ])
86
-
87
86
if use_diagonals :
88
- point_sets .append (point_set [torch .logical_and (point_set [:,0 ] > 0 , point_set [:,1 ] > 0 )])
89
- neighbour_funcs .append ([lambda x : x - 1 , lambda y : y - 1 ])
87
+ point_sets .append (point_set [torch .logical_and (point_set [:, 0 ] > 0 , point_set [:, 1 ] > 0 )])
88
+ neighbour_funcs .append ([lambda x : x - 1 , lambda y : y - 1 ])
89
+
90
+ point_sets .append (point_set [torch .logical_and (point_set [:, 0 ] < xsize - 1 , point_set [:, 1 ] > 0 )])
91
+ neighbour_funcs .append ([lambda x : x + 1 , lambda y : y - 1 ])
90
92
91
- point_sets .append (point_set [torch .logical_and (point_set [:,0 ] < xsize - 1 , point_set [:,1 ] > 0 )])
92
- neighbour_funcs .append ([lambda x : x + 1 , lambda y : y - 1 ])
93
-
94
93
t = torch .linspace (0 , 1 , 2 )
95
94
for ps , nf in zip (point_sets , neighbour_funcs ):
96
- for i in range (ceil (ps .shape [0 ] / batch_size )):
97
- x = ps [batch_size * i :batch_size * ( i + 1 ), 0 ]
98
- y = ps [batch_size * i :batch_size * ( i + 1 ), 1 ]
95
+ for i in range (ceil (ps .shape [0 ] / batch_size )):
96
+ x = ps [batch_size * i :batch_size * ( i + 1 ), 0 ]
97
+ y = ps [batch_size * i :batch_size * ( i + 1 ), 1 ]
99
98
xn = nf [0 ](x ); yn = nf [1 ](y )
100
-
99
+
101
100
bs = x .shape [0 ] # may be different from batch size for the last batch
102
101
103
102
line = CubicSpline (begin = torch .zeros (bs , dim ), end = torch .ones (bs , dim ), num_nodes = 2 )
104
103
line .begin = torch .cat ([grid [0 ][x ].view (- 1 , 1 ), grid [1 ][y ].view (- 1 , 1 )], dim = 1 ) # (bs)x2
105
104
line .end = torch .cat ([grid [0 ][xn ].view (- 1 , 1 ), grid [1 ][yn ].view (- 1 , 1 )], dim = 1 ) # (bs)x2
106
105
107
- #if external_curve_length_function:
108
- # weight = external_curve_length_function(model, line(t))
109
- #else:
106
+ # if external_curve_length_function:
107
+ # weight = external_curve_length_function(model, line(t))
108
+ # else:
110
109
with torch .no_grad ():
111
110
weight = model .curve_length (line (t ))
112
111
@@ -123,23 +122,23 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
123
122
for x in range (xsize ):
124
123
for y in range (ysize ):
125
124
p = torch .tensor ([self .grid [0 ][x ], self .grid [1 ][y ]])
126
- Mlist .append (model .metric (p )) # 1x(d)x(d) or 1x(d)
127
- M = torch .cat (Mlist , dim = 0 ) # (big)x(d)x(d) or (big)x(d)
125
+ Mlist .append (model .metric (p )) # 1x(d)x(d) or 1x(d)
126
+ M = torch .cat (Mlist , dim = 0 ) # (big)x(d)x(d) or (big)x(d)
128
127
self ._diagonal_metric = M .dim () == 2
129
128
d = M .shape [- 1 ]
130
129
if self ._diagonal_metric :
131
- self .__metric__ = M .view ([* self .grid_size , d ]) # e.g. (xsize)x(ysize)x(d)
130
+ self .__metric__ = M .view ([* self .grid_size , d ]) # e.g. (xsize)x(ysize)x(d)
132
131
else :
133
- self .__metric__ = M .view ([* self .grid_size , d , d ]) # e.g. (xsize)x(ysize)x(d)x(d)
134
-
132
+ self .__metric__ = M .view ([* self .grid_size , d , d ]) # e.g. (xsize)x(ysize)x(d)x(d)
133
+
135
134
# Compute interpolation weights. We use the mean function of a GP regressor.
136
135
mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
137
- grid_points = torch .cat ([m .unsqueeze (- 1 ) for m in mesh ], dim = - 1 ) # e.g. 100x100x2 a 2D grid with 100 points in each dim
138
- K = self ._kernel (grid_points .view (- 1 , len (self .grid ))) # (num_grid)x(num_grid)
136
+ grid_points = torch .cat ([m .unsqueeze (- 1 ) for m in mesh ], dim = - 1 ) # e.g. 100x100x2 a 2D grid with 100 points in each dim
137
+ K = self ._kernel (grid_points .view (- 1 , len (self .grid ))) # (num_grid)x(num_grid)
139
138
if interpolation_noise > 0.0 :
140
139
K += interpolation_noise * torch .eye (K .shape [0 ])
141
140
num_grid = K .shape [0 ]
142
- self ._alpha = torch .linalg .solve (K , self .__metric__ .view (num_grid , - 1 )) # (num_grid)x(d²) or (num_grid)x(d)
141
+ self ._alpha = torch .linalg .solve (K , self .__metric__ .view (num_grid , - 1 )) # (num_grid)x(d²) or (num_grid)x(d)
143
142
except :
144
143
import warnings
145
144
warnings .warn ("It appears that your model does not implement a metric." )
@@ -169,10 +168,10 @@ def metric(self, points):
169
168
a (d)x(d) diagonal matrix.
170
169
"""
171
170
# XXX: We should also support returning the derivative of the metric! (for ODEs; see local_PCA)
172
- K = self ._kernel (points ) # Nx(num_grid)
173
- M = K .mm (self ._alpha ) # Nx(d²) or Nx(d)
171
+ K = self ._kernel (points ) # Nx(num_grid)
172
+ M = K .mm (self ._alpha ) # Nx(d²) or Nx(d)
174
173
if not self ._diagonal_metric :
175
- d = len (grid )
174
+ d = len (self . grid )
176
175
M = M .view (- 1 , d , d )
177
176
return M
178
177
@@ -198,16 +197,16 @@ def _kernel(self, p):
198
197
199
198
Input:
200
199
p: a torch Tensor corresponding to a point on the manifold.
201
-
200
+
202
201
Output:
203
202
val: a torch Tensor with the kernel values.
204
203
"""
205
- lengthscales = [(g [1 ]- g [0 ])** 2 for g in self .grid ]
204
+ lengthscales = [(g [1 ] - g [0 ])** 2 for g in self .grid ]
206
205
207
206
dist2 = torch .zeros (p .shape [0 ], self .G .number_of_nodes ())
208
207
mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
209
208
for mesh_dim , dim in zip (mesh , range (len (self .grid ))):
210
- dist2 += (p [:, dim ].view (- 1 , 1 ) - mesh_dim .reshape (1 , - 1 ))** 2 / lengthscales [dim ]
209
+ dist2 += (p [:, dim ].view (- 1 , 1 ) - mesh_dim .reshape (1 , - 1 ))** 2 / lengthscales [dim ]
211
210
212
211
return torch .exp (- dist2 )
213
212
@@ -216,7 +215,7 @@ def _grid_point(self, p):
216
215
217
216
Input:
218
217
p: a torch Tensor corresponding to a latent point.
219
-
218
+
220
219
Output:
221
220
idx: an integer correponding to the node index of
222
221
the nearest point on the grid.
@@ -230,42 +229,42 @@ def shortest_path(self, p1, p2):
230
229
p1: a torch Tensor corresponding to one latent point.
231
230
232
231
p2: a torch Tensor corresponding to another latent point.
233
-
232
+
234
233
Outputs:
235
234
curve: a DiscreteCurve forming the shortest path from p1 to p2.
236
235
237
236
dist: a scalar indicating the length of the shortest curve.
238
237
"""
239
238
idx1 = self ._grid_point (p1 )
240
239
idx2 = self ._grid_point (p2 )
241
- path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
242
- #coordinates = self.grid.view(self.grid.shape[0], -1)[:, path] # (dim)xN
240
+ path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
241
+ # coordinates = self.grid.view(self.grid.shape[0], -1)[:, path] # (dim)xN
243
242
mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
244
243
raw_coordinates = [m .flatten ()[path ].view (1 , - 1 ) for m in mesh ]
245
- coordinates = torch .cat (raw_coordinates , dim = 0 ) # (dim)xN
244
+ coordinates = torch .cat (raw_coordinates , dim = 0 ) # (dim)xN
246
245
N = len (path )
247
246
curve = DiscreteCurve (begin = coordinates [:, 0 ], end = coordinates [:, - 1 ], num_nodes = N )
248
247
with torch .no_grad ():
249
248
curve .parameters [:, :] = coordinates [:, 1 :- 1 ].t ()
250
249
dist = 0
251
- for i in range (N - 1 ):
252
- dist += self .G .edges [path [i ], path [i + 1 ]]['weight' ]
250
+ for i in range (N - 1 ):
251
+ dist += self .G .edges [path [i ], path [i + 1 ]]['weight' ]
253
252
return curve , dist
254
253
255
- def connecting_geodesic (self , p1 , p2 , curve = None ):
254
+ def connecting_geodesic (self , p1 , p2 , curve = None ):
256
255
"""Compute the shortest path on the discretized manifold and fit
257
256
a smooth curve to the resulting discrete curve.
258
257
259
258
Inputs:
260
259
p1: a torch Tensor corresponding to one latent point.
261
260
262
261
p2: a torch Tensor corresponding to another latent point.
263
-
262
+
264
263
Optional input:
265
264
curve: a curve that should be fitted to the discrete graph
266
265
geodesic. By default this is None and a CubicSpline
267
266
with default paramaters will be constructed.
268
-
267
+
269
268
Outputs:
270
269
curve: a smooth curve forming the shortest path from p1 to p2.
271
270
By default the curve is a CubicSpline with its default
@@ -275,13 +274,13 @@ def connecting_geodesic(self, p1, p2, curve=None):
275
274
device = p1 .device
276
275
idx1 = self ._grid_point (p1 )
277
276
idx2 = self ._grid_point (p2 )
278
- path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
279
- weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path )- 1 )]
277
+ path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
278
+ weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path ) - 1 )]
280
279
mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
281
280
raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
282
- coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
281
+ coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
283
282
t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
284
-
283
+
285
284
if curve is None :
286
285
curve = CubicSpline (p1 , p2 )
287
286
else :
0 commit comments