Skip to content

Commit d3fb85b

Browse files
committed
SYNC: Sync with NumPy changes
Sync randint testing with NumPy Resotre setup method to simplify testing
1 parent b234110 commit d3fb85b

File tree

1 file changed

+54
-52
lines changed

1 file changed

+54
-52
lines changed

randomstate/tests/test_numpy_mt19937.py

+54-52
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import warnings
55

66
import numpy as np
7+
import randomstate as random
78
from numpy.testing import (
89
run_module_suite, assert_, assert_raises, assert_equal,
910
assert_warns, assert_no_warnings, assert_array_equal,
1011
assert_array_almost_equal)
11-
12-
import randomstate as random
1312
from randomstate.compat import suppress_warnings
1413
from randomstate.prng.mt19937 import mt19937
1514

@@ -89,13 +88,11 @@ def test_size(self):
8988

9089

9190
class TestSetState(object):
92-
93-
@classmethod
94-
def setup_class(cls):
95-
cls.seed = 1234567890
96-
cls.prng = random.RandomState(cls.seed)
97-
cls.state = cls.prng.get_state()
98-
cls.legacy_state = cls.prng.get_state(legacy=True) # Use legacy to get old NumPy state
91+
def setup(self):
92+
self.seed = 1234567890
93+
self.prng = random.RandomState(self.seed)
94+
self.state = self.prng.get_state()
95+
self.legacy_state = self.prng.get_state(legacy=True) # Use legacy to get old NumPy state
9996

10097
def test_basic(self):
10198
old = self.prng.tomaxint(16)
@@ -105,7 +102,6 @@ def test_basic(self):
105102

106103
def test_gaussian_reset(self):
107104
# Make sure the cached every-other-Gaussian is reset.
108-
self.prng.set_state(self.state)
109105
old = self.prng.standard_normal(size=3)
110106
self.prng.set_state(self.state)
111107
new = self.prng.standard_normal(size=3)
@@ -126,7 +122,6 @@ def test_backwards_compatibility(self):
126122
# Make sure we can accept old state tuples that do not have the
127123
# cached Gaussian value.
128124
old_state = self.legacy_state[:-2]
129-
self.prng.set_state(self.legacy_state)
130125
x1 = self.prng.standard_normal(size=16)
131126
self.prng.set_state(old_state)
132127
x2 = self.prng.standard_normal(size=16)
@@ -160,6 +155,11 @@ def test_bounds_checking(self):
160155
assert_raises(ValueError, self.rfunc, ubnd, lbnd, dtype=dt)
161156
assert_raises(ValueError, self.rfunc, 1, 0, dtype=dt)
162157

158+
assert_raises(ValueError, self.rfunc, [lbnd - 1], ubnd, dtype=dt)
159+
assert_raises(ValueError, self.rfunc, [lbnd], [ubnd + 1], dtype=dt)
160+
assert_raises(ValueError, self.rfunc, [ubnd], [lbnd], dtype=dt)
161+
assert_raises(ValueError, self.rfunc, 1, [0], dtype=dt)
162+
163163
def test_bounds_checking_array(self):
164164
for dt in self.itype:
165165
lbnd = 0 if dt is bool else np.iinfo(dt).min
@@ -176,12 +176,15 @@ def test_rng_zero_and_extremes(self):
176176

177177
tgt = ubnd - 1
178178
assert_equal(self.rfunc(tgt, tgt + 1, size=1000, dtype=dt), tgt)
179+
assert_equal(self.rfunc([tgt], tgt + 1, size=1000, dtype=dt), tgt)
179180

180181
tgt = lbnd
181182
assert_equal(self.rfunc(tgt, tgt + 1, size=1000, dtype=dt), tgt)
183+
assert_equal(self.rfunc(tgt, [tgt + 1], size=1000, dtype=dt), tgt)
182184

183185
tgt = (lbnd + ubnd) // 2
184186
assert_equal(self.rfunc(tgt, tgt + 1, size=1000, dtype=dt), tgt)
187+
assert_equal(self.rfunc([tgt], [tgt + 1], size=1000, dtype=dt), tgt)
185188

186189
def test_rng_zero_and_extremes_array(self):
187190
size = 1000
@@ -191,8 +194,8 @@ def test_rng_zero_and_extremes_array(self):
191194

192195
tgt = ubnd - 1
193196
assert_equal(self.rfunc([tgt], [tgt + 1], size=size, dtype=dt), tgt)
194-
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, dtype=dt), tgt)
195-
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, size=size, dtype=dt), tgt)
197+
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, dtype=dt), tgt)
198+
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, size=size, dtype=dt), tgt)
196199

197200
tgt = lbnd
198201
assert_equal(self.rfunc([tgt], [tgt + 1], size=size, dtype=dt), tgt)
@@ -226,12 +229,27 @@ def test_full_range_array(self):
226229
ubnd = 2 if dt is bool else np.iinfo(dt).max + 1
227230

228231
try:
229-
self.rfunc([lbnd], [ubnd], dtype=dt)
232+
self.rfunc([lbnd] * 2, [ubnd], dtype=dt)
230233
except Exception as e:
231234
raise AssertionError("No error should have been raised, "
232235
"but one was with the following "
233236
"message:\n\n%s" % str(e))
234237

238+
def test_in_bounds_fuzz(self):
239+
# Don't use fixed seed
240+
mt19937.seed()
241+
242+
for dt in self.itype[1:]:
243+
for ubnd in [4, 8, 16]:
244+
vals = self.rfunc(2, ubnd, size=2 ** 16, dtype=dt)
245+
assert_(vals.max() < ubnd)
246+
assert_(vals.min() >= 2)
247+
248+
vals = self.rfunc(0, 2, size=2 ** 16, dtype=bool)
249+
250+
assert_(vals.max() < 2)
251+
assert_(vals.min() >= 0)
252+
235253
def test_scalar_array_equiv(self):
236254
for dt in self.itype:
237255
lbnd = 0 if dt is bool else np.iinfo(dt).min
@@ -242,29 +260,13 @@ def test_scalar_array_equiv(self):
242260
scalar = self.rfunc(lbnd, ubnd, size=size, dtype=dt)
243261

