Skip to content

Commit c6e1c9c

Browse files
committed
Update example to use the right data and to also show-case the discretized manifold
1 parent d3b4d5a commit c6e1c9c

File tree

1 file changed

+34
-14
lines changed

1 file changed

+34
-14
lines changed

examples/local_pca_mnist.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,19 @@
99

1010
def get_subset_mnist(n: int = 1000):
1111
dataset = MNIST(root="", download=True)
12-
N = dataset.data.shape[0]
12+
data = dataset.data[dataset.targets==1]
13+
N = data.shape[0]
1314
idx = np.random.choice(np.arange(N), size=n)
14-
return dataset.data[idx], dataset.targets[idx]
15+
return data[idx]
1516

1617

1718
# Read data
18-
data, targets = get_subset_mnist(n=1000)
19-
data = data.reshape(data.shape[0], -1)
19+
data = get_subset_mnist(n=1000)
20+
data = data.reshape(data.shape[0], -1).to(torch.float)
21+
cov = torch.cov(data.t())
22+
values, vectors = torch.linalg.eigh(cov)
23+
proj = vectors[:, -2:] / values[-2:].sqrt().unsqueeze(0)
24+
data = data @ proj
2025
N, D = data.shape
2126

2227
# Parameters for metric
@@ -27,13 +32,14 @@ def get_subset_mnist(n: int = 1000):
2732
M = stochman.manifold.LocalVarMetric(data=data, sigma=sigma, rho=rho)
2833

2934
# Plot metric and data
30-
ran = torch.linspace(-2.5, 2.5, 100)
31-
X, Y = torch.meshgrid([ran, ran])
35+
plt.figure()
36+
ran = torch.linspace(-3.0, 3.0, 100)
37+
X, Y = torch.meshgrid([ran, ran], indexing='ij')
3238
XY = torch.stack((X.flatten(), Y.flatten()), dim=1) # 10000x2
3339
gridM = M.metric(XY) # 10000x2
34-
Mim = gridM.sum(dim=1).reshape((100, 100)).detach().numpy().T
40+
Mim = gridM.sum(dim=1).reshape((100, 100)).detach().t()
3541
plt.imshow(Mim, extent=(ran[0], ran[-1], ran[0], ran[-1]), origin="lower")
36-
plt.plot(data[:, 0].numpy(), data[:, 1].numpy(), "w.", markersize=1)
42+
plt.plot(data[:, 0], data[:, 1], "w.", markersize=1)
3743

3844
# Compute geodesics in parallel
3945
p0 = data[torch.randint(high=N, size=[10], dtype=torch.long)] # 10xD
@@ -42,16 +48,30 @@ def get_subset_mnist(n: int = 1000):
4248
C.plot()
4349
C.constant_speed(M)
4450
C.plot()
45-
plt.show()
4651

47-
# Compute shooting geodesic as a sanity check
48-
p0 = data[0] # 1xD
49-
p1 = data[1] # 1xD
50-
C, success = M.connecting_geodesic(p0, p1)
51-
C.plot()
52+
# Construct discretized manifold
53+
DM = stochman.discretized_manifold.DiscretizedManifold()
54+
DM.fit(M, [ran, ran], batch_size=100)
55+
56+
# Compute discretized geodesics
57+
plt.figure()
58+
ran2 = torch.linspace(-3.0, 3.0, 133)
59+
X2, Y2 = torch.meshgrid([ran2, ran2], indexing='ij')
60+
XY2 = torch.stack((X2.flatten(), Y2.flatten()), dim=1) # 10000x2
61+
DMim = DM.metric(XY2).log().sum(dim=1).view(133, 133).t()
62+
plt.imshow(DMim, extent=(ran[0], ran[-1], ran[0], ran[-1]), origin="lower")
63+
plt.plot(data[:, 0], data[:, 1], "w.", markersize=1)
64+
for k in range(10):
65+
p0 = data[torch.randint(high=N, size=[1], dtype=torch.long)] # 1xD
66+
p1 = data[torch.randint(high=N, size=[1], dtype=torch.long)] # 1xD
67+
C = DM.connecting_geodesic(p0, p1)
68+
C.plot()
5269

5370
# p = C.begin
5471
# with torch.no_grad():
5572
# v = C.deriv(torch.zeros(1))
5673
# c, dc = shooting_geodesic(M, p, v, t=torch.linspace(0, 1, 100))
5774
# plt.plot(c[:,0,0], c[:,1, 0], 'o')
75+
76+
77+
plt.show()

0 commit comments

Comments
 (0)