Skip to content

Commit 8a5eed3

Browse files
CtrlAltDeliriousPseudomanifold
authored andcommitted
Animate the persistent homology diagram for alpha complex
Visualises how the persistent homology birth and death features evolve over the iterations in the training loop.
1 parent e39a734 commit 8a5eed3

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Example demonstrating the computation of alpha complexes.
2+
3+
This simple example demonstrates how to use alpha complexes to change
4+
the appearance of a point cloud, following the `TopologyLayer
5+
<https://github.com/bruel-gabrielsson/TopologyLayer>`_ package.
6+
7+
This example is still a **work in progress**.
8+
"""
9+
10+
from torch_topological.nn import AlphaComplex
11+
from torch_topological.nn import SummaryStatisticLoss
12+
13+
from torch_topological.utils import SelectByDimension
14+
15+
import numpy as np
16+
import matplotlib.pyplot as plt
17+
18+
import torch
19+
20+
if __name__ == '__main__':
21+
np.random.seed(42)
22+
data = np.random.rand(100, 2)
23+
24+
alpha_complex = AlphaComplex()
25+
26+
loss_fn = SummaryStatisticLoss(
27+
summary_statistic='polynomial_function',
28+
p=2,
29+
q=0
30+
)
31+
32+
X = torch.nn.Parameter(torch.as_tensor(data), requires_grad=True)
33+
opt = torch.optim.Adam([X], lr=1e-2)
34+
35+
plt.ion()
36+
fig, ax = plt.subplots()
37+
38+
for i in range(100):
39+
# We are only interested in working with persistence diagrams of
40+
# dimension 1.
41+
selector = SelectByDimension(1)
42+
43+
# Let's think step by step; apparently, AIs like that! So let's
44+
# first get the persistence information of our complex. We pass
45+
# it through the selector to remove diagrams we do not need.
46+
pers_info = alpha_complex(X)
47+
pers_info = selector(pers_info)
48+
49+
data=pers_info[0].diagram
50+
data = data.detach().cpu().numpy()
51+
ax.clear()
52+
plt.xlabel("Birth")
53+
plt.ylabel("Death")
54+
ax.plot(data[:, 0], data[:, 1], 'ro', [0,0.8], [0,0.8], 'k-')
55+
plt.pause(0.2)
56+
57+
# Evaluate the loss; notice that we want to *maximise* it in
58+
# order to improve the holes in the data.
59+
loss = -loss_fn(pers_info)
60+
61+
opt.zero_grad()
62+
loss.backward()
63+
opt.step()

0 commit comments

Comments
 (0)