File tree Expand file tree Collapse file tree 1 file changed +21
-1
lines changed Expand file tree Collapse file tree 1 file changed +21
-1
lines changed Original file line number Diff line number Diff line change @@ -107,4 +107,24 @@ Gradients are computed in the usual fashion.
107107    .. code-tab :: Python Jax 
108108
109109        grad_func = jax.jit(jax.grad(func))
110-         grad_func(inputs)
110+         grad_func(inputs)
111+ 
112+ The :code: `inputs ` tensor must contain a weight for each positive literal.
113+ The weights of the negative literals follow from those.
114+ For example for the :code: `reals ` semiring: if :code: `x ` is the weight of literal :code: `l `,
115+ then :code: `1 - x ` is the weight of the negative literal :code: `-l `.
116+ To use other weights, you must provide a separate tensor containing a weight for each negative literal.
117+ 
118+ .. tabs ::
119+ 
120+     .. code-tab :: Python PyTorch 
121+ 
122+         inputs = torch.tensor([...])
123+         neg_inputs = torch.tensor([...])  # assumed 1-inputs otherwise
124+         outputs = module(inputs, neg_inputs)
125+ 
126+     .. code-tab :: Python Jax 
127+ 
128+         inputs = jnp.array([...])
129+         neg_inputs = jnp.array([...])  # assumed 1-inputs otherwise
130+         outputs = func(inputs, neg_inputs)
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments