Skip to content

Commit 1567a8e

Browse files
Sheppard, KevinSheppard, Kevin
authored andcommitted
BUG: Add check for dirichelet parameters
Verify parameters are > 0 as required Port fo NumPy
1 parent a8ac5f4 commit 1567a8e

File tree

2 files changed

+52
-40
lines changed

2 files changed

+52
-40
lines changed

randomstate/randomstate.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4689,6 +4689,8 @@ cdef class RandomState:
46894689

46904690
k = len(alpha)
46914691
alpha_arr = <np.ndarray>np.PyArray_FROM_OTF(alpha, np.NPY_DOUBLE, np.NPY_ALIGNED)
4692+
if np.any(np.less_equal(alpha_arr, 0)):
4693+
raise ValueError('alpha <= 0')
46924694
alpha_data = <double*>np.PyArray_DATA(alpha_arr)
46934695

46944696
if size is None:

randomstate/tests/test_numpy_mt19937.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from __future__ import division, absolute_import, print_function
22

3-
import numpy as np
4-
from numpy.testing import (
5-
TestCase, run_module_suite, assert_, assert_raises, assert_equal,
6-
assert_warns, assert_no_warnings, assert_array_equal,
7-
assert_array_almost_equal)
83
import sys
94
import warnings
5+
6+
import numpy as np
7+
from numpy.testing import (
8+
TestCase, run_module_suite, assert_, assert_raises, assert_equal,
9+
assert_warns, assert_no_warnings, assert_array_equal,
10+
assert_array_almost_equal)
11+
1012
import randomstate as random
1113
from randomstate.compat import suppress_warnings
1214
from randomstate.prng.mt19937 import mt19937
1315

16+
1417
class TestSeed(TestCase):
1518
def test_scalar(self):
1619
s = mt19937.RandomState(0)
@@ -90,7 +93,7 @@ def setUp(self):
9093
self.seed = 1234567890
9194
self.prng = random.RandomState(self.seed)
9295
self.state = self.prng.get_state()
93-
self.legacy_state = self.prng.get_state(legacy=True) # Use legacy to get old NumPy state
96+
self.legacy_state = self.prng.get_state(legacy=True) # Use legacy to get old NumPy state
9497

9598
def test_basic(self):
9699
old = self.prng.tomaxint(16)
@@ -135,7 +138,6 @@ def test_negative_binomial(self):
135138

136139

137140
class TestRandint(TestCase):
138-
139141
rfunc = random.randint
140142

141143
# valid integer/boolean types
@@ -185,10 +187,10 @@ def test_full_range(self):
185187
def test_in_bounds_fuzz(self):
186188
# Don't use fixed seed
187189
mt19937.seed()
188-
190+
189191
for dt in self.itype[1:]:
190192
for ubnd in [4, 8, 16]:
191-
vals = self.rfunc(2, ubnd, size=2**16, dtype=dt)
193+
vals = self.rfunc(2, ubnd, size=2 ** 16, dtype=dt)
192194
assert_(vals.max() < ubnd)
193195
assert_(vals.min() >= 2)
194196

@@ -480,9 +482,9 @@ def test_beta(self):
480482
mt19937.seed(self.seed)
481483
actual = mt19937.beta(.1, .9, size=(3, 2))
482484
desired = np.array(
483-
[[1.45341850513746058e-02, 5.31297615662868145e-04],
484-
[1.85366619058432324e-06, 4.19214516800110563e-03],
485-
[1.58405155108498093e-04, 1.26252891949397652e-04]])
485+
[[1.45341850513746058e-02, 5.31297615662868145e-04],
486+
[1.85366619058432324e-06, 4.19214516800110563e-03],
487+
[1.58405155108498093e-04, 1.26252891949397652e-04]])
486488
assert_array_almost_equal(actual, desired, decimal=15)
487489

