@@ -285,20 +285,21 @@ def connecting_geodesic(self, p1, p2, curve=None):
285
285
raise NameError ('shape mismatch' )
286
286
287
287
if curve is None :
288
- curve = CubicSpline (p1 . detach () , p2 . detach () )
288
+ curve = CubicSpline (p1 , p2 )
289
289
else :
290
290
curve .begin = p1
291
291
curve .end = p2
292
292
293
293
for b in range (B ):
294
- idx1 = self ._grid_point (p1 [b ].unsqueeze (0 ))
295
- idx2 = self ._grid_point (p2 [b ].unsqueeze (0 ))
296
- path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
297
- weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path ) - 1 )]
298
- mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
299
- raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
300
- coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
301
- t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
294
+ with torch .no_grad ():
295
+ idx1 = self ._grid_point (p1 [b ].unsqueeze (0 ))
296
+ idx2 = self ._grid_point (p2 [b ].unsqueeze (0 ))
297
+ path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
298
+ weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path ) - 1 )]
299
+ mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
300
+ raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
301
+ coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
302
+ t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
302
303
303
304
curve [b ].fit (t , coordinates )
304
305
0 commit comments