Skip to content

Commit 82eae9a

Browse files
Allow customizing style of model_graph nodes (#7302)
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
1 parent e42d35a commit 82eae9a

File tree

2 files changed

+246
-39
lines changed

2 files changed

+246
-39
lines changed

pymc/model_graph.py

Lines changed: 194 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
import warnings
1515

1616
from collections import defaultdict
17-
from collections.abc import Iterable, Sequence
17+
from collections.abc import Callable, Iterable, Sequence
18+
from enum import Enum
1819
from os import path
20+
from typing import Any
1921

2022
from pytensor import function
2123
from pytensor.graph import Apply
@@ -41,6 +43,119 @@ def fast_eval(var):
4143
return function([], var, mode="FAST_COMPILE")()
4244

4345

46+
class NodeType(str, Enum):
47+
"""Enum for the types of nodes in the graph."""
48+
49+
POTENTIAL = "Potential"
50+
FREE_RV = "Free Random Variable"
51+
OBSERVED_RV = "Observed Random Variable"
52+
DETERMINISTIC = "Deterministic"
53+
DATA = "Data"
54+
55+
56+
GraphvizNodeKwargs = dict[str, Any]
57+
NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs]
58+
59+
60+
def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
61+
"""Default data for potential in the graph."""
62+
return {
63+
"shape": "octagon",
64+
"style": "filled",
65+
"label": f"{var.name}\n~\nPotential",
66+
}
67+
68+
69+
def random_variable_symbol(var: TensorVariable) -> str:
70+
"""Get the symbol of the random variable."""
71+
symbol = var.owner.op.__class__.__name__
72+
73+
if symbol.endswith("RV"):
74+
symbol = symbol[:-2]
75+
76+
return symbol
77+
78+
79+
def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs:
80+
"""Default data for free RV in the graph."""
81+
symbol = random_variable_symbol(var)
82+
83+
return {
84+
"shape": "ellipse",
85+
"style": None,
86+
"label": f"{var.name}\n~\n{symbol}",
87+
}
88+
89+
90+
def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs:
91+
"""Default data for observed RV in the graph."""
92+
symbol = random_variable_symbol(var)
93+
94+
return {
95+
"shape": "ellipse",
96+
"style": "filled",
97+
"label": f"{var.name}\n~\n{symbol}",
98+
}
99+
100+
101+
def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs:
102+
"""Default data for the deterministic in the graph."""
103+
return {
104+
"shape": "box",
105+
"style": None,
106+
"label": f"{var.name}\n~\nDeterministic",
107+
}
108+
109+
110+
def default_data(var: TensorVariable) -> GraphvizNodeKwargs:
111+
"""Default data for the data in the graph."""
112+
return {
113+
"shape": "box",
114+
"style": "rounded, filled",
115+
"label": f"{var.name}\n~\nData",
116+
}
117+
118+
119+
def get_node_type(var_name: VarName, model) -> NodeType:
120+
"""Return the node type of the variable in the model."""
121+
v = model[var_name]
122+
123+
if v in model.deterministics:
124+
return NodeType.DETERMINISTIC
125+
elif v in model.free_RVs:
126+
return NodeType.FREE_RV
127+
elif v in model.observed_RVs:
128+
return NodeType.OBSERVED_RV
129+
elif v in model.data_vars:
130+
return NodeType.DATA
131+
else:
132+
return NodeType.POTENTIAL
133+
134+
135+
NodeTypeFormatterMapping = dict[NodeType, NodeFormatter]
136+
137+
DEFAULT_NODE_FORMATTERS: NodeTypeFormatterMapping = {
138+
NodeType.POTENTIAL: default_potential,
139+
NodeType.FREE_RV: default_free_rv,
140+
NodeType.OBSERVED_RV: default_observed_rv,
141+
NodeType.DETERMINISTIC: default_deterministic,
142+
NodeType.DATA: default_data,
143+
}
144+
145+
146+
def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTypeFormatterMapping:
147+
node_formatters = {**DEFAULT_NODE_FORMATTERS, **node_formatters}
148+
149+
unknown_keys = set(node_formatters.keys()) - set(NodeType)
150+
if unknown_keys:
151+
raise ValueError(
152+
f"Node formatters must be of type NodeType. Found: {list(unknown_keys)}."
153+
f" Please use one of {[node_type.value for node_type in NodeType]}."
154+
)
155+
156+
return node_formatters
157+
158+
44159
class ModelGraph:
45160
def __init__(self, model):
46161
self.model = model
@@ -148,42 +263,23 @@ def make_compute_graph(
148263

149264
return input_map
150265

151-
def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: str = "plain"):
266+
def _make_node(
267+
self,
268+
var_name,
269+
graph,
270+
*,
271+
node_formatters: NodeTypeFormatterMapping,
272+
nx=False,
273+
cluster=False,
274+
formatting: str = "plain",
275+
):
152276
"""Attaches the given variable to a graphviz or networkx Digraph"""
153277
v = self.model[var_name]
154278

