9
9
10
10
def get_subset_mnist (n : int = 1000 ):
11
11
dataset = MNIST (root = "" , download = True )
12
- N = dataset .data .shape [0 ]
12
+ data = dataset .data [dataset .targets == 1 ]
13
+ N = data .shape [0 ]
13
14
idx = np .random .choice (np .arange (N ), size = n )
14
- return dataset . data [ idx ], dataset . targets [idx ]
15
+ return data [idx ]
15
16
16
17
17
18
# Read data
18
- data , targets = get_subset_mnist (n = 1000 )
19
- data = data .reshape (data .shape [0 ], - 1 )
19
+ data = get_subset_mnist (n = 1000 )
20
+ data = data .reshape (data .shape [0 ], - 1 ).to (torch .float )
21
+ cov = torch .cov (data .t ())
22
+ values , vectors = torch .linalg .eigh (cov )
23
+ proj = vectors [:, - 2 :] / values [- 2 :].sqrt ().unsqueeze (0 )
24
+ data = data @ proj
20
25
N , D = data .shape
21
26
22
27
# Parameters for metric
@@ -27,13 +32,14 @@ def get_subset_mnist(n: int = 1000):
27
32
M = stochman .manifold .LocalVarMetric (data = data , sigma = sigma , rho = rho )
28
33
29
34
# Plot metric and data
30
- ran = torch .linspace (- 2.5 , 2.5 , 100 )
31
- X , Y = torch .meshgrid ([ran , ran ])
35
+ plt .figure ()
36
+ ran = torch .linspace (- 3.0 , 3.0 , 100 )
37
+ X , Y = torch .meshgrid ([ran , ran ], indexing = 'ij' )
32
38
XY = torch .stack ((X .flatten (), Y .flatten ()), dim = 1 ) # 10000x2
33
39
gridM = M .metric (XY ) # 10000x2
34
- Mim = gridM .sum (dim = 1 ).reshape ((100 , 100 )).detach ().numpy (). T
40
+ Mim = gridM .sum (dim = 1 ).reshape ((100 , 100 )).detach ().t ()
35
41
plt .imshow (Mim , extent = (ran [0 ], ran [- 1 ], ran [0 ], ran [- 1 ]), origin = "lower" )
36
- plt .plot (data [:, 0 ]. numpy () , data [:, 1 ]. numpy () , "w." , markersize = 1 )
42
+ plt .plot (data [:, 0 ], data [:, 1 ], "w." , markersize = 1 )
37
43
38
44
# Compute geodesics in parallel
39
45
p0 = data [torch .randint (high = N , size = [10 ], dtype = torch .long )] # 10xD
@@ -42,16 +48,30 @@ def get_subset_mnist(n: int = 1000):
42
48
C .plot ()
43
49
C .constant_speed (M )
44
50
C .plot ()
45
- plt .show ()
46
51
47
- # Compute shooting geodesic as a sanity check
48
- p0 = data [0 ] # 1xD
49
- p1 = data [1 ] # 1xD
50
- C , success = M .connecting_geodesic (p0 , p1 )
51
- C .plot ()
52
+ # Construct discretized manifold
53
+ DM = stochman .discretized_manifold .DiscretizedManifold ()
54
+ DM .fit (M , [ran , ran ], batch_size = 100 )
55
+
56
+ # Compute discretized geodesics
57
+ plt .figure ()
58
+ ran2 = torch .linspace (- 3.0 , 3.0 , 133 )
59
+ X2 , Y2 = torch .meshgrid ([ran2 , ran2 ], indexing = 'ij' )
60
+ XY2 = torch .stack ((X2 .flatten (), Y2 .flatten ()), dim = 1 ) # 10000x2
61
+ DMim = DM .metric (XY2 ).log ().sum (dim = 1 ).view (133 , 133 ).t ()
62
+ plt .imshow (DMim , extent = (ran [0 ], ran [- 1 ], ran [0 ], ran [- 1 ]), origin = "lower" )
63
+ plt .plot (data [:, 0 ], data [:, 1 ], "w." , markersize = 1 )
64
+ for k in range (10 ):
65
+ p0 = data [torch .randint (high = N , size = [1 ], dtype = torch .long )] # 1xD
66
+ p1 = data [torch .randint (high = N , size = [1 ], dtype = torch .long )] # 1xD
67
+ C = DM .connecting_geodesic (p0 , p1 )
68
+ C .plot ()
52
69
53
70
# p = C.begin
54
71
# with torch.no_grad():
55
72
# v = C.deriv(torch.zeros(1))
56
73
# c, dc = shooting_geodesic(M, p, v, t=torch.linspace(0, 1, 100))
57
74
# plt.plot(c[:,0,0], c[:,1, 0], 'o')
75
+
76
+
77
+ plt .show ()
0 commit comments