-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathcubical_complex.py
107 lines (74 loc) · 2.54 KB
/
cubical_complex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Demo for calculating cubical complexes.
This example demonstrates how to perform topological operations on
a structured array, such as a grey-scale image.
"""
import numpy as np
import matplotlib.pyplot as plt
from torch_topological.nn import CubicalComplex
from torch_topological.nn import WassersteinDistance
from tqdm import tqdm
import torch
def sample_circles(n_cells, n_samples=1000):
"""Sample two nested circles and bin them.
Parameters
----------
n_cells : int
Number of cells for the 2D histogram, i.e. the 'resolution' of
the histogram.
n_samples : int
Number of samples to use for creating the nested circles
coordinates.
Returns
-------
np.ndarray of shape ``(n_cells, n_cells)``
Structured array containing intensity values for the data set.
"""
from sklearn.datasets import make_circles
X = make_circles(n_samples, shuffle=True, noise=0.01)[0]
heatmap, *_ = np.histogram2d(X[:, 0], X[:, 1], bins=n_cells)
heatmap -= heatmap.mean()
heatmap /= heatmap.max()
return heatmap
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device', device)
np.random.seed(23)
Y = sample_circles(50)
Y = torch.as_tensor(Y, dtype=torch.float)
X = torch.as_tensor(
Y + np.random.normal(scale=0.20, size=Y.shape),
dtype=torch.float,
device=device,
)
Y = Y.to(device)
X = torch.nn.Parameter(X, requires_grad=True).to(device)
source = X.clone()
optimizer = torch.optim.Adam([X], lr=1e-3)
loss_fn = WassersteinDistance(q=2)
cubical_complex = CubicalComplex()
persistence_information_target = cubical_complex(Y)
persistence_information_target = persistence_information_target[0]
n_iter = 500
progress = tqdm(range(n_iter))
for i in progress:
persistence_information = cubical_complex(X)
persistence_information = persistence_information[0]
optimizer.zero_grad()
loss = loss_fn(
persistence_information,
persistence_information_target
)
loss.backward()
optimizer.step()
progress.set_postfix(loss=loss.item())
source = source.detach().numpy()
target = Y.cpu().detach().numpy()
result = X.cpu().detach().numpy()
fig, ax = plt.subplots(ncols=3)
ax[0].imshow(source)
ax[0].set_title('Source')
ax[1].imshow(target)
ax[1].set_title('Target')
ax[2].imshow(result)
ax[2].set_title('Result')
plt.show()