11import  math 
2+ from  functools  import  reduce 
23
34import  torch 
45from  torch  import  nn 
@@ -38,6 +39,7 @@ class KnowledgeModule(nn.Module):
3839    def  __init__ (self , pointers , csrs , semiring = 'real' , probabilistic = False ):
3940        super (KnowledgeModule , self ).__init__ ()
4041        layers  =  []
42+         self .probabilistic  =  probabilistic 
4143        sum_layer , prod_layer , self .zero , self .one , self .negate  =  get_semiring (semiring , probabilistic )
4244        for  i , (ptrs , csr ) in  enumerate (zip (pointers , csrs )):
4345            ptrs  =  torch .as_tensor (ptrs )
@@ -61,27 +63,41 @@ def sparsity(self, nb_vars: int) -> float:
6163        dense_params   =  sum (layer_widths [i ] *  layer_widths [i + 1 ] for  i  in  range (len (layer_widths ) -  1 ))
6264        return  sparse_params  /  dense_params 
6365
66+     def  sample_pc (self ):
67+         assert  self .probabilistic 
68+         y  =  torch .tensor ([1 ])
69+         for  layer  in  reversed (self .layers ):
70+             y  =  layer .sample_pc (y )
71+         return  y [2 ::2 ]
72+ 
6473
6574class  KnowledgeLayer (nn .Module ):
6675    def  __init__ (self , ptrs , csr ):
6776        super ().__init__ ()
6877        self .register_buffer ('ptrs' , ptrs )
6978        self .register_buffer ('csr' , csr )
7079        self .out_shape  =  (self .csr [- 1 ].item () +  1 ,)
80+         self .in_shape  =  (self .ptrs .max () +  1 ,)
81+ 
82+     def  _scatter_forward (self , x : torch .Tensor , reduce : str ):
83+         output  =  torch .empty (self .out_shape , dtype = x .dtype , device = x .device )
84+         output  =  torch .scatter_reduce (output , 0 , index = self .csr , src = x , reduce = reduce , include_self = False )
85+         return  output 
7186
72-     def  _scatter_reduce (self , src : torch .Tensor , reduce : str ):
73-         output  =  torch .empty (self .out_shape , dtype = src .dtype , device = src .device )
74-         output  =  torch .scatter_reduce (output , 0 , index = self .csr , src = src , reduce = reduce , include_self = False )
87+     def  _scatter_backward (self , x : torch .Tensor , reduce : str ):
88+         output  =  torch .empty (self .in_shape , dtype = x .dtype , device = x .device )
89+         output  =  torch .scatter_reduce (output , 0 , index = self .ptrs , src = x , reduce = reduce , include_self = False )
7590        return  output 
7691
92+ 
7793    def  _safe_exp (self , x : torch .Tensor ):
7894        with  torch .no_grad ():
79-             max_output  =  self ._scatter_reduce (x , "amax" )
95+             max_output  =  self ._scatter_forward (x , "amax" )
8096        x  =  x  -  max_output [self .csr ]
8197        x .nan_to_num_ (nan = 0. , posinf = float ('inf' ), neginf = float ('-inf' ))
8298        return  torch .exp (x ), max_output 
8399
84-     def  _logsumexp_scatter_reduce (self , x : torch .Tensor , epsilon : float ):
100+     def  _logsumexp_scatter (self , x : torch .Tensor , epsilon : float ):
85101        x , max_output  =  self ._safe_exp (x )
86102        output  =  torch .full (self .out_shape , epsilon , dtype = x .dtype , device = x .device )
87103        output  =  torch .scatter_add (output , 0 , index = self .csr , src = x )
@@ -98,49 +114,60 @@ def __init__(self, ptrs, csr):
98114
99115class  SumLayer (KnowledgeLayer ):
100116    def  forward (self , x ):
101-         return  self ._scatter_reduce (x [self .ptrs ], "sum" )
117+         return  self ._scatter_forward (x [self .ptrs ], "sum" )
118+ 
119+     def  sample_pc (self , y ):
120+         return  self ._scatter_backward (y [self .csr ], "sum" ) >  0 
102121
103122
104123class  ProdLayer (KnowledgeLayer ):
105124    def  forward (self , x ):
106-         return  self ._scatter_reduce (x [self .ptrs ], "prod" )
125+         return  self ._scatter_forward (x [self .ptrs ], "prod" )
107126
108127
109128class  MinLayer (KnowledgeLayer ):
110129    def  forward (self , x ):
111-         return  self ._scatter_reduce (x [self .ptrs ], "amin" )
130+         return  self ._scatter_forward (x [self .ptrs ], "amin" )
112131
113132
114133class  MaxLayer (KnowledgeLayer ):
115134    def  forward (self , x ):
116-         return  self ._scatter_reduce (x [self .ptrs ], "amax" )
135+         return  self ._scatter_forward (x [self .ptrs ], "amax" )
117136
118137
119138class  LogSumLayer (KnowledgeLayer ):
120139    def  forward (self , x , epsilon = 10e-16 ):
121-         return  self ._logsumexp_scatter_reduce (x [self .ptrs ], epsilon )
140+         return  self ._logsumexp_scatter (x [self .ptrs ], epsilon )
122141
123142
124143class  ProbabilisticSumLayer (ProbabilisticKnowledgeLayer ):
125144    def  forward (self , x ):
126145        x  =  self .get_edge_weights () *  x [self .ptrs ]
127-         return  self ._scatter_reduce (x , "sum" )
146+         return  self ._scatter_forward (x , "sum" )
128147
129148    def  get_edge_weights (self ):
130149        exp_weights , _  =  self ._safe_exp (self .weights )
131-         norm  =  self ._scatter_reduce (exp_weights , "sum" )
150+         norm  =  self ._scatter_forward (exp_weights , "sum" )
132151        return  exp_weights  /  norm [self .csr ]
133152
134153
135154class  ProbabilisticLogSumLayer (ProbabilisticKnowledgeLayer ):
136155    def  forward (self , x , epsilon = 10e-16 ):
137156        x  =  self .get_edge_weights (epsilon ) +  x [self .ptrs ]
138-         return  self ._logsumexp_scatter_reduce (x , epsilon )
157+         return  self ._logsumexp_scatter (x , epsilon )
139158
140159    def  get_edge_weights (self , epsilon ):
141-         norm  =  self ._logsumexp_scatter_reduce (self .weights , epsilon )
160+         norm  =  self ._logsumexp_scatter (self .weights , epsilon )
142161        return  self .weights  -  norm [self .csr ]
143162
163+     def  sample_pc (self , y , epsilon = 10e-16 ):
164+         weights  =  self .get_edge_weights (epsilon )
165+         gumbels  =  weights  -  (- torch .rand_like (weights ).log ()).log ()
166+         samples  =  self ._scatter_forward (gumbels , "amax" )
167+         samples  =  samples [self .csr ] ==  gumbels 
168+         samples  &=  y [self .csr ].to (torch .bool )
169+         return  self ._scatter_backward (samples , "sum" ) >  0 
170+ 
144171
145172def  get_semiring (name : str , probabilistic : bool ):
146173    """ 
0 commit comments