Skip to content

Commit 093cde0

Browse files
committed
nicer docs template
1 parent f754f1b commit 093cde0

File tree

4 files changed

+41
-29
lines changed

4 files changed

+41
-29
lines changed

docs/circuit_eval.rst

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ Backends
88

99
Once we have created a circuit, we can start using it. KLay relies on a backend to perform the inference. Currently, the PyTorch and Jax backends are implemented.
1010

11-
.. tabs::
12-
13-
.. group-tab:: PyTorch
11+
.. tab:: PyTorch
1412

1513
We can turn the circuit into a PyTorch module as follows.
1614

@@ -24,9 +22,9 @@ Once we have created a circuit, we can start using it. KLay relies on a backend
2422
2523
module = module.to("cuda:0")
2624
27-
.. group-tab:: Jax
25+
.. tab:: Jax
2826

29-
We can turn the circuit into a jax function as follows.
27+
We can turn the circuit into a Jax function as follows.
3028

3129
.. code-block:: Python
3230
@@ -39,38 +37,44 @@ product and sum nodes just compute the normal product and sum operations.
3937

4038
KLay doesn't introduce a batch dimension by default. So use vmap to perform batched inference.
4139

42-
.. tabs::
40+
.. tab:: PyTorch
4341

44-
.. code-tab:: Python PyTorch
42+
.. code-block:: Python
4543
4644
module = torch.vmap(module)
4745
48-
.. code-tab:: Python Jax
46+
.. tab:: Jax
47+
48+
.. code-block:: Python
4949
5050
func = jax.vmap(func)
5151
5252
To achieve best runtime performance, it is advisable to use JIT compilation.
5353

54-
.. tabs::
54+
.. tab:: PyTorch
5555

56-
.. code-tab:: Python PyTorch
56+
.. code-block:: Python
5757
5858
module = torch.compile(module, mode="reduce-overhead")
5959
60-
.. code-tab:: Python Jax
60+
.. tab:: Jax
61+
62+
.. code-block:: Python
6163
6264
func = jax.jit(func)
6365
6466
6567
Klay also supports `probabilistic circuits <https://starai.cs.ucla.edu/papers/ProbCirc20.pdf>`_, which have weights associated with the edges of sum nodes.
6668

67-
.. tabs::
69+
.. tab:: PyTorch
6870

69-
.. code-tab:: Python PyTorch
71+
.. code-block:: Python
7072
7173
module2 = circuit.to_torch_module(semiring="real", probabilistic=True)
7274
73-
.. code-tab:: Python Jax
75+
.. tab:: Jax
76+
77+
.. code-block:: Python
7478
7579
# Warning: not yet implemented!
7680
func2 = circuit.to_jax_module(semiring="real", probabilistic=True)
@@ -83,28 +87,32 @@ The input to the circuit should be a tensor with as size the number of input lit
8387
Note that when using the :code:`log` semiring, the inputs are log-probabilities, while in the :code:`real` or :code:`mpe` semiring the inputs should be probabilities.
8488
In case you are using a probabilistic circuit, you should likely have some input distributions producing these (log-)probabilities prior to the circuit.
8589

86-
.. tabs::
90+
.. tab:: PyTorch
8791

88-
.. code-tab:: Python PyTorch
92+
.. code-block:: Python
8993
9094
inputs = torch.tensor([...])
9195
outputs = module(inputs)
9296
93-
.. code-tab:: Python Jax
97+
.. tab:: Jax
98+
99+
.. code-block:: Python
94100
95101
inputs = jnp.array([...])
96102
outputs = func(inputs)
97103
98104
Gradients are computed in the usual fashion.
99105

100-
.. tabs::
106+
.. tab:: PyTorch
101107

102-
.. code-tab:: Python PyTorch
108+
.. code-block:: Python
103109
104110
outputs = func(inputs)
105111
outputs.backward()
106112
107-
.. code-tab:: Python Jax
113+
.. tab:: Jax
114+
115+
.. code-block:: Python
108116
109117
grad_func = jax.jit(jax.grad(func))
110118
grad_func(inputs)
@@ -115,15 +123,17 @@ For example for the :code:`real` semiring: if :code:`x` is the weight of literal
115123
then :code:`1 - x` is the weight of the negative literal :code:`-l`.
116124
To use other weights, you must provide a separate tensor containing a weight for each negative literal.
117125

118-
.. tabs::
126+
.. tab:: PyTorch
119127

120-
.. code-tab:: Python PyTorch
128+
.. code-block:: Python
121129
122130
inputs = torch.tensor([...])
123131
neg_inputs = torch.tensor([...]) # assumed 1-inputs otherwise
124132
outputs = module(inputs, neg_inputs)
125133
126-
.. code-tab:: Python Jax
134+
.. tab:: Jax
135+
136+
.. code-block:: Python
127137
128138
inputs = jnp.array([...])
129139
neg_inputs = jnp.array([...]) # assumed 1-inputs otherwise

docs/conf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
project = 'KLay'
1010
copyright = '2025, DTAI Research Group'
1111
author = 'Jaron Maene, Vincent Derkinderen, Pedro Zuidberg Dos Martires'
12-
release = '0.1'
12+
release = '0.0.2'
1313

1414
# -- General configuration ---------------------------------------------------
1515
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
@@ -20,7 +20,7 @@
2020
"sphinx.ext.autodoc",
2121
"sphinx.ext.autosummary",
2222
"sphinx.ext.intersphinx",
23-
"sphinx_tabs.tabs"
23+
"sphinx_inline_tabs"
2424
]
2525

2626
templates_path = ['_templates']
@@ -31,5 +31,5 @@
3131
# -- Options for HTML output -------------------------------------------------
3232
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
3333

34-
html_theme = 'alabaster'
34+
html_theme = 'furo'
3535
html_static_path = ['_static']

docs/requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
sphinx
2-
sphinx-tabs
3-
klaycircuits
2+
sphinx-inline-tabs
3+
klaycircuits
4+
furo

src/klay/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# noinspection PyUnresolvedReferences
2-
from .nanobind_ext import Circuit
2+
from .nanobind_ext import Circuit, NodePtr
3+
NodePtr.__module__ = "klay"
34

45
from collections.abc import Sequence
56

0 commit comments

Comments
 (0)