Skip to content

Commit e39a734

Browse files
CtrlAltDeliriousPseudomanifold
authored andcommitted
Animate the training loop for the alpha complex example
Creates an animation using matplotlib to visualise how the point cloud forms the required topology over each training loop iteration.
1 parent 9b22d9e commit e39a734

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Example demonstrating the computation of alpha complexes.
2+
3+
This simple example demonstrates how to use alpha complexes to change
4+
the appearance of a point cloud, following the `TopologyLayer
5+
<https://github.com/bruel-gabrielsson/TopologyLayer>`_ package.
6+
7+
This example is still a **work in progress**.
8+
"""
9+
10+
from torch_topological.nn import AlphaComplex
11+
from torch_topological.nn import SummaryStatisticLoss
12+
13+
from torch_topological.utils import SelectByDimension
14+
15+
import numpy as np
16+
import matplotlib.pyplot as plt
17+
18+
import torch
19+
20+
if __name__ == '__main__':
21+
np.random.seed(42)
22+
data = np.random.rand(100, 2)
23+
24+
alpha_complex = AlphaComplex()
25+
26+
loss_fn = SummaryStatisticLoss(
27+
summary_statistic='polynomial_function',
28+
p=2,
29+
q=0
30+
)
31+
32+
X = torch.nn.Parameter(torch.as_tensor(data), requires_grad=True)
33+
opt = torch.optim.Adam([X], lr=1e-2)
34+
35+
plt.ion()
36+
fig, ax = plt.subplots()
37+
38+
for i in range(100):
39+
# We are only interested in working with persistence diagrams of
40+
# dimension 1.
41+
selector = SelectByDimension(1)
42+
43+
# Let's think step by step; apparently, AIs like that! So let's
44+
# first get the persistence information of our complex. We pass
45+
# it through the selector to remove diagrams we do not need.
46+
pers_info = alpha_complex(X)
47+
pers_info = selector(pers_info)
48+
49+
# Evaluate the loss; notice that we want to *maximise* it in
50+
# order to improve the holes in the data.
51+
loss = -loss_fn(pers_info)
52+
53+
opt.zero_grad()
54+
loss.backward()
55+
opt.step()
56+
57+
data = X.detach().numpy()
58+
ax.clear()
59+
ax.scatter(data[:, 0], data[:, 1])
60+
plt.pause(0.2)

0 commit comments

Comments
 (0)