14
14
import warnings
15
15
16
16
from collections import defaultdict
17
- from collections .abc import Iterable , Sequence
17
+ from collections .abc import Callable , Iterable , Sequence
18
+ from enum import Enum
18
19
from os import path
20
+ from typing import Any
19
21
20
22
from pytensor import function
21
23
from pytensor .graph import Apply
@@ -41,6 +43,119 @@ def fast_eval(var):
41
43
return function ([], var , mode = "FAST_COMPILE" )()
42
44
43
45
46
+ class NodeType (str , Enum ):
47
+ """Enum for the types of nodes in the graph."""
48
+
49
+ POTENTIAL = "Potential"
50
+ FREE_RV = "Free Random Variable"
51
+ OBSERVED_RV = "Observed Random Variable"
52
+ DETERMINISTIC = "Deterministic"
53
+ DATA = "Data"
54
+
55
+
56
+ GraphvizNodeKwargs = dict [str , Any ]
57
+ NodeFormatter = Callable [[TensorVariable ], GraphvizNodeKwargs ]
58
+
59
+
60
+ def default_potential (var : TensorVariable ) -> GraphvizNodeKwargs :
61
+ """Default data for potential in the graph."""
62
+ return {
63
+ "shape" : "octagon" ,
64
+ "style" : "filled" ,
65
+ "label" : f"{ var .name } \n ~\n Potential" ,
66
+ }
67
+
68
+
69
+ def random_variable_symbol (var : TensorVariable ) -> str :
70
+ """Get the symbol of the random variable."""
71
+ symbol = var .owner .op .__class__ .__name__
72
+
73
+ if symbol .endswith ("RV" ):
74
+ symbol = symbol [:- 2 ]
75
+
76
+ return symbol
77
+
78
+
79
+ def default_free_rv (var : TensorVariable ) -> GraphvizNodeKwargs :
80
+ """Default data for free RV in the graph."""
81
+ symbol = random_variable_symbol (var )
82
+
83
+ return {
84
+ "shape" : "ellipse" ,
85
+ "style" : None ,
86
+ "label" : f"{ var .name } \n ~\n { symbol } " ,
87
+ }
88
+
89
+
90
+ def default_observed_rv (var : TensorVariable ) -> GraphvizNodeKwargs :
91
+ """Default data for observed RV in the graph."""
92
+ symbol = random_variable_symbol (var )
93
+
94
+ return {
95
+ "shape" : "ellipse" ,
96
+ "style" : "filled" ,
97
+ "label" : f"{ var .name } \n ~\n { symbol } " ,
98
+ }
99
+
100
+
101
+ def default_deterministic (var : TensorVariable ) -> GraphvizNodeKwargs :
102
+ """Default data for the deterministic in the graph."""
103
+ return {
104
+ "shape" : "box" ,
105
+ "style" : None ,
106
+ "label" : f"{ var .name } \n ~\n Deterministic" ,
107
+ }
108
+
109
+
110
+ def default_data (var : TensorVariable ) -> GraphvizNodeKwargs :
111
+ """Default data for the data in the graph."""
112
+ return {
113
+ "shape" : "box" ,
114
+ "style" : "rounded, filled" ,
115
+ "label" : f"{ var .name } \n ~\n Data" ,
116
+ }
117
+
118
+
119
+ def get_node_type (var_name : VarName , model ) -> NodeType :
120
+ """Return the node type of the variable in the model."""
121
+ v = model [var_name ]
122
+
123
+ if v in model .deterministics :
124
+ return NodeType .DETERMINISTIC
125
+ elif v in model .free_RVs :
126
+ return NodeType .FREE_RV
127
+ elif v in model .observed_RVs :
128
+ return NodeType .OBSERVED_RV
129
+ elif v in model .data_vars :
130
+ return NodeType .DATA
131
+ else :
132
+ return NodeType .POTENTIAL
133
+
134
+
135
+ NodeTypeFormatterMapping = dict [NodeType , NodeFormatter ]
136
+
137
+ DEFAULT_NODE_FORMATTERS : NodeTypeFormatterMapping = {
138
+ NodeType .POTENTIAL : default_potential ,
139
+ NodeType .FREE_RV : default_free_rv ,
140
+ NodeType .OBSERVED_RV : default_observed_rv ,
141
+ NodeType .DETERMINISTIC : default_deterministic ,
142
+ NodeType .DATA : default_data ,
143
+ }
144
+
145
+
146
+ def update_node_formatters (node_formatters : NodeTypeFormatterMapping ) -> NodeTypeFormatterMapping :
147
+ node_formatters = {** DEFAULT_NODE_FORMATTERS , ** node_formatters }
148
+
149
+ unknown_keys = set (node_formatters .keys ()) - set (NodeType )
150
+ if unknown_keys :
151
+ raise ValueError (
152
+ f"Node formatters must be of type NodeType. Found: { list (unknown_keys )} ."
153
+ f" Please use one of { [node_type .value for node_type in NodeType ]} ."
154
+ )
155
+
156
+ return node_formatters
157
+
158
+
44
159
class ModelGraph :
45
160
def __init__ (self , model ):
46
161
self .model = model
@@ -148,42 +263,23 @@ def make_compute_graph(
148
263
149
264
return input_map
150
265
151
- def _make_node (self , var_name , graph , * , nx = False , cluster = False , formatting : str = "plain" ):
266
+ def _make_node (
267
+ self ,
268
+ var_name ,
269
+ graph ,
270
+ * ,
271
+ node_formatters : NodeTypeFormatterMapping ,
272
+ nx = False ,
273
+ cluster = False ,
274
+ formatting : str = "plain" ,
275
+ ):
152
276
"""Attaches the given variable to a graphviz or networkx Digraph"""
153
277
v = self .model [var_name ]
154
278
155
- shape = None
156
- style = None
157
- label = str (v )
158
-
159
- if v in self .model .potentials :
160
- shape = "octagon"
161
- style = "filled"
162
- label = f"{ var_name } \n ~\n Potential"
163
- elif v in self .model .basic_RVs :
164
- shape = "ellipse"
165
- if v in self .model .observed_RVs :
166
- style = "filled"
167
- else :
168
- style = None
169
- symbol = v .owner .op .__class__ .__name__
170
- if symbol .endswith ("RV" ):
171
- symbol = symbol [:- 2 ]
172
- label = f"{ var_name } \n ~\n { symbol } "
173
- elif v in self .model .deterministics :
174
- shape = "box"
175
- style = None
176
- label = f"{ var_name } \n ~\n Deterministic"
177
- else :
178
- shape = "box"
179
- style = "rounded, filled"
180
- label = f"{ var_name } \n ~\n Data"
181
-
182
- kwargs = {
183
- "shape" : shape ,
184
- "style" : style ,
185
- "label" : label ,
186
- }
279
+ node_type = get_node_type (var_name , self .model )
280
+ node_formatter = node_formatters [node_type ]
281
+
282
+ kwargs = node_formatter (v )
187
283
188
284
if cluster :
189
285
kwargs ["cluster" ] = cluster
@@ -240,6 +336,7 @@ def make_graph(
240
336
save = None ,
241
337
figsize = None ,
242
338
dpi = 300 ,
339
+ node_formatters : NodeTypeFormatterMapping | None = None ,
243
340
):
244
341
"""Make graphviz Digraph of PyMC model
245
342
@@ -255,18 +352,26 @@ def make_graph(
255
352
"The easiest way to install all of this is by running\n \n "
256
353
"\t conda install -c conda-forge python-graphviz"
257
354
)
355
+
356
+ node_formatters = node_formatters or {}
357
+ node_formatters = update_node_formatters (node_formatters )
358
+
258
359
graph = graphviz .Digraph (self .model .name )
259
360
for plate_label , all_var_names in self .get_plates (var_names ).items ():
260
361
if plate_label :
261
362
# must be preceded by 'cluster' to get a box around it
262
363
with graph .subgraph (name = "cluster" + plate_label ) as sub :
263
364
for var_name in all_var_names :
264
- self ._make_node (var_name , sub , formatting = formatting )
365
+ self ._make_node (
366
+ var_name , sub , formatting = formatting , node_formatters = node_formatters
367
+ )
265
368
# plate label goes bottom right
266
369
sub .attr (label = plate_label , labeljust = "r" , labelloc = "b" , style = "rounded" )
267
370
else :
268
371
for var_name in all_var_names :
269
- self ._make_node (var_name , graph , formatting = formatting )
372
+ self ._make_node (
373
+ var_name , graph , formatting = formatting , node_formatters = node_formatters
374
+ )
270
375
271
376
for child , parents in self .make_compute_graph (var_names = var_names ).items ():
272
377
# parents is a set of rv names that precede child rv nodes
@@ -287,7 +392,12 @@ def make_graph(
287
392
288
393
return graph
289
394
290
- def make_networkx (self , var_names : Iterable [VarName ] | None = None , formatting : str = "plain" ):
395
+ def make_networkx (
396
+ self ,
397
+ var_names : Iterable [VarName ] | None = None ,
398
+ formatting : str = "plain" ,
399
+ node_formatters : NodeTypeFormatterMapping | None = None ,
400
+ ):
291
401
"""Make networkx Digraph of PyMC model
292
402
293
403
Returns
@@ -302,6 +412,10 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
302
412
"The easiest way to install all of this is by running\n \n "
303
413
"\t conda install networkx"
304
414
)
415
+
416
+ node_formatters = node_formatters or {}
417
+ node_formatters = update_node_formatters (node_formatters )
418
+
305
419
graphnetwork = networkx .DiGraph (name = self .model .name )
306
420
for plate_label , all_var_names in self .get_plates (var_names ).items ():
307
421
if plate_label :
@@ -314,6 +428,7 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
314
428
var_name ,
315
429
subgraphnetwork ,
316
430
nx = True ,
431
+ node_formatters = node_formatters ,
317
432
cluster = "cluster" + plate_label ,
318
433
formatting = formatting ,
319
434
)
@@ -332,7 +447,13 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
332
447
graphnetwork .graph ["name" ] = self .model .name
333
448
else :
334
449
for var_name in all_var_names :
335
- self ._make_node (var_name , graphnetwork , nx = True , formatting = formatting )
450
+ self ._make_node (
451
+ var_name ,
452
+ graphnetwork ,
453
+ nx = True ,
454
+ formatting = formatting ,
455
+ node_formatters = node_formatters ,
456
+ )
336
457
337
458
for child , parents in self .make_compute_graph (var_names = var_names ).items ():
338
459
# parents is a set of rv names that precede child rv nodes
@@ -346,6 +467,7 @@ def model_to_networkx(
346
467
* ,
347
468
var_names : Iterable [VarName ] | None = None ,
348
469
formatting : str = "plain" ,
470
+ node_formatters : NodeTypeFormatterMapping | None = None ,
349
471
):
350
472
"""Produce a networkx Digraph from a PyMC model.
351
473
@@ -367,6 +489,10 @@ def model_to_networkx(
367
489
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
368
490
formatting : str, optional
369
491
one of { "plain" }
492
+ node_formatters : dict, optional
493
+ A dictionary mapping node types to functions that return a dictionary of node attributes.
494
+ Check out the networkx documentation for more information
495
+ how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html
370
496
371
497
Examples
372
498
--------
@@ -392,6 +518,17 @@ def model_to_networkx(
392
518
obs = Normal("obs", theta, sigma=sigma, observed=y)
393
519
394
520
model_to_networkx(schools)
521
+
522
+ Add custom attributes to Free Random Variables and Observed Random Variables nodes.
523
+
524
+ .. code-block:: python
525
+
526
+ node_formatters = {
527
+ "Free Random Variable": lambda var: {"shape": "circle", "label": var.name},
528
+ "Observed Random Variable": lambda var: {"shape": "square", "label": var.name},
529
+ }
530
+ model_to_networkx(schools, node_formatters=node_formatters)
531
+
395
532
"""
396
533
if "plain" not in formatting :
397
534
raise ValueError (f"Unsupported formatting for graph nodes: '{ formatting } '. See docstring." )
@@ -403,7 +540,9 @@ def model_to_networkx(
403
540
stacklevel = 2 ,
404
541
)
405
542
model = pm .modelcontext (model )
406
- return ModelGraph (model ).make_networkx (var_names = var_names , formatting = formatting )
543
+ return ModelGraph (model ).make_networkx (
544
+ var_names = var_names , formatting = formatting , node_formatters = node_formatters
545
+ )
407
546
408
547
409
548
def model_to_graphviz (
@@ -414,6 +553,7 @@ def model_to_graphviz(
414
553
save : str | None = None ,
415
554
figsize : tuple [int , int ] | None = None ,
416
555
dpi : int = 300 ,
556
+ node_formatters : NodeTypeFormatterMapping | None = None ,
417
557
):
418
558
"""Produce a graphviz Digraph from a PyMC model.
419
559
@@ -441,6 +581,10 @@ def model_to_graphviz(
441
581
the size of the saved figure.
442
582
dpi : int, optional
443
583
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
584
+ node_formatters : dict, optional
585
+ A dictionary mapping node types to functions that return a dictionary of node attributes.
586
+ Check out graphviz documentation for more information on available
587
+ attributes. https://graphviz.org/docs/nodes/
444
588
445
589
Examples
446
590
--------
@@ -475,6 +619,16 @@ def model_to_graphviz(
475
619
476
620
# creates the file `schools.pdf`
477
621
model_to_graphviz(schools).render("schools")
622
+
623
+ Display Free Random Variables and Observed Random Variables nodes with custom formatting.
624
+
625
+ .. code-block:: python
626
+
627
+ node_formatters = {
628
+ "Free Random Variable": lambda var: {"shape": "circle", "label": var.name},
629
+ "Observed Random Variable": lambda var: {"shape": "square", "label": var.name},
630
+ }
631
+ model_to_graphviz(schools, node_formatters=node_formatters)
478
632
"""
479
633
if "plain" not in formatting :
480
634
raise ValueError (f"Unsupported formatting for graph nodes: '{ formatting } '. See docstring." )
@@ -491,4 +645,5 @@ def model_to_graphviz(
491
645
save = save ,
492
646
figsize = figsize ,
493
647
dpi = dpi ,
648
+ node_formatters = node_formatters ,
494
649
)
0 commit comments