You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+52-8Lines changed: 52 additions & 8 deletions
Original file line number
Diff line number
Diff line change
@@ -46,17 +46,68 @@ print(J.shape) # jacobian between input and output: torch.size([100, 5, 10])
46
46
47
47
### `stochman.manifold`: Interface for working with Riemannian manifolds
48
48
49
+
A manifold can be constructed simply by specifying its metric. The example below shows a toy example where the metric grows with the distance to the origin.
50
+
51
+
```python
52
+
import torch
53
+
from stochman.manifold import Manifold
54
+
55
+
classMyManifold(Manifold):
56
+
defmetric(self, c, return_deriv=False):
57
+
N, D = c.shape # N is number of points where we evaluate the metric; D is the manifold dimension
G = (1+ sq_dist_to_origin).unsqueeze(-1) * torch.eye(D).repeat(N, 1, 1) # NxDxD
60
+
return G
61
+
62
+
model = MyManifold()
63
+
p0, p1 = torch.randn(1, 2), torch.randn(1, 2)
64
+
c, _ = model.connecting_geodesic(p0, p1) # geodesic between two random points
65
+
```
66
+
67
+
If you manifold is embedded (e.g. an autoencoder) then you only have to provide a function for realizing the embedding (i.e. a decoder) and StochMan takes care of the rest (you, however, have to learn the autoencoder yourself).
68
+
69
+
```python
70
+
import torch
71
+
from stochman.manifold import EmbeddedManifold
72
+
73
+
classAutoencoder(EmbeddedManifold):
74
+
defembed(self, c, jacobian=False):
75
+
returnself.decode(c)
76
+
77
+
model = Autoencoder()
78
+
p0, p1 = torch.randn(1, 2), torch.randn(1, 2)
79
+
c, _ = model.connecting_geodesic(p0, p1) # geodesic between two random points
80
+
```
49
81
50
82
### `stochman.geodesic`: computing geodesics made easy!
51
83
84
+
Geodesics are energy-minimizing curves, and StochMan computes them as such. You can use the high-level `Manifold` interface or the more explicit one:
52
85
53
-
### `stochman.curves`: Simple curve objects
86
+
```python
87
+
import torch
88
+
from stochman.geodesic import geodesic_minimizing_energy
89
+
from stochman.curves import CubicSpline
54
90
91
+
model = MyManifold()
92
+
p0, p1 = torch.randn(1, 2), torch.randn(1, 2)
93
+
curve = CubicSpline(p0, p1)
94
+
geodesic_minimizing_energy(curve, model)
95
+
```
55
96
97
+
### `stochman.curves`: Simple curve objects
56
98
99
+
We often want to manipulate curves when computing geodesics. StochMan provides an implementation of cubic splines and discrete curves, both with the end-points fixed.
57
100
101
+
```python
102
+
import torch
103
+
from stochman.curves import CubicSpline
58
104
105
+
p0, p1 = torch.randn(1, 2), torch.randn(1, 2)
106
+
curve = CubicSpline(p0, p1)
59
107
108
+
t = torch.linspace(0, 1, 50)
109
+
ct = curve(t) # 50x2
110
+
```
60
111
61
112
## Licence
62
113
@@ -73,10 +124,3 @@ If you want to cite the framework feel free to use this (but only if you loved i
0 commit comments