Skip to content

Commit 0054996

Browse files
committed
Do not detach curve end-points as it is not needed
1 parent d25eebc commit 0054996

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

stochman/discretized_manifold.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -285,20 +285,21 @@ def connecting_geodesic(self, p1, p2, curve=None):
285285
raise NameError('shape mismatch')
286286

287287
if curve is None:
288-
curve = CubicSpline(p1.detach(), p2.detach())
288+
curve = CubicSpline(p1, p2)
289289
else:
290290
curve.begin = p1
291291
curve.end = p2
292292

293293
for b in range(B):
294-
idx1 = self._grid_point(p1[b].unsqueeze(0))
295-
idx2 = self._grid_point(p2[b].unsqueeze(0))
296-
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
297-
weights = [self.G.edges[path[k], path[k + 1]]['weight'] for k in range(len(path) - 1)]
298-
mesh = torch.meshgrid(*self.grid, indexing='ij')
299-
raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh]
300-
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
301-
t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights)
294+
with torch.no_grad():
295+
idx1 = self._grid_point(p1[b].unsqueeze(0))
296+
idx2 = self._grid_point(p2[b].unsqueeze(0))
297+
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
298+
weights = [self.G.edges[path[k], path[k + 1]]['weight'] for k in range(len(path) - 1)]
299+
mesh = torch.meshgrid(*self.grid, indexing='ij')
300+
raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh]
301+
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
302+
t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights)
302303

303304
curve[b].fit(t, coordinates)
304305

0 commit comments

Comments
 (0)