Skip to content

Commit 05f4256

Browse files
authored
Update README.md
Add examples
1 parent 125ec22 commit 05f4256

File tree

1 file changed

+52
-8
lines changed

1 file changed

+52
-8
lines changed

README.md

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,68 @@ print(J.shape) # jacobian between input and output: torch.size([100, 5, 10])
4646

4747
### `stochman.manifold`: Interface for working with Riemannian manifolds
4848

49+
A manifold can be constructed simply by specifying its metric. The example below shows a toy example where the metric grows with the distance to the origin.
50+
51+
``` python
52+
import torch
53+
from stochman.manifold import Manifold
54+
55+
class MyManifold(Manifold):
56+
def metric(self, c, return_deriv=False):
57+
N, D = c.shape # N is number of points where we evaluate the metric; D is the manifold dimension
58+
sq_dist_to_origin = torch.sum(c**2, dim=1, keepdim=True) # Nx1
59+
G = (1 + sq_dist_to_origin).unsqueeze(-1) * torch.eye(D).repeat(N, 1, 1) # NxDxD
60+
return G
61+
62+
model = MyManifold()
63+
p0, p1 = torch.randn(1, 2), torch.randn(1, 2)
64+
c, _ = model.connecting_geodesic(p0, p1) # geodesic between two random points
65+
```
66+
67+
If you manifold is embedded (e.g. an autoencoder) then you only have to provide a function for realizing the embedding (i.e. a decoder) and StochMan takes care of the rest (you, however, have to learn the autoencoder yourself).
68+
69+
``` python
70+
import torch
71+
from stochman.manifold import EmbeddedManifold
72+
73+
class Autoencoder(EmbeddedManifold):
74+
def embed(self, c, jacobian = False):
75+
return self.decode(c)
76+
77+
model = Autoencoder()
78+
p0, p1 = torch.randn(1, 2), torch.randn(1, 2)
79+
c, _ = model.connecting_geodesic(p0, p1) # geodesic between two random points
80+
```
4981

5082
### `stochman.geodesic`: computing geodesics made easy!
5183

84+
Geodesics are energy-minimizing curves, and StochMan computes them as such. You can use the high-level `Manifold` interface or the more explicit one:
5285

53-
### `stochman.curves`: Simple curve objects
86+
``` python
87+
import torch
88+
from stochman.geodesic import geodesic_minimizing_energy
89+
from stochman.curves import CubicSpline
5490

91+
model = MyManifold()
92+
p0, p1 = torch.randn(1, 2), torch.randn(1, 2)
93+
curve = CubicSpline(p0, p1)
94+
geodesic_minimizing_energy(curve, model)
95+
```
5596

97+
### `stochman.curves`: Simple curve objects
5698

99+
We often want to manipulate curves when computing geodesics. StochMan provides an implementation of cubic splines and discrete curves, both with the end-points fixed.
57100

101+
``` python
102+
import torch
103+
from stochman.curves import CubicSpline
58104

105+
p0, p1 = torch.randn(1, 2), torch.randn(1, 2)
106+
curve = CubicSpline(p0, p1)
59107

108+
t = torch.linspace(0, 1, 50)
109+
ct = curve(t) # 50x2
110+
```
60111

61112
## Licence
62113

@@ -73,10 +124,3 @@ If you want to cite the framework feel free to use this (but only if you loved i
73124
year={2021}
74125
}
75126
```
76-
77-
78-
79-
80-
81-
82-

0 commit comments

Comments
 (0)