@@ -13,13 +13,13 @@ figwidth, figheight = mpl.rcParams["figure.figsize"]
1313```
1414
1515``` python
16- fig, ax = plt.subplots()
16+ fig, ax = plt.subplots(figsize = (figwidth, figheight) )
1717
1818vspace = 1
1919hspace = 2.1
2020kwargs = {
2121 " ha" : " center" ,
22- " va" : " center" ,
22+ " va" : " center" ,
2323 " fontsize" : " small" ,
2424}
2525fkwargs = kwargs | {
@@ -64,7 +64,7 @@ texts = {
6464 " mdn-compressed-samples" : (2 * hspace, - 7 * vspace, " MDN-compressed samples\n $\\ tilde\\ theta\\ in\\ mathbb{R} ^p\\ sim \\ tilde f\\ left(\\ theta\\ mid t(y)\\ right)$" , vkwargs),
6565 " estimator" : (0 , - 6 * vspace, " mixture density\n network $h:\\ mathbb{R} ^q\\ rightarrow\\ mathcal{F} $" , fkwargs),
6666 " estimate" : (0 , - 7 * vspace, " density estimate\n $\\ hat f\\ left(\\ theta\\ mid t(z)\\ right)\\ in \\ mathcal{F} $" , vkwargs),
67- " loss" : (- hspace / 2 , - 8 * vspace, " NLP loss" , fkwargs),
67+ " loss" : (- hspace / 2 , - 8 * vspace, " NLP loss" , fkwargs),
6868}
6969
7070elements = {}
@@ -80,12 +80,12 @@ if True:
8080 ax.set_xlim(xmin, xmax)
8181 ax.set_ylim(ymin, ymax)
8282 ax.set_aspect(" equal" )
83-
83+
8484 # Then adjust based on the actual extent of the box containing the text.
8585 fig.tight_layout()
8686 fig.draw_without_rendering()
8787 transform = ax.transData.inverted()
88- extents = np.asarray([transform.transform(element.get_bbox_patch().get_window_extent())
88+ extents = np.asarray([transform.transform(element.get_bbox_patch().get_window_extent())
8989 for element in elements.values()])
9090 xmin, ymin = extents.min(axis = 0 )[0 ]
9191 xmax, ymax = extents.max(axis = 0 )[1 ]
@@ -105,12 +105,12 @@ connections = [
105105 [(" params" , 6 ), (" simulator" , 12 )],
106106 [(" simulator" , 6 ), (" simulated_data" , 12 )],
107107 [
108- (" simulated_data" , 3 ),
108+ (" simulated_data" , 3 ),
109109 (get_anchor(elements[" compressor" ], 11 ).x, get_anchor(elements[" simulated_data" ], 3 ).y),
110110 (" compressor" , 11 ),
111111 ],
112112 [
113- (" compressor" , 7 ),
113+ (" compressor" , 7 ),
114114 (get_anchor(elements[" compressor" ], 7 ).x, get_anchor(elements[" simulated_summaries" ], 2.75 ).y),
115115 (" simulated_summaries" , 2.75 ),
116116 ],
@@ -119,7 +119,7 @@ connections = [
119119 (get_anchor(elements[" compressor" ], 6 ).x, get_anchor(elements[" simulated_summaries" ], 3.25 ).y),
120120 (get_anchor(elements[" compressor" ], 6 ).x, get_anchor(elements[" abc" ], 9 ).y),
121121 (" abc" , 9 ),
122-
122+
123123 ],
124124 [(" simulated_summaries" , 6 ), (" estimator" , 12 )],
125125 [(" estimator" , 6 ), (" estimate" , 12 )],
@@ -136,12 +136,12 @@ connections = [
136136
137137 # Observed.
138138 [
139- (" observed_data" , 9 ),
139+ (" observed_data" , 9 ),
140140 (get_anchor(elements[" compressor" ], 1 ).x, get_anchor(elements[" observed_data" ], 9 ).y),
141141 (" compressor" , 1 ),
142142 ],
143143 [
144- (" compressor" , 5 ),
144+ (" compressor" , 5 ),
145145 (get_anchor(elements[" compressor" ], 5 ).x, get_anchor(elements[" observed_summaries" ], 9 ).y),
146146 (" observed_summaries" , 9 ),
147147 ],
@@ -198,7 +198,7 @@ for (row, col), cell in celld.items():
198198 x = cell.get_x()
199199 y = cell.get_y() - (row + 1 ) * (height - original_height)
200200 phantom = mpl.patches.Rectangle(
201- (x, y), cell.get_width(), cell.get_height(), transform = ax.transAxes,
201+ (x, y), cell.get_width(), cell.get_height(), transform = ax.transAxes,
202202 facecolor = color, zorder = - 1 )
203203 ax.add_patch(phantom)
204204
0 commit comments