@@ -19,15 +19,46 @@ def test_or_node():
1919    assert  m (weights ) ==  0.4  +  (1  -  0.8 )
2020
2121
22- def  test_probabilistic_or_node ():
22+ def  test_disjoin_conjoin ():
2323    c  =  klay .Circuit ()
24-     l1 , l2  =  c .literal_node (1 ), c .literal_node (- 2 )
25-     c .set_root (c .or_node ([l1 , l2 ]))
24+     l1 , l2 , l3  =  c .literal_node (1 ), c .literal_node (- 2 ), c .literal_node (3 )
25+     or_node1  =  c .disjoin ([l1 , l2 , l2 ])
26+     or_node2  =  c .disjoin ([l3 , or_node1 , l3 , or_node1 ])
27+     c .set_root (or_node2 )
28+ 
29+     m  =  c .to_torch_module (semiring = 'real' )
30+     weights  =  torch .tensor ([0.4 , 0.8 , 0.5 ])
31+     expected_result  =  torch .tensor (0.4  +  (1  -  0.8 ) +  0.5 )
32+     assert  torch .allclose (m (weights ), expected_result )
33+ 
34+ 
35+ def  test_probabilistic ():
36+     c  =  klay .Circuit ()
37+     l1 , l2 , l3  =  c .literal_node (1 ), c .literal_node (- 2 ), c .literal_node (3 )
38+     or_node1  =  c .or_node ([l1 , l2 ])
39+     or_node2  =  c .or_node ([l2 , l3 ])
40+     and_node  =  c .and_node ([or_node1 , or_node2 ])
41+     c .set_root (and_node )
2642
2743    m  =  c .to_torch_module (semiring = 'real' , probabilistic = True )
2844    m .layers [1 ].weights .data .zero_ ()
29-     weights  =  torch .tensor ([0.4 , 0.8 ])
30-     assert  m (weights ) ==  0.5  *  0.4  +  0.5  *  (1  -  0.8 )
45+     weights  =  torch .tensor ([0.4 , 0.8 , 0.5 ])
46+     expected_result  =  torch .tensor ((0.4 / 2  +  0.2 / 2 ) *  (0.2 / 2  +  0.5 / 2 ))
47+     assert  torch .allclose (m (weights ), expected_result )
48+ 
49+ def  test_log_probabilistic ():
50+     c  =  klay .Circuit ()
51+     l1 , l2 , l3  =  c .literal_node (1 ), c .literal_node (- 2 ), c .literal_node (3 )
52+     or_node1  =  c .or_node ([l1 , l2 ])
53+     or_node2  =  c .or_node ([l2 , l3 ])
54+     and_node  =  c .and_node ([or_node1 , or_node2 ])
55+     c .set_root (and_node )
56+ 
57+     m  =  c .to_torch_module (semiring = 'log' , probabilistic = True )
58+     m .layers [1 ].weights .data .zero_ ()
59+     weights  =  torch .tensor ([0.4 , 0.8 , 0.5 ])
60+     expected_result  =  torch .tensor ((0.4 / 2  +  0.2 / 2 ) *  (0.2 / 2  +  0.5 / 2 ))
61+     assert  torch .allclose (m (weights .log ()).exp (), expected_result )
3162
3263
3364def  test_multi_rooted ():
0 commit comments