488490
def test_binomial(self):
@@ -512,6 +514,8 @@ def test_dirichlet(self):
512514
[[0.59266909280647828, 0.40733090719352177],
513515
[0.56974431743975207, 0.43025568256024799]]])
514516
assert_array_almost_equal(actual, desired, decimal=15)
517+
bad_alpha = np.array([5.4e-01, -1.0e-16])
518+
assert_raises(ValueError, mt19937.dirichlet, bad_alpha)
515519

516520
def test_dirichlet_size(self):
517521
# gh-3173
@@ -660,7 +664,7 @@ def test_multivariate_normal(self):
660664
cov = [[1, 0], [0, 1]]
661665
size = (3, 2)
662666
actual = mt19937.multivariate_normal(mean, cov, size)
663-
desired = np.array([[[1.463620246718631, 11.73759122771936 ],
667+
desired = np.array([[[1.463620246718631, 11.73759122771936],
664668
[1.622445133300628, 9.771356667546383]],
665669
[[2.154490787682787, 12.170324946056553],
666670
[1.719909438201865, 9.230548443648306]],
@@ -720,7 +724,7 @@ def test_noncentral_chisquare(self):
720724
def test_noncentral_f(self):
721725
mt19937.seed(self.seed)
722726
actual = mt19937.noncentral_f(dfnum=5, dfden=2, nonc=1,
723-
size=(3, 2))
727+
size=(3, 2))
724728
desired = np.array([[1.40598099674926669, 0.34207973179285761],
725729
[3.57715069265772545, 7.92632662577829805],
726730
[0.43741599463544162, 1.1774208752428319]])
@@ -742,9 +746,9 @@ def test_pareto(self):
742746
mt19937.seed(self.seed)
743747
actual = mt19937.pareto(a=.123456789, size=(3, 2))
744748
desired = np.array(
745-
[[2.46852460439034849e+03, 1.41286880810518346e+03],
746-
[5.28287797029485181e+07, 6.57720981047328785e+07],
747-
[1.40840323350391515e+02, 1.98390255135251704e+05]])
749+
[[2.46852460439034849e+03, 1.41286880810518346e+03],
750+
[5.28287797029485181e+07, 6.57720981047328785e+07],
751+
[1.40840323350391515e+02, 1.98390255135251704e+05]])
748752
# For some reason on 32-bit x86 Ubuntu 12.10 the [1, 0] entry in this
749753
# matrix differs by 24 nulps. Discussion:
750754
# http://mail.scipy.org/pipermail/numpy-discussion/2012-September/063801.html
@@ -765,9 +769,9 @@ def test_poisson_exceptions(self):
765769
lambig = np.iinfo('l').max
766770
lamneg = -1
767771
assert_raises(ValueError, mt19937.poisson, lamneg)
768-
assert_raises(ValueError, mt19937.poisson, [lamneg]*10)
772+
assert_raises(ValueError, mt19937.poisson, [lamneg] * 10)
769773
assert_raises(ValueError, mt19937.poisson, lambig)
770-
assert_raises(ValueError, mt19937.poisson, [lambig]*10)
774+
assert_raises(ValueError, mt19937.poisson, [lambig] * 10)
771775

772776
def test_power(self):
773777
mt19937.seed(self.seed)
@@ -856,8 +860,8 @@ def test_uniform_range_bounds(self):
856860

857861
func = mt19937.uniform
858862
assert_raises(OverflowError, func, -np.inf, 0)
859-
assert_raises(OverflowError, func, 0, np.inf)
860-
assert_raises(OverflowError, func, fmin, fmax)
863+
assert_raises(OverflowError, func, 0, np.inf)
864+
assert_raises(OverflowError, func, fmin, fmax)
861865
assert_raises(OverflowError, func, [-np.inf], [0])
862866
assert_raises(OverflowError, func, [0], [np.inf])
863867

@@ -886,7 +890,7 @@ def __int__(self):
886890

887891
throwing_int = np.array(1).view(ThrowingInteger)
888892
assert_raises(TypeError, mt19937.hypergeometric, throwing_int, 1, 1)
889-
893+
890894
def test_vonmises(self):
891895
mt19937.seed(self.seed)
892896
actual = mt19937.vonmises(mu=1.23, kappa=1.54, size=(3, 2))
@@ -1485,6 +1489,7 @@ def test_logseries(self):
14851489
assert_raises(ValueError, logseries, bad_p_one * 3)
14861490
assert_raises(ValueError, logseries, bad_p_two * 3)
14871491

1492+
14881493
class TestThread(TestCase):
14891494
# make sure each state produces the same sequence even in threads
14901495
def setUp(self):
@@ -1515,18 +1520,22 @@ def check_function(self, function, sz):
15151520
def test_normal(self):
15161521
def gen_random(state, out):
15171522
out[...] = state.normal(size=10000)
1523+
15181524
self.check_function(gen_random, sz=(10000,))
15191525

15201526
def test_exp(self):
15211527
def gen_random(state, out):
15221528
out[...] = state.exponential(scale=np.ones((100, 1000)))
1529+
15231530
self.check_function(gen_random, sz=(100, 1000))
15241531

15251532
def test_multinomial(self):
15261533
def gen_random(state, out):
1527-
out[...] = state.multinomial(10, [1/6.]*6, size=10000)
1534+
out[...] = state.multinomial(10, [1 / 6.] * 6, size=10000)
1535+
15281536
self.check_function(gen_random, sz=(10000, 6))
15291537

1538+
15301539
# See Issue #4263
15311540
class TestSingleEltArrayInput(TestCase):
15321541
def setUp(self):
@@ -1581,23 +1590,23 @@ def test_two_arg_funcs(self):
15811590
out = func(self.argOne, argTwo[0])
15821591
self.assertEqual(out.shape, self.tgtShape)
15831592

1584-
# TODO: Uncomment once randint can broadcast arguments
1585-
# def test_randint(self):
1586-
# itype = [np.bool, np.int8, np.uint8, np.int16, np.uint16,
1587-
# np.int32, np.uint32, np.int64, np.uint64]
1588-
# func = mt19937.randint
1589-
# high = np.array([1])
1590-
# low = np.array([0])
1591-
#
1592-
# for dt in itype:
1593-
# out = func(low, high, dtype=dt)
1594-
# self.assert_equal(out.shape, self.tgtShape)
1595-
#
1596-
# out = func(low[0], high, dtype=dt)
1597-
# self.assert_equal(out.shape, self.tgtShape)
1598-
#
1599-
# out = func(low, high[0], dtype=dt)
1600-
# self.assert_equal(out.shape, self.tgtShape)
1593+
# TODO: Uncomment once randint can broadcast arguments
1594+
# def test_randint(self):
1595+
# itype = [np.bool, np.int8, np.uint8, np.int16, np.uint16,
1596+
# np.int32, np.uint32, np.int64, np.uint64]
1597+
# func = mt19937.randint
1598+
# high = np.array([1])
1599+
# low = np.array([0])
1600+
#
1601+
# for dt in itype:
1602+
# out = func(low, high, dtype=dt)
1603+
# self.assert_equal(out.shape, self.tgtShape)
1604+
#
1605+
# out = func(low[0], high, dtype=dt)
1606+
# self.assert_equal(out.shape, self.tgtShape)
1607+
#
1608+
# out = func(low, high[0], dtype=dt)
1609+
# self.assert_equal(out.shape, self.tgtShape)
16011610

16021611
def test_three_arg_funcs(self):
16031612
funcs = [mt19937.noncentral_f, mt19937.triangular,
@@ -1613,5 +1622,6 @@ def test_three_arg_funcs(self):
16131622
out = func(self.argOne, self.argTwo[0], self.argThree)
16141623
self.assertEqual(out.shape, self.tgtShape)
16151624

1625+
16161626
if __name__ == "__main__":
16171627
run_module_suite()

0 commit comments

Comments
 (0)