Skip to content

Commit 61ce2e2

Browse files
committed
more robustness classes and better class names
1 parent 1479f53 commit 61ce2e2

File tree

2 files changed

+82
-35
lines changed

2 files changed

+82
-35
lines changed

probscale/probscale.py

Lines changed: 78 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,29 @@
77
NullLocator,
88
Formatter,
99
NullFormatter,
10-
FuncFormatter
10+
FuncFormatter,
1111
)
1212

1313

14+
def _mask_non_prop(a):
15+
"""
16+
Return a Numpy array where all values outside ]0, 1[ are
17+
replaced with NaNs. If all values are inside ]0, 1[, the original
18+
array is returned.
19+
"""
20+
mask = (a <= 0.0) | (a >= 1.0)
21+
if mask.any():
22+
return np.where(mask, np.nan, a)
23+
return a
24+
25+
26+
def _clip_non_positives(a):
27+
a = np.array(a, float)
28+
a[a <= 0.0] = 1e-300
29+
a[a >= 1.0] = 1 - 1e-300
30+
return a
31+
32+
1433
class _minimal_norm(object):
1534
_A = -(8 * (np.pi - 3.0) / (3.0 * np.pi * (np.pi - 4.0)))
1635

@@ -44,7 +63,6 @@ def ppf(cls, q):
4463
Wikipedia: https://goo.gl/Rtxjme
4564
4665
"""
47-
4866
return np.sqrt(2) * cls._approx_inv_erf(2*q - 1)
4967

5068
@classmethod
@@ -57,7 +75,7 @@ def cdf(cls, x):
5775
return 0.5 * (1 + cls._approx_erf(x/np.sqrt(2)))
5876

5977

60-
class ProbFormatter(Formatter):
78+
class _FormatterMixin(Formatter):
6179
@classmethod
6280
def _sig_figs(cls, x, n, expthresh=5, forceint=False):
6381
""" Formats a number with the correct number of sig figs.
@@ -128,49 +146,67 @@ def _sig_figs(cls, x, n, expthresh=5, forceint=False):
128146
return out
129147

130148
def __call__(self, x, pos=None):
131-
if x < 10:
149+
if x < (10 / self.factor):
132150
out = self._sig_figs(x, 1)
133-
elif x <= 99:
151+
elif x <= (99 / self.factor):
134152
out = self._sig_figs(x, 2)
135153
else:
136-
order = np.ceil(np.round(np.abs(np.log10(100 - x)), 6))
137-
out = self._sig_figs(x, order + 2)
154+
order = np.ceil(np.round(np.abs(np.log10(self.top - x)), 6))
155+
out = self._sig_figs(x, order + self.offset)
138156

139157
return '{}'.format(out)
140158

141159

142-
class ProbTransform(Transform):
160+
class PctFormatter(_FormatterMixin):
161+
factor = 1.0
162+
offset = 2
163+
top = 100
164+
165+
166+
class ProbFormatter(_FormatterMixin):
167+
factor = 100.0
168+
offset = 0
169+
top = 1
170+
171+
172+
class _ProbTransformMixin(Transform):
143173
input_dims = 1
144174
output_dims = 1
145175
is_separable = True
146176
has_inverse = True
147177

148-
def __init__(self, dist):
178+
def __init__(self, dist, as_pct=True, nonpos='mask'):
149179
Transform.__init__(self)
150180
self.dist = dist
181+
if as_pct:
182+
self.factor = 100.0
183+
else:
184+
self.factor = 1.0
151185

152-
def transform_non_affine(self, a):
153-
return self.dist.ppf(a / 100.)
186+
if nonpos == 'mask':
187+
self._handle_nonpos = _mask_non_positives
188+
elif nonpos == 'clip':
189+
self._handle_nonpos = _clip_non_positives
190+
else:
191+
raise ValueError("`nonpos` muse be either 'mask' or 'clip'")
154192

155-
def inverted(self):
156-
return InvertedProbTransform(self.dist)
157193

194+
class ProbTransform(_ProbTransformMixin):
195+
def transform_non_affine(self, prob):
196+
q = self.dist.ppf(prob / self.factor)
197+
return q
158198

159-
class InvertedProbTransform(Transform):
160-
input_dims = 1
161-
output_dims = 1
162-
is_separable = True
163-
has_inverse = True
199+
def inverted(self):
200+
return QuantileTransform(self.dist, as_pct=self.as_pct, nonpos=self.nonpos)
164201

