Skip to content

Commit 397dc91

Browse files
authored
Merge pull request #108 from bashtage/dirchelet-check
BUG: Add check for dirichelet parameters
2 parents f39a4ae + 4b04556 commit 397dc91

File tree

2 files changed

+52
-40
lines changed

2 files changed

+52
-40
lines changed

randomstate/randomstate.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4689,6 +4689,8 @@ cdef class RandomState:
46894689

46904690
k = len(alpha)
46914691
alpha_arr = <np.ndarray>np.PyArray_FROM_OTF(alpha, np.NPY_DOUBLE, np.NPY_ALIGNED)
4692+
if np.any(np.less_equal(alpha_arr, 0)):
4693+
raise ValueError('alpha <= 0')
46924694
alpha_data = <double*>np.PyArray_DATA(alpha_arr)
46934695

46944696
if size is None:

randomstate/tests/test_numpy_mt19937.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from __future__ import division, absolute_import, print_function
22

3-
import numpy as np
4-
from numpy.testing import (
5-
TestCase, run_module_suite, assert_, assert_raises, assert_equal,
6-
assert_warns, assert_no_warnings, assert_array_equal,
7-
assert_array_almost_equal)
83
import sys
94
import warnings
5+
6+
import numpy as np
7+
from numpy.testing import (
8+
TestCase, run_module_suite, assert_, assert_raises, assert_equal,
9+
assert_warns, assert_no_warnings, assert_array_equal,
10+
assert_array_almost_equal)
11+
1012
import randomstate as random
1113
from randomstate.compat import suppress_warnings
1214
from randomstate.prng.mt19937 import mt19937
1315

16+
1417
class TestSeed(TestCase):
1518
def test_scalar(self):
1619
s = mt19937.RandomState(0)
@@ -90,7 +93,7 @@ def setUp(self):
9093
self.seed = 1234567890
9194
self.prng = random.RandomState(self.seed)
9295
self.state = self.prng.get_state()
93-
self.legacy_state = self.prng.get_state(legacy=True) # Use legacy to get old NumPy state
96+
self.legacy_state = self.prng.get_state(legacy=True) # Use legacy to get old NumPy state
9497

9598
def test_basic(self):
9699
old = self.prng.tomaxint(16)
@@ -135,7 +138,6 @@ def test_negative_binomial(self):
135138

136139

137140
class TestRandint(TestCase):
138-
139141
rfunc = random.randint
140142

141143
# valid integer/boolean types
@@ -185,10 +187,10 @@ def test_full_range(self):
185187
def test_in_bounds_fuzz(self):
186188
# Don't use fixed seed
187189
mt19937.seed()
188-
190+
189191
for dt in self.itype[1:]:
190192
for ubnd in [4, 8, 16]:
191-
vals = self.rfunc(2, ubnd, size=2**16, dtype=dt)
193+
vals = self.rfunc(2, ubnd, size=2 ** 16, dtype=dt)
192194
assert_(vals.max() < ubnd)
193195
assert_(vals.min() >= 2)
194196

@@ -481,9 +483,9 @@ def test_beta(self):
481483
mt19937.seed(self.seed)
482484
actual = mt19937.beta(.1, .9, size=(3, 2))
483485
desired = np.array(
484-
[[1.45341850513746058e-02, 5.31297615662868145e-04],
485-
[1.85366619058432324e-06, 4.19214516800110563e-03],
486-
[1.58405155108498093e-04, 1.26252891949397652e-04]])
486+
[[1.45341850513746058e-02, 5.31297615662868145e-04],
487+
[1.85366619058432324e-06, 4.19214516800110563e-03],
488+
[1.58405155108498093e-04, 1.26252891949397652e-04]])
487489
assert_array_almost_equal(actual, desired, decimal=15)
488490

