44import warnings
55
66import numpy as np
7+ import randomstate as random
78from numpy .testing import (
89 run_module_suite , assert_ , assert_raises , assert_equal ,
910 assert_warns , assert_no_warnings , assert_array_equal ,
1011 assert_array_almost_equal )
11-
12- import randomstate as random
1312from randomstate .compat import suppress_warnings
1413from randomstate .prng .mt19937 import mt19937
1514
@@ -89,13 +88,11 @@ def test_size(self):
8988
9089
9190class TestSetState (object ):
92-
93- @classmethod
94- def setup_class (cls ):
95- cls .seed = 1234567890
96- cls .prng = random .RandomState (cls .seed )
97- cls .state = cls .prng .get_state ()
98- cls .legacy_state = cls .prng .get_state (legacy = True ) # Use legacy to get old NumPy state
91+ def setup (self ):
92+ self .seed = 1234567890
93+ self .prng = random .RandomState (self .seed )
94+ self .state = self .prng .get_state ()
95+ self .legacy_state = self .prng .get_state (legacy = True ) # Use legacy to get old NumPy state
9996
10097 def test_basic (self ):
10198 old = self .prng .tomaxint (16 )
@@ -105,7 +102,6 @@ def test_basic(self):
105102
106103 def test_gaussian_reset (self ):
107104 # Make sure the cached every-other-Gaussian is reset.
108- self .prng .set_state (self .state )
109105 old = self .prng .standard_normal (size = 3 )
110106 self .prng .set_state (self .state )
111107 new = self .prng .standard_normal (size = 3 )
@@ -126,7 +122,6 @@ def test_backwards_compatibility(self):
126122 # Make sure we can accept old state tuples that do not have the
127123 # cached Gaussian value.
128124 old_state = self .legacy_state [:- 2 ]
129- self .prng .set_state (self .legacy_state )
130125 x1 = self .prng .standard_normal (size = 16 )
131126 self .prng .set_state (old_state )
132127 x2 = self .prng .standard_normal (size = 16 )
@@ -160,6 +155,11 @@ def test_bounds_checking(self):
160155 assert_raises (ValueError , self .rfunc , ubnd , lbnd , dtype = dt )
161156 assert_raises (ValueError , self .rfunc , 1 , 0 , dtype = dt )
162157
158+ assert_raises (ValueError , self .rfunc , [lbnd - 1 ], ubnd , dtype = dt )
159+ assert_raises (ValueError , self .rfunc , [lbnd ], [ubnd + 1 ], dtype = dt )
160+ assert_raises (ValueError , self .rfunc , [ubnd ], [lbnd ], dtype = dt )
161+ assert_raises (ValueError , self .rfunc , 1 , [0 ], dtype = dt )
162+
163163 def test_bounds_checking_array (self ):
164164 for dt in self .itype :
165165 lbnd = 0 if dt is bool else np .iinfo (dt ).min
@@ -176,12 +176,15 @@ def test_rng_zero_and_extremes(self):
176176
177177 tgt = ubnd - 1
178178 assert_equal (self .rfunc (tgt , tgt + 1 , size = 1000 , dtype = dt ), tgt )
179+ assert_equal (self .rfunc ([tgt ], tgt + 1 , size = 1000 , dtype = dt ), tgt )
179180
180181 tgt = lbnd
181182 assert_equal (self .rfunc (tgt , tgt + 1 , size = 1000 , dtype = dt ), tgt )
183+ assert_equal (self .rfunc (tgt , [tgt + 1 ], size = 1000 , dtype = dt ), tgt )
182184
183185 tgt = (lbnd + ubnd ) // 2
184186 assert_equal (self .rfunc (tgt , tgt + 1 , size = 1000 , dtype = dt ), tgt )
187+ assert_equal (self .rfunc ([tgt ], [tgt + 1 ], size = 1000 , dtype = dt ), tgt )
185188
186189 def test_rng_zero_and_extremes_array (self ):
187190 size = 1000
@@ -191,8 +194,8 @@ def test_rng_zero_and_extremes_array(self):
191194
192195 tgt = ubnd - 1
193196 assert_equal (self .rfunc ([tgt ], [tgt + 1 ], size = size , dtype = dt ), tgt )
194- assert_equal (self .rfunc ([tgt ] * size , [tgt + 1 ] * size , dtype = dt ), tgt )
195- assert_equal (self .rfunc ([tgt ] * size , [tgt + 1 ] * size , size = size , dtype = dt ), tgt )
197+ assert_equal (self .rfunc ([tgt ] * size , [tgt + 1 ] * size , dtype = dt ), tgt )
198+ assert_equal (self .rfunc ([tgt ] * size , [tgt + 1 ] * size , size = size , dtype = dt ), tgt )
196199
197200 tgt = lbnd
198201 assert_equal (self .rfunc ([tgt ], [tgt + 1 ], size = size , dtype = dt ), tgt )
@@ -226,12 +229,27 @@ def test_full_range_array(self):
226229 ubnd = 2 if dt is bool else np .iinfo (dt ).max + 1
227230
228231 try :
229- self .rfunc ([lbnd ], [ubnd ], dtype = dt )
232+ self .rfunc ([lbnd ] * 2 , [ubnd ], dtype = dt )
230233 except Exception as e :
231234 raise AssertionError ("No error should have been raised, "
232235 "but one was with the following "
233236 "message:\n \n %s" % str (e ))
234237
238+ def test_in_bounds_fuzz (self ):
239+ # Don't use fixed seed
240+ mt19937 .seed ()
241+
242+ for dt in self .itype [1 :]:
243+ for ubnd in [4 , 8 , 16 ]:
244+ vals = self .rfunc (2 , ubnd , size = 2 ** 16 , dtype = dt )
245+ assert_ (vals .max () < ubnd )
246+ assert_ (vals .min () >= 2 )
247+
248+ vals = self .rfunc (0 , 2 , size = 2 ** 16 , dtype = bool )
249+
250+ assert_ (vals .max () < 2 )
251+ assert_ (vals .min () >= 0 )
252+
235253 def test_scalar_array_equiv (self ):
236254 for dt in self .itype :
237255 lbnd = 0 if dt is bool else np .iinfo (dt ).min
@@ -242,29 +260,13 @@ def test_scalar_array_equiv(self):
242260 scalar = self .rfunc (lbnd , ubnd , size = size , dtype = dt )
243261
244262 mt19937 .seed (1234 )
245- scalar_array = self .rfunc (lbnd , ubnd , size = size , dtype = dt )
263+ scalar_array = self .rfunc ([ lbnd ], [ ubnd ] , size = size , dtype = dt )
246264
247265 mt19937 .seed (1234 )
248266 array = self .rfunc ([lbnd ] * size , [ubnd ] * size , size = size , dtype = dt )
249267 assert_array_equal (scalar , scalar_array )
250268 assert_array_equal (scalar , array )
251269
252-
253- def test_in_bounds_fuzz (self ):
254- # Don't use fixed seed
255- mt19937 .seed ()
256-
257- for dt in self .itype [1 :]:
258- for ubnd in [4 , 8 , 16 ]:
259- vals = self .rfunc (2 , ubnd , size = 2 ** 16 , dtype = dt )
260- assert_ (vals .max () < ubnd )
261- assert_ (vals .min () >= 2 )
262-
263- vals = self .rfunc (0 , 2 , size = 2 ** 16 , dtype = bool )
264-
265- assert_ (vals .max () < 2 )
266- assert_ (vals .min () >= 0 )
267-
268270 def test_repeatability (self ):
269271 import hashlib
270272 # We use a md5 hash of generated sequences of 1000 samples
@@ -301,7 +303,6 @@ def test_repeatability(self):
301303 def test_repeatability_broadcasting (self ):
302304
303305 for dt in self .itype :
304-
305306 lbnd = 0 if dt in (np .bool , bool , np .bool_ ) else np .iinfo (dt ).min
306307 ubnd = 2 if dt in (np .bool , bool , np .bool_ ) else np .iinfo (dt ).max + 1
307308
@@ -361,7 +362,6 @@ def test_respect_dtype_singleton(self):
361362 assert not hasattr (sample , 'dtype' )
362363 assert_equal (type (sample ), dt )
363364
364-
365365 def test_respect_dtype_array (self ):
366366 # See gh-7203
367367 for dt in self .itype :
@@ -374,21 +374,21 @@ def test_respect_dtype_array(self):
374374 sample = self .rfunc ([lbnd ] * 2 , [ubnd ] * 2 , dtype = dt )
375375 assert_equal (sample .dtype , dt )
376376
377- def test_empty (self ):
377+ def test_zero_size (self ):
378+ # See gh-7203
378379 for dt in self .itype :
379380 sample = self .rfunc (0 , 0 , (3 , 0 , 4 ), dtype = dt )
380- assert_equal ( sample .shape , (3 , 0 , 4 ) )
381- assert_equal ( self . rfunc ( 0 , - 10 , size = 0 , dtype = dt ). shape , ( 0 ,))
382- assert_equal ( sample . dtype , dt )
381+ assert sample .shape == (3 , 0 , 4 )
382+ assert sample . dtype == dt
383+ assert self . rfunc ( 0 , - 10 , 0 , dtype = dt ). shape == ( 0 , )
383384
384385
385386class TestRandomDist (object ):
386387 # Make sure the random distribution returns the correct value for a
387388 # given seed
388389
389- @classmethod
390- def setup_class (cls ):
391- cls .seed = 1234567890
390+ def setup (self ):
391+ self .seed = 1234567890
392392
393393 def test_rand (self ):
394394 mt19937 .seed (self .seed )
@@ -638,6 +638,11 @@ def test_dirichlet_size(self):
638638
639639 assert_raises (TypeError , mt19937 .dirichlet , p , float (1 ))
640640
641+ def test_dirichlet_bad_alpha (self ):
642+ # gh-2089
643+ alpha = np .array ([5.4e-01 , - 1.0e-16 ])
644+ assert_raises (ValueError , mt19937 .dirichlet , alpha )
645+
641646 def test_exponential (self ):
642647 mt19937 .seed (self .seed )
643648 actual = mt19937 .exponential (1.1234 , size = (3 , 2 ))
@@ -1046,9 +1051,8 @@ def test_zipf(self):
10461051class TestBroadcast (object ):
10471052 # tests that functions that broadcast behave
10481053 # correctly when presented with non-scalar arguments
1049- @classmethod
1050- def setup_class (cls ):
1051- cls .seed = 123456789
1054+ def setup (self ):
1055+ self .seed = 123456789
10521056
10531057 def set_seed (self ):
10541058 random .seed (self .seed )
@@ -1603,9 +1607,8 @@ def test_logseries(self):
16031607class TestThread (object ):
16041608 # make sure each state produces the same sequence even in threads
16051609
1606- @classmethod
1607- def setup_class (cls ):
1608- cls .seeds = range (4 )
1610+ def setup (self ):
1611+ self .seeds = range (4 )
16091612
16101613 def check_function (self , function , sz ):
16111614 from threading import Thread
@@ -1650,12 +1653,11 @@ def gen_random(state, out):
16501653
16511654# See Issue #4263
16521655class TestSingleEltArrayInput (object ):
1653- @classmethod
1654- def setup_class (cls ):
1655- cls .argOne = np .array ([2 ])
1656- cls .argTwo = np .array ([3 ])
1657- cls .argThree = np .array ([4 ])
1658- cls .tgtShape = (1 ,)
1656+ def setup (self ):
1657+ self .argOne = np .array ([2 ])
1658+ self .argTwo = np .array ([3 ])
1659+ self .argThree = np .array ([4 ])
1660+ self .tgtShape = (1 ,)
16591661
16601662 def test_one_arg_funcs (self ):
16611663 funcs = (mt19937 .exponential , mt19937 .standard_gamma ,
0 commit comments