You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/circuit_eval.rst
+33-23Lines changed: 33 additions & 23 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -8,9 +8,7 @@ Backends
8
8
9
9
Once we have created a circuit, we can start using it. KLay relies on a backend to perform the inference. Currently, the PyTorch and Jax backends are implemented.
10
10
11
-
.. tabs::
12
-
13
-
.. group-tab:: PyTorch
11
+
.. tab:: PyTorch
14
12
15
13
We can turn the circuit into a PyTorch module as follows.
16
14
@@ -24,9 +22,9 @@ Once we have created a circuit, we can start using it. KLay relies on a backend
24
22
25
23
module = module.to("cuda:0")
26
24
27
-
.. group-tab:: Jax
25
+
.. tab:: Jax
28
26
29
-
We can turn the circuit into a jax function as follows.
27
+
We can turn the circuit into a Jax function as follows.
30
28
31
29
.. code-block:: Python
32
30
@@ -39,38 +37,44 @@ product and sum nodes just compute the normal product and sum operations.
39
37
40
38
KLay doesn't introduce a batch dimension by default. So use vmap to perform batched inference.
41
39
42
-
.. tabs::
40
+
.. tab:: PyTorch
43
41
44
-
.. code-tab:: Python PyTorch
42
+
.. code-block:: Python
45
43
46
44
module = torch.vmap(module)
47
45
48
-
.. code-tab:: Python Jax
46
+
.. tab:: Jax
47
+
48
+
.. code-block:: Python
49
49
50
50
func = jax.vmap(func)
51
51
52
52
To achieve best runtime performance, it is advisable to use JIT compilation.
Klay also supports `probabilistic circuits <https://starai.cs.ucla.edu/papers/ProbCirc20.pdf>`_, which have weights associated with the edges of sum nodes.
@@ -83,28 +87,32 @@ The input to the circuit should be a tensor with as size the number of input lit
83
87
Note that when using the :code:`log` semiring, the inputs are log-probabilities, while in the :code:`real` or :code:`mpe` semiring the inputs should be probabilities.
84
88
In case you are using a probabilistic circuit, you should likely have some input distributions producing these (log-)probabilities prior to the circuit.
85
89
86
-
.. tabs::
90
+
.. tab:: PyTorch
87
91
88
-
.. code-tab:: Python PyTorch
92
+
.. code-block:: Python
89
93
90
94
inputs = torch.tensor([...])
91
95
outputs = module(inputs)
92
96
93
-
.. code-tab:: Python Jax
97
+
.. tab:: Jax
98
+
99
+
.. code-block:: Python
94
100
95
101
inputs = jnp.array([...])
96
102
outputs = func(inputs)
97
103
98
104
Gradients are computed in the usual fashion.
99
105
100
-
.. tabs::
106
+
.. tab:: PyTorch
101
107
102
-
.. code-tab:: Python PyTorch
108
+
.. code-block:: Python
103
109
104
110
outputs = func(inputs)
105
111
outputs.backward()
106
112
107
-
.. code-tab:: Python Jax
113
+
.. tab:: Jax
114
+
115
+
.. code-block:: Python
108
116
109
117
grad_func = jax.jit(jax.grad(func))
110
118
grad_func(inputs)
@@ -115,15 +123,17 @@ For example for the :code:`real` semiring: if :code:`x` is the weight of literal
115
123
then :code:`1 - x` is the weight of the negative literal :code:`-l`.
116
124
To use other weights, you must provide a separate tensor containing a weight for each negative literal.
0 commit comments