155-
shape = None
156-
style = None
157-
label = str(v)
158-
159-
if v in self.model.potentials:
160-
shape = "octagon"
161-
style = "filled"
162-
label = f"{var_name}\n~\nPotential"
163-
elif v in self.model.basic_RVs:
164-
shape = "ellipse"
165-
if v in self.model.observed_RVs:
166-
style = "filled"
167-
else:
168-
style = None
169-
symbol = v.owner.op.__class__.__name__
170-
if symbol.endswith("RV"):
171-
symbol = symbol[:-2]
172-
label = f"{var_name}\n~\n{symbol}"
173-
elif v in self.model.deterministics:
174-
shape = "box"
175-
style = None
176-
label = f"{var_name}\n~\nDeterministic"
177-
else:
178-
shape = "box"
179-
style = "rounded, filled"
180-
label = f"{var_name}\n~\nData"
181-
182-
kwargs = {
183-
"shape": shape,
184-
"style": style,
185-
"label": label,
186-
}
279+
node_type = get_node_type(var_name, self.model)
280+
node_formatter = node_formatters[node_type]
281+
282+
kwargs = node_formatter(v)
187283

188284
if cluster:
189285
kwargs["cluster"] = cluster
@@ -240,6 +336,7 @@ def make_graph(
240336
save=None,
241337
figsize=None,
242338
dpi=300,
339+
node_formatters: NodeTypeFormatterMapping | None = None,
243340
):
244341
"""Make graphviz Digraph of PyMC model
245342
@@ -255,18 +352,26 @@ def make_graph(
255352
"The easiest way to install all of this is by running\n\n"
256353
"\tconda install -c conda-forge python-graphviz"
257354
)
355+
356+
node_formatters = node_formatters or {}
357+
node_formatters = update_node_formatters(node_formatters)
358+
258359
graph = graphviz.Digraph(self.model.name)
259360
for plate_label, all_var_names in self.get_plates(var_names).items():
260361
if plate_label:
261362
# must be preceded by 'cluster' to get a box around it
262363
with graph.subgraph(name="cluster" + plate_label) as sub:
263364
for var_name in all_var_names:
264-
self._make_node(var_name, sub, formatting=formatting)
365+
self._make_node(
366+
var_name, sub, formatting=formatting, node_formatters=node_formatters
367+
)
265368
# plate label goes bottom right
266369
sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded")
267370
else:
268371
for var_name in all_var_names:
269-
self._make_node(var_name, graph, formatting=formatting)
372+
self._make_node(
373+
var_name, graph, formatting=formatting, node_formatters=node_formatters
374+
)
270375

271376
for child, parents in self.make_compute_graph(var_names=var_names).items():
272377
# parents is a set of rv names that precede child rv nodes
@@ -287,7 +392,12 @@ def make_graph(
287392

288393
return graph
289394

290-
def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting: str = "plain"):
395+
def make_networkx(
396+
self,
397+
var_names: Iterable[VarName] | None = None,
398+
formatting: str = "plain",
399+
node_formatters: NodeTypeFormatterMapping | None = None,
400+
):
291401
"""Make networkx Digraph of PyMC model
292402
293403
Returns
@@ -302,6 +412,10 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
302412
"The easiest way to install all of this is by running\n\n"
303413
"\tconda install networkx"
304414
)
415+
416+
node_formatters = node_formatters or {}
417+
node_formatters = update_node_formatters(node_formatters)
418+
305419
graphnetwork = networkx.DiGraph(name=self.model.name)
306420
for plate_label, all_var_names in self.get_plates(var_names).items():
307421
if plate_label:
@@ -314,6 +428,7 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
314428
var_name,
315429
subgraphnetwork,
316430
nx=True,
431+
node_formatters=node_formatters,
317432
cluster="cluster" + plate_label,
318433
formatting=formatting,
319434
)
@@ -332,7 +447,13 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
332447
graphnetwork.graph["name"] = self.model.name
333448
else:
334449
for var_name in all_var_names:
335-
self._make_node(var_name, graphnetwork, nx=True, formatting=formatting)
450+
self._make_node(
451+
var_name,
452+
graphnetwork,
453+
nx=True,
454+
formatting=formatting,
455+
node_formatters=node_formatters,
456+
)
336457