244262
mt19937.seed(1234)
245-
scalar_array = self.rfunc(lbnd, ubnd, size=size, dtype=dt)
263+
scalar_array = self.rfunc([lbnd], [ubnd], size=size, dtype=dt)
246264

247265
mt19937.seed(1234)
248266
array = self.rfunc([lbnd] * size, [ubnd] * size, size=size, dtype=dt)
249267
assert_array_equal(scalar, scalar_array)
250268
assert_array_equal(scalar, array)
251269

252-
253-
def test_in_bounds_fuzz(self):
254-
# Don't use fixed seed
255-
mt19937.seed()
256-
257-
for dt in self.itype[1:]:
258-
for ubnd in [4, 8, 16]:
259-
vals = self.rfunc(2, ubnd, size=2 ** 16, dtype=dt)
260-
assert_(vals.max() < ubnd)
261-
assert_(vals.min() >= 2)
262-
263-
vals = self.rfunc(0, 2, size=2 ** 16, dtype=bool)
264-
265-
assert_(vals.max() < 2)
266-
assert_(vals.min() >= 0)
267-
268270
def test_repeatability(self):
269271
import hashlib
270272
# We use a md5 hash of generated sequences of 1000 samples
@@ -301,7 +303,6 @@ def test_repeatability(self):
301303
def test_repeatability_broadcasting(self):
302304

303305
for dt in self.itype:
304-
305306
lbnd = 0 if dt in (np.bool, bool, np.bool_) else np.iinfo(dt).min
306307
ubnd = 2 if dt in (np.bool, bool, np.bool_) else np.iinfo(dt).max + 1
307308

@@ -361,7 +362,6 @@ def test_respect_dtype_singleton(self):
361362
assert not hasattr(sample, 'dtype')
362363
assert_equal(type(sample), dt)
363364

364-
365365
def test_respect_dtype_array(self):
366366
# See gh-7203
367367
for dt in self.itype:
@@ -374,21 +374,21 @@ def test_respect_dtype_array(self):
374374
sample = self.rfunc([lbnd] * 2, [ubnd] * 2, dtype=dt)
375375
assert_equal(sample.dtype, dt)
376376

377-
def test_empty(self):
377+
def test_zero_size(self):
378+
# See gh-7203
378379
for dt in self.itype:
379380
sample = self.rfunc(0, 0, (3, 0, 4), dtype=dt)
380-
assert_equal(sample.shape, (3, 0, 4))
381-
assert_equal(self.rfunc(0, -10, size=0, dtype=dt).shape, (0,))
382-
assert_equal(sample.dtype, dt)
381+
assert sample.shape == (3, 0, 4)
382+
assert sample.dtype == dt
383+
assert self.rfunc(0, -10, 0, dtype=dt).shape == (0,)
383384

384385

385386
class TestRandomDist(object):
386387
# Make sure the random distribution returns the correct value for a
387388
# given seed
388389

389-
@classmethod
390-
def setup_class(cls):
391-
cls.seed = 1234567890
390+
def setup(self):
391+
self.seed = 1234567890
392392

393393
def test_rand(self):
394394
mt19937.seed(self.seed)
@@ -638,6 +638,11 @@ def test_dirichlet_size(self):
638638

639639
assert_raises(TypeError, mt19937.dirichlet, p, float(1))
640640

641+
def test_dirichlet_bad_alpha(self):
642+
# gh-2089
643+
alpha = np.array([5.4e-01, -1.0e-16])
644+
assert_raises(ValueError, mt19937.dirichlet, alpha)
645+
641646
def test_exponential(self):
642647
mt19937.seed(self.seed)
643648
actual = mt19937.exponential(1.1234, size=(3, 2))
@@ -1046,9 +1051,8 @@ def test_zipf(self):
10461051
class TestBroadcast(object):
10471052
# tests that functions that broadcast behave
10481053
# correctly when presented with non-scalar arguments
1049-
@classmethod
1050-
def setup_class(cls):
1051-
cls.seed = 123456789
1054+
def setup(self):
1055+
self.seed = 123456789
10521056

10531057
def set_seed(self):
10541058
random.seed(self.seed)
@@ -1603,9 +1607,8 @@ def test_logseries(self):
16031607
class TestThread(object):
16041608
# make sure each state produces the same sequence even in threads
16051609

1606-
@classmethod
1607-
def setup_class(cls):
1608-
cls.seeds = range(4)
1610+
def setup(self):
1611+
self.seeds = range(4)
16091612

16101613
def check_function(self, function, sz):
16111614
from threading import Thread
@@ -1650,12 +1653,11 @@ def gen_random(state, out):
16501653

16511654
# See Issue #4263
16521655
class TestSingleEltArrayInput(object):
1653-
@classmethod
1654-
def setup_class(cls):
1655-
cls.argOne = np.array([2])
1656-
cls.argTwo = np.array([3])
1657-
cls.argThree = np.array([4])
1658-
cls.tgtShape = (1,)
1656+
def setup(self):
1657+
self.argOne = np.array([2])
1658+
self.argTwo = np.array([3])
1659+
self.argThree = np.array([4])
1660+
self.tgtShape = (1,)
16591661

16601662
def test_one_arg_funcs(self):
16611663
funcs = (mt19937.exponential, mt19937.standard_gamma,

0 commit comments

Comments
 (0)