Skip to content

Commit 2a417a0

Browse files
committed
polished docs a little
1 parent 5e87d6b commit 2a417a0

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

src/circuit.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ nb::class_<NodePtr>(m, "NodePtr")
479479

480480
nb::class_<Circuit>(m, "Circuit", "Circuits are the main class added by KLay, and require no arguments to construct.\n\n:code:`circuit = klay.Circuit()` ")
481481
.def(nb::init<>())
482-
.def("add_sdd_from_file", &Circuit::add_sdd_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>(), "Add an SDD circuit from file.\n\n:param filename:\n\tPath to the :code:`.sdd` file on disk.\n:param true_lits:\n\tList of literals that are always true and should get propagated away.\n:param false_lits:\n\tList of literals that are always false and should get propagated away.")
482+
.def("add_sdd_from_file", &Circuit::add_sdd_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>(), "Add a sentential decision diagram (SDD) from file.\n\n:param filename:\n\tPath to the :code:`.sdd` file on disk.\n:param true_lits:\n\tList of literals that are always true and should get propagated away.\n:param false_lits:\n\tList of literals that are always false and should get propagated away.")
483483
.def("add_d4_from_file", &Circuit::add_d4_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>(), "Add an NNF circuit in the D4 format from file.\n\n:param filename:\n\tPath to the :code:`.nnf` file on disk.\n:param true_lits:\n\tList of literals that are always true and should get propagated away.\n:param false_lits:\n\tList of literals that are always false and should get propagated away.")
484484
.def("_get_indices", &Circuit::get_indices)
485485
.def("nb_nodes", &Circuit::nb_nodes, "Number of nodes in the circuit.")
@@ -489,8 +489,8 @@ nb::class_<Circuit>(m, "Circuit", "Circuits are the main class added by KLay, an
489489
.def("literal_node", &Circuit::literal_node, "Adds a literal node to the circuit, and returns a pointer to this node.", "literal"_a)
490490
.def("or_node", &Circuit::or_node, "children"_a, "Adds an :code:`or` node to the circuit, and returns a pointer to this node.")
491491
.def("and_node", &Circuit::and_node, "children"_a, "Adds an :code:`and` node to the circuit, and returns a pointer to this node.")
492-
.def("set_root", &Circuit::set_root, "root"_a, "Marks a node pointer as root. The order in which nodes are set as root determines the order of the output tensor. Only use this when manually constructing a circuit, when loading in a NNF/SDD its root is automatically set as root.")
493-
.def("remove_unused_nodes", &Circuit::remove_unused_nodes, "Removes unused non-root nodes from the circuit.\nWarning: this invalidates any :code:`NodePtr` referring to an unused node (i.e., a node not connected to a root node).");
492+
.def("set_root", &Circuit::set_root, "root"_a, "Marks a node pointer as root. The order in which nodes are set as root determines the order of the output tensor.\n .. note:: Only use this when manually constructing a circuit, when loading in a NNF/SDD its root is automatically set as root.\n")
493+
.def("remove_unused_nodes", &Circuit::remove_unused_nodes, "Removes unused nodes from the circuit. Root nodes are always considered used.\n .. warning:: Invalidates any :code:`NodePtr` referring to an unused node (i.e., a node not connected to a root node).\n");
494494

495495
m.def("to_dot_file", &to_dot_file, "circuit"_a, "filename"_a, "Write the given circuit as dot format to a file");
496496
}

src/klay/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool =
1010
Convert the circuit into a PyTorch module.
1111
1212
:param semiring:
13-
The semiring in which the circuit should be evaluated. Supported options are ("log", "real", "mpe", "godel").
13+
The semiring in which the circuit should be evaluated. Supported options are :code:`"log"`, :code:`"real"`, :code:`"mpe"`, or :code:`"godel"`.
1414
:param probabilistic:
15-
If true, construct a probabilistic circuit instead of an arithmetic circuit.
15+
If enabled, construct a probabilistic circuit instead of an arithmetic circuit.
1616
This means the inputs to a sum node are multiplied by a probability, and
1717
we can interpret sum nodes as latent Categorical variables.
1818
"""
@@ -26,28 +26,33 @@ def to_jax_function(self: Circuit, semiring: str = "log"):
2626
Convert the circuit into a Jax function.
2727
2828
:param semiring:
29-
The semiring in which the circuit should be evaluated. Supported options are ("log", "real", "mpe", "godel").
29+
The semiring in which the circuit should be evaluated. Supported options are :code:`"log"`, :code:`"real"`, :code:`"mpe"`, or :code:`"godel"`.
3030
"""
3131
from .jax import create_knowledge_layer
3232
indices = self._get_indices()
3333
return create_knowledge_layer(*indices, semiring=semiring)
3434

3535

36-
def add_sdd(self: Circuit, sdd: "SddNode", true_lits: Sequence[int] = (), false_lits: Sequence[int] = ()):
36+
def add_sdd(self: Circuit, sdd: "SddNode", true_lits: Sequence[int] = (), false_lits: Sequence[int] = ()) -> NodePtr:
3737
"""
3838
Add an SDD to the Circuit.
3939
40+
:param sdd:
41+
PySDD `SDDNode`_ to be added.
4042
:param true_lits:
4143
List of literals that are always true and should get propagated away.
4244
:param false_lits:
4345
List of literals that are always false and should get propagated away.
46+
47+
.. _SDDNode: https://pysdd.readthedocs.io/en/latest/classes/SddNode.html
4448
"""
4549
import os
4650
from pathlib import Path
4751

4852
sdd.save(bytes(Path("tmp.sdd")))
49-
self.add_sdd_from_file("tmp.sdd", true_lits, false_lits)
53+
root = self.add_sdd_from_file("tmp.sdd", true_lits, false_lits)
5054
os.remove("tmp.sdd")
55+
return root
5156

5257

5358
Circuit.to_torch_module = to_torch_module

0 commit comments

Comments
 (0)