Skip to content

Commit f87ac0f

Browse files
Pickle support. (And Dynamic Inheritance Concept)
1 parent e9e90a9 commit f87ac0f

File tree

4 files changed

+203
-17
lines changed

4 files changed

+203
-17
lines changed

matplotview/_utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import functools
2+
3+
def _fix_super_reference(function):
4+
"""
5+
Private utility decorator. Allows a function to be transferred to another
6+
class by dynamically updating the local __class__ attribute of the function
7+
when called. This allows for use of zero argument super in all methods
8+
9+
Parameters:
10+
-----------
11+
function
12+
The function to wrap, to allow for the dynamic
13+
"""
14+
@functools.wraps(function)
15+
def run_function(self, *args, **kwargs):
16+
try:
17+
cls_idx = function.__code__.co_freevars.index('__class__')
18+
old_value = function.__closure__[cls_idx].cell_contents
19+
function.__closure__[cls_idx].cell_contents = type(self)
20+
res = function(self, *args, **kwargs)
21+
function.__closure__[cls_idx].cell_contents = old_value
22+
return res
23+
except (AttributeError, ValueError):
24+
return function(self, *args, **kwargs)
25+
26+
return run_function
27+
28+
29+
class _WrappingType(type):
30+
def __new__(mcs, *args, **kwargs):
31+
res = super().__new__(mcs, *args, **kwargs)
32+
33+
res.__base_wrapping__ = getattr(
34+
res, "__base_wrapping__", res.__bases__[0]
35+
)
36+
res.__instances__ = getattr(res, "__instances__", {})
37+
38+
return res
39+
40+
def __getitem__(cls, the_type):
41+
if(cls.__instances__ is None):
42+
raise TypeError("Already instantiated wrapper!")
43+
44+
if(the_type == cls.__base_wrapping__):
45+
return cls
46+
47+
if(issubclass(super().__class__, _WrappingType)):
48+
return cls._gen_type(super()[the_type])
49+
50+
if(not issubclass(the_type, cls.__base_wrapping__)):
51+
raise TypeError(
52+
f"The extension type {the_type} must be a subclass of "
53+
f"{cls.__base_wrapping__}"
54+
)
55+
56+
return cls._gen_type(the_type)
57+
58+
def _gen_type(cls, the_type):
59+
if(the_type not in cls.__instances__):
60+
cls.__instances__[the_type] = _WrappingType(
61+
f"{cls.__name__}[{the_type.__name__}]",
62+
(the_type,),
63+
{"__instances__": None}
64+
)
65+
cls._copy_attrs_to(cls.__instances__[the_type])
66+
67+
return cls.__instances__[the_type]
68+
69+
def _copy_attrs_to(cls, other):
70+
dont_copy = {"__dict__", "__weakref__", "__instances__"}
71+
72+
for k, v in cls.__dict__.items():
73+
if(k not in dont_copy):
74+
setattr(
75+
other,
76+
k,
77+
_fix_super_reference(v) if(hasattr(v, "__code__")) else v
78+
)
79+
80+
other.__instances__ = None
81+
82+
def __iter__(cls):
83+
return NotImplemented

matplotview/_view_axes.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import itertools
23
from typing import Type, List, Optional, Callable, Any
34
from matplotlib.axes import Axes
@@ -77,7 +78,12 @@ def do_3d_projection(self) -> float:
7778

7879
return res
7980

81+
def _view_from_pickle(builder, args):
82+
res = builder(*args)
83+
res.__class__ = view_wrapper(type(res))
84+
return res
8085

86+
@functools.lru_cache(None)
8187
def view_wrapper(axes_class: Type[Axes]) -> Type[Axes]:
8288
"""
8389
Construct a ViewAxes, which subclasses, or wraps a specific Axes subclass.
@@ -95,15 +101,12 @@ def view_wrapper(axes_class: Type[Axes]) -> Type[Axes]:
95101
The view axes wrapper for a given axes class, capable of display
96102
other axes contents...
97103
"""
98-
99104
@docstring.interpd
100-
class ViewAxesImpl(axes_class):
105+
class View(axes_class):
101106
"""
102107
An axes which automatically displays elements of another axes. Does not
103108
require Artists to be plotted twice.
104109
"""
105-
__module__ = axes_class.__module__
106-
107110
def __init__(
108111
self,
109112
axes_to_view: Axes,
@@ -113,7 +116,7 @@ def __init__(
113116
filter_function: Optional[Callable[[Artist], bool]] = None,
114117
**kwargs
115118
):
116-
f"""
119+
"""
117120
Construct a new view axes.
118121
119122
Parameters
@@ -170,10 +173,10 @@ def _init_vars(
170173
):
171174
if(render_depth < 1):
172175
raise ValueError(f"Render depth of {render_depth} is invalid.")
173-
if(filter_function is None):
174-
filter_function = lambda a: True
175-
if(not callable(filter_function)):
176-
raise ValueError(f"The filter function must be a callable!")
176+
if(filter_function is not None and not callable(filter_function)):
177+
raise ValueError(
178+
f"The filter function must be a callable or None!"
179+
)
177180

178181
self.__view_axes = axes_to_view
179182
# The current render depth is stored in the figure, so the number
@@ -212,13 +215,14 @@ def get_children(self) -> List[Artist]:
212215
for a in itertools.chain(
213216
self.__view_axes._children,
214217
self.__view_axes.child_axes
215-
) if(self.__filter_function(a))
218+
) if(self.__filter_function is None
219+
or self.__filter_function(a))
216220
])
217221

