Skip to content

Commit ff8ec68

Browse files
committed
Use batching
1 parent dd66696 commit ff8ec68

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

examples/local_pca_mnist.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,10 @@ def get_subset_mnist(n: int = 1000):
6161
DMim = DM.metric(XY2).log().sum(dim=1).view(133, 133).t()
6262
plt.imshow(DMim, extent=(ran[0], ran[-1], ran[0], ran[-1]), origin="lower")
6363
plt.plot(data[:, 0], data[:, 1], "w.", markersize=1)
64-
for k in range(10):
65-
p0 = data[torch.randint(high=N, size=[1], dtype=torch.long)] # 1xD
66-
p1 = data[torch.randint(high=N, size=[1], dtype=torch.long)] # 1xD
67-
C, _ = DM.connecting_geodesic(p0, p1)
68-
C.plot()
64+
p0 = data[torch.randint(high=N, size=[10], dtype=torch.long)] # 10xD
65+
p1 = data[torch.randint(high=N, size=[10], dtype=torch.long)] # 10xD
66+
C, success = DM.connecting_geodesic(p0, p1)
67+
C.plot()
6968

7069
# p = C.begin
7170
# with torch.no_grad():

0 commit comments

Comments
 (0)