Skip to content

Commit 0fa4037

Browse files
committed
docs: add description on negative lit weights.
1 parent 5460ee5 commit 0fa4037

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

docs/circuit_eval.rst

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,24 @@ Gradients are computed in the usual fashion.
107107
.. code-tab:: Python Jax
108108

109109
grad_func = jax.jit(jax.grad(func))
110-
grad_func(inputs)
110+
grad_func(inputs)
111+
112+
The :code:`inputs` tensor must contain a weight for each positive literal.
113+
The weights of the negative literals follow from those.
114+
For example for the :code:`reals` semiring: if :code:`x` is the weight of literal :code:`l`,
115+
then :code:`1 - x` is the weight of the negative literal :code:`-l`.
116+
To use other weights, you must provide a separate tensor containing a weight for each negative literal.
117+
118+
.. tabs::
119+
120+
.. code-tab:: Python PyTorch
121+
122+
inputs = torch.tensor([...])
123+
neg_inputs = torch.tensor([...]) # assumed 1-inputs otherwise
124+
outputs = module(inputs, neg_inputs)
125+
126+
.. code-tab:: Python Jax
127+
128+
inputs = jnp.array([...])
129+
neg_inputs = jnp.array([...]) # assumed 1-inputs otherwise
130+
outputs = func(inputs, neg_inputs)

0 commit comments

Comments
 (0)