|
| 1 | +"""Example demonstrating the computation of alpha complexes. |
| 2 | +
|
| 3 | +This simple example demonstrates how to use alpha complexes to change |
| 4 | +the appearance of a point cloud, following the `TopologyLayer |
| 5 | +<https://github.com/bruel-gabrielsson/TopologyLayer>`_ package. |
| 6 | +
|
| 7 | +This example is still a **work in progress**. |
| 8 | +""" |
| 9 | + |
| 10 | +from torch_topological.nn import AlphaComplex |
| 11 | +from torch_topological.nn import SummaryStatisticLoss |
| 12 | + |
| 13 | +from torch_topological.utils import SelectByDimension |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +import matplotlib.pyplot as plt |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | +if __name__ == '__main__': |
| 21 | + np.random.seed(42) |
| 22 | + data = np.random.rand(100, 2) |
| 23 | + |
| 24 | + alpha_complex = AlphaComplex() |
| 25 | + |
| 26 | + loss_fn = SummaryStatisticLoss( |
| 27 | + summary_statistic='polynomial_function', |
| 28 | + p=2, |
| 29 | + q=0 |
| 30 | + ) |
| 31 | + |
| 32 | + X = torch.nn.Parameter(torch.as_tensor(data), requires_grad=True) |
| 33 | + opt = torch.optim.Adam([X], lr=1e-2) |
| 34 | + |
| 35 | + plt.ion() |
| 36 | + fig, ax = plt.subplots() |
| 37 | + |
| 38 | + for i in range(100): |
| 39 | + # We are only interested in working with persistence diagrams of |
| 40 | + # dimension 1. |
| 41 | + selector = SelectByDimension(1) |
| 42 | + |
| 43 | + # Let's think step by step; apparently, AIs like that! So let's |
| 44 | + # first get the persistence information of our complex. We pass |
| 45 | + # it through the selector to remove diagrams we do not need. |
| 46 | + pers_info = alpha_complex(X) |
| 47 | + pers_info = selector(pers_info) |
| 48 | + |
| 49 | + # Evaluate the loss; notice that we want to *maximise* it in |
| 50 | + # order to improve the holes in the data. |
| 51 | + loss = -loss_fn(pers_info) |
| 52 | + |
| 53 | + opt.zero_grad() |
| 54 | + loss.backward() |
| 55 | + opt.step() |
| 56 | + |
| 57 | + data = X.detach().numpy() |
| 58 | + ax.clear() |
| 59 | + ax.scatter(data[:, 0], data[:, 1]) |
| 60 | + plt.pause(0.2) |
0 commit comments