Skip to content

Commit 6b43a9b

Browse files
committed
use the validation module in viz.py
1 parent ac8ed07 commit 6b43a9b

File tree

2 files changed

+8
-112
lines changed

2 files changed

+8
-112
lines changed

probscale/tests/test_probscale/test_viz.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -24,61 +24,6 @@ def setup_plot_data():
2424
return data
2525

2626

27-
class Test__check_ax_obj(object):
28-
@nt.raises(ValueError)
29-
def test_bad_value(self):
30-
viz._check_ax_obj('junk')
31-
32-
@cleanup
33-
def test_with_ax(self):
34-
fig, ax = plt.subplots()
35-
fig1, ax1 = viz._check_ax_obj(ax)
36-
nt.assert_true(isinstance(ax1, plt.Axes))
37-
nt.assert_true(isinstance(fig1, plt.Figure))
38-
nt.assert_true(ax1 is ax)
39-
nt.assert_true(fig1 is fig)
40-
41-
@cleanup
42-
def test_with_None(self):
43-
fig1, ax1 = viz._check_ax_obj(None)
44-
nt.assert_true(isinstance(ax1, plt.Axes))
45-
nt.assert_true(isinstance(fig1, plt.Figure))
46-
47-
48-
class Test__check_fit_arg(object):
49-
@nt.raises(ValueError)
50-
def test_bad_fitarg(self):
51-
viz._check_fit_arg('junk', 'fitprobs')
52-
53-
def test_x(self):
54-
nt.assert_equal('x', viz._check_fit_arg('x', 'fitprobs'))
55-
nt.assert_equal('x', viz._check_fit_arg('x', 'fitlogs'))
56-
57-
def test_y(self):
58-
nt.assert_equal('y', viz._check_fit_arg('y', 'fitprobs'))
59-
nt.assert_equal('y', viz._check_fit_arg('y', 'fitlogs'))
60-
61-
def test_both(self):
62-
nt.assert_equal('both', viz._check_fit_arg('both', 'fitprobs'))
63-
nt.assert_equal('both', viz._check_fit_arg('both', 'fitlogs'))
64-
65-
def test_None(self):
66-
nt.assert_true(viz._check_fit_arg(None, 'fitprobs') is None)
67-
nt.assert_true(viz._check_fit_arg(None, 'fitlogs') is None)
68-
69-
70-
class Test__check_ax_name(object):
71-
@nt.raises
72-
def test_bad_name(self):
73-
viz._check_fit_arg('junk', 'axname')
74-
75-
def test_x(self):
76-
nt.assert_equal('x', viz._check_fit_arg('x', 'axname'))
77-
78-
def test_y(self):
79-
nt.assert_equal('y', viz._check_fit_arg('y', 'axname'))
80-
81-
8227
class Test__fit_line(object):
8328
def setup(self):
8429
self.data = np.array([
@@ -271,17 +216,6 @@ def test_linlog(self):
271216
)
272217

273218

274-
class Test__check_ax_type(object):
275-
@nt.raises(ValueError)
276-
def test_bad_value(self):
277-
viz._check_ax_type("JUNK")
278-
279-
def test_upper(self):
280-
nt.assert_equal('pp', viz._check_ax_type('PP'))
281-
nt.assert_equal('qq', viz._check_ax_type('QQ'))
282-
nt.assert_equal('prob', viz._check_ax_type('ProB'))
283-
284-
285219
@image_comparison(baseline_images=['test_probplot_prob'], extensions=['png'])
286220
def test_probplot_prob():
287221
fig, ax = plt.subplots()

probscale/viz.py

Lines changed: 8 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,14 @@
33
from matplotlib import scale
44
from scipy import stats
55

6-
from .probscale import ProbScale, _minimal_norm
6+
from .probscale import ProbScale
7+
from .probscale import _minimal_norm
8+
from . import validate
79

810

911
scale.register_scale(ProbScale)
1012

1113

12-
def _check_ax_obj(ax):
13-
""" Checks if a value if an Axes. If None, a new one is created.
14-
15-
"""
16-
17-
if ax is None:
18-
fig, ax = pyplot.subplots()
19-
elif isinstance(ax, pyplot.Axes):
20-
fig = ax.figure
21-
else:
22-
msg = "`ax` must be a matplotlib Axes instance or None"
23-
raise ValueError(msg)
24-
25-
return fig, ax
26-
27-
28-
def _check_fit_arg(arg, argname):
29-
valid_args = ['x', 'y', 'both', None]
30-
if arg not in valid_args:
31-
msg = 'Invalid value for {} ({}). Must be on of {}.'
32-
raise ValueError(msg.format(argname, arg, valid_args))
33-
34-
return arg
35-
36-
37-
def _check_ax_name(axname, argname):
38-
valid_args = ['x', 'y']
39-
if axname.lower() not in valid_args:
40-
msg = 'Invalid value for {} ({}). Must be on of {}.'
41-
raise ValueError(msg.format(argname, arg, valid_args))
42-
43-
return axname.lower()
44-
45-
46-
def _check_ax_type(axtype):
47-
if axtype.lower() not in ['pp', 'qq', 'prob']:
48-
raise ValueError("invalid axtype: {}".format(axtype))
49-
return axtype.lower()
50-
51-
5214
def probplot(data, ax=None, axtype='prob', probax='x',
5315
otherscale='linear', xlabel=None, ylabel=None,
5416
bestfit=False, return_results=False,
@@ -99,15 +61,15 @@ def probplot(data, ax=None, axtype='prob', probax='x',
9961
"""
10062

10163
# check input values
102-
fig, ax = _check_ax_obj(ax)
103-
probax = _check_ax_name(probax, 'probax')
64+
fig, ax = validate.axes_object(ax)
65+
probax = validate.axis_name(probax, 'x')
10466

10567
# default values for plotting options
10668
scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
10769
line_kws = {} if line_kws is None else line_kws.copy()
10870

10971
# check axtype
110-
axtype = _check_ax_type(axtype)
72+
axtype = validate.axis_type(axtype)
11173

11274
# compute the plotting positions and sort the data
11375
qntls, datavals = stats.probplot(data, fit=False)
@@ -205,8 +167,8 @@ def _fit_line(x, y, xhat=None, fitprobs=None, fitlogs=None, dist=None):
205167
206168
"""
207169

208-
fitprobs = _check_fit_arg(fitprobs, "fitprobs")
209-
fitlogs = _check_fit_arg(fitlogs, "fitlogs")
170+
fitprobs = validate.fit_argument(fitprobs, "fitprobs")
171+
fitlogs = validate.fit_argument(fitlogs, "fitlogs")
210172

211173
# maybe set xhat to default values
212174
if xhat is None:

0 commit comments

Comments
 (0)