Skip to content

Commit 937086b

Browse files
committed
add some tests
1 parent 397fd8d commit 937086b

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

tests/test_manual.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,46 @@ def test_or_node():
1919
assert m(weights) == 0.4 + (1 - 0.8)
2020

2121

22-
def test_probabilistic_or_node():
22+
def test_disjoin_conjoin():
2323
c = klay.Circuit()
24-
l1, l2 = c.literal_node(1), c.literal_node(-2)
25-
c.set_root(c.or_node([l1, l2]))
24+
l1, l2, l3 = c.literal_node(1), c.literal_node(-2), c.literal_node(3)
25+
or_node1 = c.disjoin([l1, l2, l2])
26+
or_node2 = c.disjoin([l3, or_node1, l3, or_node1])
27+
c.set_root(or_node2)
28+
29+
m = c.to_torch_module(semiring='real')
30+
weights = torch.tensor([0.4, 0.8, 0.5])
31+
expected_result = torch.tensor(0.4 + (1 - 0.8) + 0.5)
32+
assert torch.allclose(m(weights), expected_result)
33+
34+
35+
def test_probabilistic():
36+
c = klay.Circuit()
37+
l1, l2, l3 = c.literal_node(1), c.literal_node(-2), c.literal_node(3)
38+
or_node1 = c.or_node([l1, l2])
39+
or_node2 = c.or_node([l2, l3])
40+
and_node = c.and_node([or_node1, or_node2])
41+
c.set_root(and_node)
2642

2743
m = c.to_torch_module(semiring='real', probabilistic=True)
2844
m.layers[1].weights.data.zero_()
29-
weights = torch.tensor([0.4, 0.8])
30-
assert m(weights) == 0.5 * 0.4 + 0.5 * (1 - 0.8)
45+
weights = torch.tensor([0.4, 0.8, 0.5])
46+
expected_result = torch.tensor((0.4/2 + 0.2/2) * (0.2/2 + 0.5/2))
47+
assert torch.allclose(m(weights), expected_result)
48+
49+
def test_log_probabilistic():
50+
c = klay.Circuit()
51+
l1, l2, l3 = c.literal_node(1), c.literal_node(-2), c.literal_node(3)
52+
or_node1 = c.or_node([l1, l2])
53+
or_node2 = c.or_node([l2, l3])
54+
and_node = c.and_node([or_node1, or_node2])
55+
c.set_root(and_node)
56+
57+
m = c.to_torch_module(semiring='log', probabilistic=True)
58+
m.layers[1].weights.data.zero_()
59+
weights = torch.tensor([0.4, 0.8, 0.5])
60+
expected_result = torch.tensor((0.4/2 + 0.2/2) * (0.2/2 + 0.5/2))
61+
assert torch.allclose(m(weights.log()).exp(), expected_result)
3162

3263

3364
def test_multi_rooted():

0 commit comments

Comments
 (0)