489491
def test_binomial(self):
@@ -513,6 +515,8 @@ def test_dirichlet(self):
513515
[[0.59266909280647828, 0.40733090719352177],
514516
[0.56974431743975207, 0.43025568256024799]]])
515517
assert_array_almost_equal(actual, desired, decimal=15)
518+
bad_alpha = np.array([5.4e-01, -1.0e-16])
519+
assert_raises(ValueError, mt19937.dirichlet, bad_alpha)
516520

517521
def test_dirichlet_size(self):
518522
# gh-3173
@@ -661,7 +665,7 @@ def test_multivariate_normal(self):
661665
cov = [[1, 0], [0, 1]]
662666
size = (3, 2)
663667
actual = mt19937.multivariate_normal(mean, cov, size)
664-
desired = np.array([[[1.463620246718631, 11.73759122771936 ],
668+
desired = np.array([[[1.463620246718631, 11.73759122771936],
665669
[1.622445133300628, 9.771356667546383]],
666670
[[2.154490787682787, 12.170324946056553],
667671
[1.719909438201865, 9.230548443648306]],
@@ -721,7 +725,7 @@ def test_noncentral_chisquare(self):
721725
def test_noncentral_f(self):
722726
mt19937.seed(self.seed)
723727
actual = mt19937.noncentral_f(dfnum=5, dfden=2, nonc=1,
724-
size=(3, 2))
728+
size=(3, 2))
725729
desired = np.array([[1.40598099674926669, 0.34207973179285761],
726730
[3.57715069265772545, 7.92632662577829805],
727731
[0.43741599463544162, 1.1774208752428319]])
@@ -743,9 +747,9 @@ def test_pareto(self):
743747
mt19937.seed(self.seed)
744748
actual = mt19937.pareto(a=.123456789, size=(3, 2))
745749
desired = np.array(
746-
[[2.46852460439034849e+03, 1.41286880810518346e+03],
747-
[5.28287797029485181e+07, 6.57720981047328785e+07],
748-
[1.40840323350391515e+02, 1.98390255135251704e+05]])
750+
[[2.46852460439034849e+03, 1.41286880810518346e+03],
751+
[5.28287797029485181e+07, 6.57720981047328785e+07],
752+
[1.40840323350391515e+02, 1.98390255135251704e+05]])
749753
# For some reason on 32-bit x86 Ubuntu 12.10 the [1, 0] entry in this
750754
# matrix differs by 24 nulps. Discussion:
751755
# http://mail.scipy.org/pipermail/numpy-discussion/2012-September/063801.html
@@ -766,9 +770,9 @@ def test_poisson_exceptions(self):
766770
lambig = np.iinfo('l').max
767771
lamneg = -1
768772
assert_raises(ValueError, mt19937.poisson, lamneg)
769-
assert_raises(ValueError, mt19937.poisson, [lamneg]*10)
773+
assert_raises(ValueError, mt19937.poisson, [lamneg] * 10)
770774
assert_raises(ValueError, mt19937.poisson, lambig)
771-
assert_raises(ValueError, mt19937.poisson, [lambig]*10)
775+
assert_raises(ValueError, mt19937.poisson, [lambig] * 10)
772776

773777
def test_power(self):
774778
mt19937.seed(self.seed)
@@ -857,8 +861,8 @@ def test_uniform_range_bounds(self):
857861

858862
func = mt19937.uniform
859863
assert_raises(OverflowError, func, -np.inf, 0)
860-
assert_raises(OverflowError, func, 0, np.inf)
861-
assert_raises(OverflowError, func, fmin, fmax)
864+
assert_raises(OverflowError, func, 0, np.inf)
865+
assert_raises(OverflowError, func, fmin, fmax)
862866
assert_raises(OverflowError, func, [-np.inf], [0])
863867
assert_raises(OverflowError, func, [0], [np.inf])
864868

@@ -887,7 +891,7 @@ def __int__(self):
887891

888892
throwing_int = np.array(1).view(ThrowingInteger)
889893
assert_raises(TypeError, mt19937.hypergeometric, throwing_int, 1, 1)
890-
894+
891895
def test_vonmises(self):
892896
mt19937.seed(self.seed)
893897
actual = mt19937.vonmises(mu=1.23, kappa=1.54, size=(3, 2))
@@ -1486,6 +1490,7 @@ def test_logseries(self):
14861490
assert_raises(ValueError, logseries, bad_p_one * 3)
14871491
assert_raises(ValueError, logseries, bad_p_two * 3)
14881492

1493+
14891494
class TestThread(TestCase):
14901495
# make sure each state produces the same sequence even in threads
14911496
def setUp(self):
@@ -1516,18 +1521,22 @@ def check_function(self, function, sz):
15161521
def test_normal(self):
15171522
def gen_random(state, out):
15181523
out[...] = state.normal(size=10000)
1524+
15191525
self.check_function(gen_random, sz=(10000,))
15201526

15211527
def test_exp(self):
15221528
def gen_random(state, out):
15231529
out[...] = state.exponential(scale=np.ones((100, 1000)))
1530+
15241531
self.check_function(gen_random, sz=(100, 1000))
15251532

15261533
def test_multinomial(self):
15271534
def gen_random(state, out):
1528-
out[...] = state.multinomial(10, [1/6.]*6, size=10000)
1535+
out[...] = state.multinomial(10, [1 / 6.] * 6, size=10000)
1536+
15291537
self.check_function(gen_random, sz=(10000, 6))
15301538

1539+
15311540
# See Issue #4263
15321541
class TestSingleEltArrayInput(TestCase):
15331542
def setUp(self):
@@ -1582,23 +1591,23 @@ def test_two_arg_funcs(self):
15821591
out = func(self.argOne, argTwo[0])
15831592
self.assertEqual(out.shape, self.tgtShape)
15841593

1585-
# TODO: Uncomment once randint can broadcast arguments
1586-
# def test_randint(self):
1587-
# itype = [bool, np.int8, np.uint8, np.int16, np.uint16,
1588-
# np.int32, np.uint32, np.int64, np.uint64]
1589-
# func = mt19937.randint
1590-
# high = np.array([1])
1591-
# low = np.array([0])
1592-
#
1593-
# for dt in itype:
1594-
# out = func(low, high, dtype=dt)
1595-
# self.assert_equal(out.shape, self.tgtShape)
1596-
#
1597-
# out = func(low[0], high, dtype=dt)
1598-
# self.assert_equal(out.shape, self.tgtShape)
1599-
#
1600-
# out = func(low, high[0], dtype=dt)
1601-
# self.assert_equal(out.shape, self.tgtShape)
1594+
# TODO: Uncomment once randint can broadcast arguments
1595+
# def test_randint(self):
1596+
# itype = [np.bool, np.int8, np.uint8, np.int16, np.uint16,
1597+
# np.int32, np.uint32, np.int64, np.uint64]
1598+
# func = mt19937.randint
1599+
# high = np.array([1])
1600+
# low = np.array([0])
1601+
#
1602+
# for dt in itype:
1603+
# out = func(low, high, dtype=dt)
1604+
# self.assert_equal(out.shape, self.tgtShape)
1605+
#
1606+
# out = func(low[0], high, dtype=dt)
1607+
# self.assert_equal(out.shape, self.tgtShape)
1608+
#
1609+
# out = func(low, high[0], dtype=dt)
1610+
# self.assert_equal(out.shape, self.tgtShape)
16021611

16031612
def test_three_arg_funcs(self):
16041613
funcs = [mt19937.noncentral_f, mt19937.triangular,
@@ -1614,5 +1623,6 @@ def test_three_arg_funcs(self):
16141623
out = func(self.argOne, self.argTwo[0], self.argThree)
16151624
self.assertEqual(out.shape, self.tgtShape)
16161625

1626+
16171627
if __name__ == "__main__":
16181628
run_module_suite()

0 commit comments

Comments
 (0)