218222
return init_list
219223
else:
220224
return super().get_children()
221-
225+
222226
def draw(self, renderer: RendererBase = None):
223227
# It is possible to have two axes which are views of each other
224228
# therefore we track the number of recursions and stop drawing
@@ -260,6 +264,26 @@ def set_linescaling(self, value: bool):
260264
"""
261265
self.__scale_lines = value
262266

267+
def __reduce__(self):
268+
builder, args = super().__reduce__()[:2]
269+
270+
if(type(self) in args):
271+
builder = super().__new__
272+
args = (type(self).__bases__[0],)
273+
274+
return (
275+
_view_from_pickle,
276+
(builder, args),
277+
self.__getstate__()
278+
)
279+
280+
def __getstate__(self):
281+
state = super().__getstate__()
282+
state["__renderer"] = None
283+
# We don't support pickling the filter...
284+
state["__filter_function"] = None
285+
return state
286+
263287
@classmethod
264288
def from_axes(
265289
cls,
@@ -276,10 +300,7 @@ def from_axes(
276300
)
277301
return axes
278302

279-
new_name = f"{ViewAxesImpl.__name__}[{axes_class.__name__}]"
280-
ViewAxesImpl.__name__ = ViewAxesImpl.__qualname__ = new_name
281-
282-
return ViewAxesImpl
283-
303+
View.__name__ = f"{View.__name__}[{axes_class.__name__}]"
304+
View.__qualname__ = f"{View.__qualname__}[{axes_class.__name__}]"
284305

285-
ViewAxes = view_wrapper(Axes)
306+
return View

matplotview/tests/test_view_obj.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import matplotlib.pyplot as plt
2+
import pickle
3+
from matplotview import view, view_wrapper, inset_zoom_axes
4+
import numpy as np
5+
6+
def to_image(figure):
7+
figure.canvas.draw()
8+
img = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8)
9+
return img.reshape(figure.canvas.get_width_height()[::-1] + (3,))
10+
11+
12+
def test_obj_comparison():
13+
from matplotlib.axes import Subplot, Axes
14+
15+
view_class1 = view_wrapper(Subplot)
16+
view_class2 = view_wrapper(Subplot)
17+
view_class3 = view_wrapper(Axes)
18+
19+
assert view_class1 is view_class2
20+
assert view_class1 == view_class2
21+
assert view_class2 != view_class3
22+
23+
24+
def test_subplot_view_pickle():
25+
np.random.seed(1)
26+
im_data = np.random.rand(30, 30)
27+
28+
# Test case...
29+
fig_test, (ax_test1, ax_test2) = plt.subplots(1, 2)
30+
31+
ax_test1.plot([i for i in range(10)], "r")
32+
ax_test1.add_patch(plt.Circle((3, 3), 1, ec="black", fc="blue"))
33+
ax_test1.text(10, 10, "Hello World!", size=14)
34+
ax_test1.imshow(im_data, origin="lower", cmap="Blues", alpha=0.5,
35+
interpolation="nearest")
36+
ax_test2 = view(ax_test2, ax_test1)
37+
ax_test2.set_aspect(ax_test1.get_aspect())
38+
ax_test2.set_xlim(ax_test1.get_xlim())
39+
ax_test2.set_ylim(ax_test1.get_ylim())
40+
41+
img_expected = to_image(fig_test)
42+
43+
saved_fig = pickle.dumps(fig_test)
44+
plt.clf()
45+
46+
fig_test = pickle.loads(saved_fig)
47+
img_result = to_image(fig_test)
48+
49+
assert np.all(img_expected == img_result)
50+
51+
52+
def test_zoom_plot_pickle():
53+
np.random.seed(1)
54+
plt.clf()
55+
im_data = np.random.rand(30, 30)
56+
57+
# Test Case...
58+
fig_test = plt.gcf()
59+
ax_test = fig_test.gca()
60+
ax_test.plot([i for i in range(10)], "r")
61+
ax_test.add_patch(plt.Circle((3, 3), 1, ec="black", fc="blue"))
62+
ax_test.imshow(im_data, origin="lower", cmap="Blues", alpha=0.5,
63+
interpolation="nearest")
64+
axins_test = inset_zoom_axes(ax_test, [0.5, 0.5, 0.48, 0.48])
65+
axins_test.set_linescaling(False)
66+
axins_test.set_xlim(1, 5)
67+
axins_test.set_ylim(1, 5)
68+
ax_test.indicate_inset_zoom(axins_test, edgecolor="black")
69+
70+
fig_test.savefig("before.png")
71+
img_expected = to_image(fig_test)
72+
73+
saved_fig = pickle.dumps(fig_test)
74+
plt.clf()
75+
76+
fig_test = pickle.loads(saved_fig)
77+
fig_test.savefig("after.png")
78+
img_result = to_image(fig_test)
79+
80+
assert np.all(img_expected == img_result)

matplotview/tests/test_inset_zoom.py renamed to matplotview/tests/test_view_rendering.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def test_3d_view(fig_test, fig_ref):
136136
ax2_ref.set_ylim(-10, 10)
137137
ax2_ref.set_zlim(-2, 2)
138138

139+
139140
@check_figures_equal()
140141
def test_polar_view(fig_test, fig_ref):
141142
r = np.arange(0, 2, 0.01)
@@ -156,6 +157,7 @@ def test_polar_view(fig_test, fig_ref):
156157
ax_r2.plot(theta, r)
157158
ax_r2.set_rmax(1)
158159

160+
159161
@check_figures_equal()
160162
def test_map_projection_view(fig_test, fig_ref):
161163
x = np.linspace(-2.5, 2.5, 20)

0 commit comments

Comments
 (0)