Skip to content

Commit 9efff97

Browse files
committed
pc sampling
1 parent dfe7e6c commit 9efff97

File tree

1 file changed

+41
-14
lines changed

1 file changed

+41
-14
lines changed

src/klay/backends/torch_backend.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from functools import reduce
23

34
import torch
45
from torch import nn
@@ -38,6 +39,7 @@ class KnowledgeModule(nn.Module):
3839
def __init__(self, pointers, csrs, semiring='real', probabilistic=False):
3940
super(KnowledgeModule, self).__init__()
4041
layers = []
42+
self.probabilistic = probabilistic
4143
sum_layer, prod_layer, self.zero, self.one, self.negate = get_semiring(semiring, probabilistic)
4244
for i, (ptrs, csr) in enumerate(zip(pointers, csrs)):
4345
ptrs = torch.as_tensor(ptrs)
@@ -61,27 +63,41 @@ def sparsity(self, nb_vars: int) -> float:
6163
dense_params = sum(layer_widths[i] * layer_widths[i+1] for i in range(len(layer_widths) - 1))
6264
return sparse_params / dense_params
6365

66+
def sample_pc(self):
67+
assert self.probabilistic
68+
y = torch.tensor([1])
69+
for layer in reversed(self.layers):
70+
y = layer.sample_pc(y)
71+
return y[2::2]
72+
6473

6574
class KnowledgeLayer(nn.Module):
6675
def __init__(self, ptrs, csr):
6776
super().__init__()
6877
self.register_buffer('ptrs', ptrs)
6978
self.register_buffer('csr', csr)
7079
self.out_shape = (self.csr[-1].item() + 1,)
80+
self.in_shape = (self.ptrs.max() + 1,)
81+
82+
def _scatter_forward(self, x: torch.Tensor, reduce: str):
83+
output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device)
84+
output = torch.scatter_reduce(output, 0, index=self.csr, src=x, reduce=reduce, include_self=False)
85+
return output
7186

72-
def _scatter_reduce(self, src: torch.Tensor, reduce: str):
73-
output = torch.empty(self.out_shape, dtype=src.dtype, device=src.device)
74-
output = torch.scatter_reduce(output, 0, index=self.csr, src=src, reduce=reduce, include_self=False)
87+
def _scatter_backward(self, x: torch.Tensor, reduce: str):
88+
output = torch.empty(self.in_shape, dtype=x.dtype, device=x.device)
89+
output = torch.scatter_reduce(output, 0, index=self.ptrs, src=x, reduce=reduce, include_self=False)
7590
return output
7691

92+
7793
def _safe_exp(self, x: torch.Tensor):
7894
with torch.no_grad():
79-
max_output = self._scatter_reduce(x, "amax")
95+
max_output = self._scatter_forward(x, "amax")
8096
x = x - max_output[self.csr]
8197
x.nan_to_num_(nan=0., posinf=float('inf'), neginf=float('-inf'))
8298
return torch.exp(x), max_output
8399

84-
def _logsumexp_scatter_reduce(self, x: torch.Tensor, epsilon: float):
100+
def _logsumexp_scatter(self, x: torch.Tensor, epsilon: float):
85101
x, max_output = self._safe_exp(x)
86102
output = torch.full(self.out_shape, epsilon, dtype=x.dtype, device=x.device)
87103
output = torch.scatter_add(output, 0, index=self.csr, src=x)
@@ -98,49 +114,60 @@ def __init__(self, ptrs, csr):
98114

99115
class SumLayer(KnowledgeLayer):
100116
def forward(self, x):
101-
return self._scatter_reduce(x[self.ptrs], "sum")
117+
return self._scatter_forward(x[self.ptrs], "sum")
118+
119+
def sample_pc(self, y):
120+
return self._scatter_backward(y[self.csr], "sum") > 0
102121

103122

104123
class ProdLayer(KnowledgeLayer):
105124
def forward(self, x):
106-
return self._scatter_reduce(x[self.ptrs], "prod")
125+
return self._scatter_forward(x[self.ptrs], "prod")
107126

108127

109128
class MinLayer(KnowledgeLayer):
110129
def forward(self, x):
111-
return self._scatter_reduce(x[self.ptrs], "amin")
130+
return self._scatter_forward(x[self.ptrs], "amin")
112131

113132

114133
class MaxLayer(KnowledgeLayer):
115134
def forward(self, x):
116-
return self._scatter_reduce(x[self.ptrs], "amax")
135+
return self._scatter_forward(x[self.ptrs], "amax")
117136

118137

119138
class LogSumLayer(KnowledgeLayer):
120139
def forward(self, x, epsilon=10e-16):
121-
return self._logsumexp_scatter_reduce(x[self.ptrs], epsilon)
140+
return self._logsumexp_scatter(x[self.ptrs], epsilon)
122141

123142

124143
class ProbabilisticSumLayer(ProbabilisticKnowledgeLayer):
125144
def forward(self, x):
126145
x = self.get_edge_weights() * x[self.ptrs]
127-
return self._scatter_reduce(x, "sum")
146+
return self._scatter_forward(x, "sum")
128147

129148
def get_edge_weights(self):
130149
exp_weights, _ = self._safe_exp(self.weights)
131-
norm = self._scatter_reduce(exp_weights, "sum")
150+
norm = self._scatter_forward(exp_weights, "sum")
132151
return exp_weights / norm[self.csr]
133152

134153

135154
class ProbabilisticLogSumLayer(ProbabilisticKnowledgeLayer):
136155
def forward(self, x, epsilon=10e-16):
137156
x = self.get_edge_weights(epsilon) + x[self.ptrs]
138-
return self._logsumexp_scatter_reduce(x, epsilon)
157+
return self._logsumexp_scatter(x, epsilon)
139158

140159
def get_edge_weights(self, epsilon):
141-
norm = self._logsumexp_scatter_reduce(self.weights, epsilon)
160+
norm = self._logsumexp_scatter(self.weights, epsilon)
142161
return self.weights - norm[self.csr]
143162

163+
def sample_pc(self, y, epsilon=10e-16):
164+
weights = self.get_edge_weights(epsilon)
165+
gumbels = weights - (-torch.rand_like(weights).log()).log()
166+
samples = self._scatter_forward(gumbels, "amax")
167+
samples = samples[self.csr] == gumbels
168+
samples &= y[self.csr].to(torch.bool)
169+
return self._scatter_backward(samples, "sum") > 0
170+
144171

145172
def get_semiring(name: str, probabilistic: bool):
146173
"""

0 commit comments

Comments
 (0)