337458
for child, parents in self.make_compute_graph(var_names=var_names).items():
338459
# parents is a set of rv names that precede child rv nodes
@@ -346,6 +467,7 @@ def model_to_networkx(
346467
*,
347468
var_names: Iterable[VarName] | None = None,
348469
formatting: str = "plain",
470+
node_formatters: NodeTypeFormatterMapping | None = None,
349471
):
350472
"""Produce a networkx Digraph from a PyMC model.
351473
@@ -367,6 +489,10 @@ def model_to_networkx(
367489
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
368490
formatting : str, optional
369491
one of { "plain" }
492+
node_formatters : dict, optional
493+
A dictionary mapping node types to functions that return a dictionary of node attributes.
494+
Check out the networkx documentation for more information
495+
how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html
370496
371497
Examples
372498
--------
@@ -392,6 +518,17 @@ def model_to_networkx(
392518
obs = Normal("obs", theta, sigma=sigma, observed=y)
393519
394520
model_to_networkx(schools)
521+
522+
Add custom attributes to Free Random Variables and Observed Random Variables nodes.
523+
524+
.. code-block:: python
525+
526+
node_formatters = {
527+
"Free Random Variable": lambda var: {"shape": "circle", "label": var.name},
528+
"Observed Random Variable": lambda var: {"shape": "square", "label": var.name},
529+
}
530+
model_to_networkx(schools, node_formatters=node_formatters)
531+
395532
"""
396533
if "plain" not in formatting:
397534
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
@@ -403,7 +540,9 @@ def model_to_networkx(
403540
stacklevel=2,
404541
)
405542
model = pm.modelcontext(model)
406-
return ModelGraph(model).make_networkx(var_names=var_names, formatting=formatting)
543+
return ModelGraph(model).make_networkx(
544+
var_names=var_names, formatting=formatting, node_formatters=node_formatters
545+
)
407546

408547

409548
def model_to_graphviz(
@@ -414,6 +553,7 @@ def model_to_graphviz(
414553
save: str | None = None,
415554
figsize: tuple[int, int] | None = None,
416555
dpi: int = 300,
556+
node_formatters: NodeTypeFormatterMapping | None = None,
417557
):
418558
"""Produce a graphviz Digraph from a PyMC model.
419559
@@ -441,6 +581,10 @@ def model_to_graphviz(
441581
the size of the saved figure.
442582
dpi : int, optional
443583
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
584+
node_formatters : dict, optional
585+
A dictionary mapping node types to functions that return a dictionary of node attributes.
586+
Check out graphviz documentation for more information on available
587+
attributes. https://graphviz.org/docs/nodes/
444588
445589
Examples
446590
--------
@@ -475,6 +619,16 @@ def model_to_graphviz(
475619
476620
# creates the file `schools.pdf`
477621
model_to_graphviz(schools).render("schools")
622+
623+
Display Free Random Variables and Observed Random Variables nodes with custom formatting.
624+
625+
.. code-block:: python
626+
627+
node_formatters = {
628+
"Free Random Variable": lambda var: {"shape": "circle", "label": var.name},
629+
"Observed Random Variable": lambda var: {"shape": "square", "label": var.name},
630+
}
631+
model_to_graphviz(schools, node_formatters=node_formatters)
478632
"""
479633
if "plain" not in formatting:
480634
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
@@ -491,4 +645,5 @@ def model_to_graphviz(
491645
save=save,
492646
figsize=figsize,
493647
dpi=dpi,
648+
node_formatters=node_formatters,
494649
)

0 commit comments

Comments
 (0)