165-
def __init__(self, dist):
166-
self.dist = dist
167-
Transform.__init__(self)
168202

169-
def transform_non_affine(self, a):
170-
return self.dist.cdf(a) * 100.
203+
class QuantileTransform(_ProbTransformMixin):
204+
def transform_non_affine(self, q):
205+
prob = self.dist.cdf(q) * self.factor
206+
return prob
171207

172208
def inverted(self):
173-
return ProbTransform(self.dist)
209+
return ProbTransform(self.dist, as_pct=self.as_pct, nonpos=self.nonpos)
174210

175211

176212
class ProbScale(ScaleBase):
@@ -199,13 +235,19 @@ class ProbScale(ScaleBase):
199235

200236
def __init__(self, axis, **kwargs):
201237
self.dist = kwargs.pop('dist', _minimal_norm)
202-
self._transform = ProbTransform(self.dist)
238+
self.as_pct = kwargs.pop('as_pct', True)
239+
self.nonpos = kwargs.pop('nonpos', 'mask')
240+
self._transform = ProbTransform(self.dist, as_pct=self.as_pct)
203241

204242
@classmethod
205-
def _get_probs(cls, nobs):
243+
def _get_probs(cls, nobs, as_pct):
206244
""" Returns the x-axis labels for a probability plot based on
207245
the number of observations (`nobs`).
208246
"""
247+
if as_pct:
248+
factor = 1.0
249+
else:
250+
factor = 100.0
209251

210252
order = int(np.floor(np.log10(nobs)))
211253
base_probs = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90])
@@ -219,19 +261,23 @@ def _get_probs(cls, nobs):
219261
lower_fringe = np.array([1])
220262
upper_fringe = np.array([9])
221263

222-
new_lower = lower_fringe/10**(n)
223-
new_upper = upper_fringe/10**(n) + axis_probs.max()
264+
new_lower = lower_fringe / 10**(n)
265+
new_upper = upper_fringe / 10**(n) + axis_probs.max()
224266
axis_probs = np.hstack([new_lower, axis_probs, new_upper])
225-
226-
return axis_probs
267+
locs = axis_probs / factor
268+
return locs
227269

228270
def set_default_locators_and_formatters(self, axis):
229271
"""
230272
Set the locators and formatters to specialized versions for
231273
log scaling.
232274
"""
233-
axis.set_major_locator(FixedLocator(self._get_probs(1e10)))
234-
axis.set_major_formatter(FuncFormatter(ProbFormatter()))
275+
276+
axis.set_major_locator(FixedLocator(self._get_probs(1e8, self.as_pct)))
277+
if self.as_pct:
278+
axis.set_major_formatter(FuncFormatter(PctFormatter()))
279+
else:
280+
axis.set_major_formatter(FuncFormatter(ProbFormatter()))
235281
axis.set_minor_locator(NullLocator())
236282
axis.set_minor_formatter(NullFormatter())
237283

probscale/tests/test_probscale/test_probscale.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_cdf(self):
7373

7474

7575
class Mixin_ProbFormatter_sig_figs(object):
76-
fmt = probscale.ProbFormatter()
76+
fmt = probscale.PctFormatter()
7777
def teardown(self):
7878
pass
7979

@@ -110,6 +110,7 @@ def test_forceint(self):
110110

111111
def test__call__(self):
112112
nt.assert_equal(self.fmt(0.0301), '0.03')
113+
nt.assert_equal(self.fmt(0.2), '0.2')
113114
nt.assert_equal(self.fmt(0.1), '0.1')
114115
nt.assert_equal(self.fmt(10), '10')
115116
nt.assert_equal(self.fmt(5), '5')
@@ -181,7 +182,7 @@ def setup(self):
181182
self.known_tras_na = -2.569150498
182183

183184

184-
class Test_InvertedProbTransform(Mixin_Transform):
185+
class Test_QuantileTransform(Mixin_Transform):
185186
def setup(self):
186-
self.trans = probscale.InvertedProbTransform(probscale._minimal_norm)
187+
self.trans = probscale.QuantileTransform(probscale._minimal_norm)
187188
self.known_tras_na = 69.1464492

0 commit comments

Comments
 (0)