@@ -276,21 +276,30 @@ def connecting_geodesic(self, p1, p2, curve=None):
276
276
curve input.
277
277
"""
278
278
device = p1 .device
279
- idx1 = self ._grid_point (p1 )
280
- idx2 = self ._grid_point (p2 )
281
- path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
282
- weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path ) - 1 )]
283
- mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
284
- raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
285
- coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
286
- t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
279
+ if p1 .ndim == 1 :
280
+ p1 = p1 .unsqueeze (0 ) # 1xD
281
+ if p2 .ndim == 1 :
282
+ p2 = p2 .unsqueeze (0 ) # 1xD
283
+ B = p1 .shape [0 ]
284
+ if p1 .shape != p2 .shape :
285
+ raise NameError ('shape mismatch' )
287
286
288
287
if curve is None :
289
288
curve = CubicSpline (p1 , p2 )
290
289
else :
291
290
curve .begin = p1
292
291
curve .end = p2
293
292
294
- curve .fit (t , coordinates )
293
+ for b in range (B ):
294
+ idx1 = self ._grid_point (p1 [b ].unsqueeze (0 ))
295
+ idx2 = self ._grid_point (p2 [b ].unsqueeze (0 ))
296
+ path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
297
+ weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path ) - 1 )]
298
+ mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
299
+ raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
300
+ coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
301
+ t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
302
+
303
+ curve [b ].fit (t , coordinates )
295
304
296
305
return curve , True
0 commit comments