Skip to content

Commit dd66696

Browse files
committed
Support batching in a simple way
1 parent e6e4f62 commit dd66696

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

stochman/discretized_manifold.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,21 +276,30 @@ def connecting_geodesic(self, p1, p2, curve=None):
276276
curve input.
277277
"""
278278
device = p1.device
279-
idx1 = self._grid_point(p1)
280-
idx2 = self._grid_point(p2)
281-
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
282-
weights = [self.G.edges[path[k], path[k + 1]]['weight'] for k in range(len(path) - 1)]
283-
mesh = torch.meshgrid(*self.grid, indexing='ij')
284-
raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh]
285-
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
286-
t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights)
279+
if p1.ndim == 1:
280+
p1 = p1.unsqueeze(0) # 1xD
281+
if p2.ndim == 1:
282+
p2 = p2.unsqueeze(0) # 1xD
283+
B = p1.shape[0]
284+
if p1.shape != p2.shape:
285+
raise NameError('shape mismatch')
287286

288287
if curve is None:
289288
curve = CubicSpline(p1, p2)
290289
else:
291290
curve.begin = p1
292291
curve.end = p2
293292

294-
curve.fit(t, coordinates)
293+
for b in range(B):
294+
idx1 = self._grid_point(p1[b].unsqueeze(0))
295+
idx2 = self._grid_point(p2[b].unsqueeze(0))
296+
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
297+
weights = [self.G.edges[path[k], path[k + 1]]['weight'] for k in range(len(path) - 1)]
298+
mesh = torch.meshgrid(*self.grid, indexing='ij')
299+
raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh]
300+
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
301+
t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights)
302+
303+
curve[b].fit(t, coordinates)
295304

296305
return curve, True

0 commit comments

Comments
 (0)