4
4
import warnings
5
5
6
6
import numpy as np
7
+ import randomstate as random
7
8
from numpy .testing import (
8
9
run_module_suite , assert_ , assert_raises , assert_equal ,
9
10
assert_warns , assert_no_warnings , assert_array_equal ,
10
11
assert_array_almost_equal )
11
-
12
- import randomstate as random
13
12
from randomstate .compat import suppress_warnings
14
13
from randomstate .prng .mt19937 import mt19937
15
14
@@ -89,13 +88,11 @@ def test_size(self):
89
88
90
89
91
90
class TestSetState (object ):
92
-
93
- @classmethod
94
- def setup_class (cls ):
95
- cls .seed = 1234567890
96
- cls .prng = random .RandomState (cls .seed )
97
- cls .state = cls .prng .get_state ()
98
- cls .legacy_state = cls .prng .get_state (legacy = True ) # Use legacy to get old NumPy state
91
+ def setup (self ):
92
+ self .seed = 1234567890
93
+ self .prng = random .RandomState (self .seed )
94
+ self .state = self .prng .get_state ()
95
+ self .legacy_state = self .prng .get_state (legacy = True ) # Use legacy to get old NumPy state
99
96
100
97
def test_basic (self ):
101
98
old = self .prng .tomaxint (16 )
@@ -105,7 +102,6 @@ def test_basic(self):
105
102
106
103
def test_gaussian_reset (self ):
107
104
# Make sure the cached every-other-Gaussian is reset.
108
- self .prng .set_state (self .state )
109
105
old = self .prng .standard_normal (size = 3 )
110
106
self .prng .set_state (self .state )
111
107
new = self .prng .standard_normal (size = 3 )
@@ -126,7 +122,6 @@ def test_backwards_compatibility(self):
126
122
# Make sure we can accept old state tuples that do not have the
127
123
# cached Gaussian value.
128
124
old_state = self .legacy_state [:- 2 ]
129
- self .prng .set_state (self .legacy_state )
130
125
x1 = self .prng .standard_normal (size = 16 )
131
126
self .prng .set_state (old_state )
132
127
x2 = self .prng .standard_normal (size = 16 )
@@ -160,6 +155,11 @@ def test_bounds_checking(self):
160
155
assert_raises (ValueError , self .rfunc , ubnd , lbnd , dtype = dt )
161
156
assert_raises (ValueError , self .rfunc , 1 , 0 , dtype = dt )
162
157
158
+ assert_raises (ValueError , self .rfunc , [lbnd - 1 ], ubnd , dtype = dt )
159
+ assert_raises (ValueError , self .rfunc , [lbnd ], [ubnd + 1 ], dtype = dt )
160
+ assert_raises (ValueError , self .rfunc , [ubnd ], [lbnd ], dtype = dt )
161
+ assert_raises (ValueError , self .rfunc , 1 , [0 ], dtype = dt )
162
+
163
163
def test_bounds_checking_array (self ):
164
164
for dt in self .itype :
165
165
lbnd = 0 if dt is bool else np .iinfo (dt ).min
@@ -176,12 +176,15 @@ def test_rng_zero_and_extremes(self):
176
176
177
177
tgt = ubnd - 1
178
178
assert_equal (self .rfunc (tgt , tgt + 1 , size = 1000 , dtype = dt ), tgt )
179
+ assert_equal (self .rfunc ([tgt ], tgt + 1 , size = 1000 , dtype = dt ), tgt )
179
180
180
181
tgt = lbnd
181
182
assert_equal (self .rfunc (tgt , tgt + 1 , size = 1000 , dtype = dt ), tgt )
183
+ assert_equal (self .rfunc (tgt , [tgt + 1 ], size = 1000 , dtype = dt ), tgt )
182
184
183
185
tgt = (lbnd + ubnd ) // 2
184
186
assert_equal (self .rfunc (tgt , tgt + 1 , size = 1000 , dtype = dt ), tgt )
187
+ assert_equal (self .rfunc ([tgt ], [tgt + 1 ], size = 1000 , dtype = dt ), tgt )
185
188
186
189
def test_rng_zero_and_extremes_array (self ):
187
190
size = 1000
@@ -191,8 +194,8 @@ def test_rng_zero_and_extremes_array(self):
191
194
192
195
tgt = ubnd - 1
193
196
assert_equal (self .rfunc ([tgt ], [tgt + 1 ], size = size , dtype = dt ), tgt )
194
- assert_equal (self .rfunc ([tgt ] * size , [tgt + 1 ] * size , dtype = dt ), tgt )
195
- assert_equal (self .rfunc ([tgt ] * size , [tgt + 1 ] * size , size = size , dtype = dt ), tgt )
197
+ assert_equal (self .rfunc ([tgt ] * size , [tgt + 1 ] * size , dtype = dt ), tgt )
198
+ assert_equal (self .rfunc ([tgt ] * size , [tgt + 1 ] * size , size = size , dtype = dt ), tgt )
196
199
197
200
tgt = lbnd
198
201
assert_equal (self .rfunc ([tgt ], [tgt + 1 ], size = size , dtype = dt ), tgt )
@@ -226,12 +229,27 @@ def test_full_range_array(self):
226
229
ubnd = 2 if dt is bool else np .iinfo (dt ).max + 1
227
230
228
231
try :
229
- self .rfunc ([lbnd ], [ubnd ], dtype = dt )
232
+ self .rfunc ([lbnd ] * 2 , [ubnd ], dtype = dt )
230
233
except Exception as e :
231
234
raise AssertionError ("No error should have been raised, "
232
235
"but one was with the following "
233
236
"message:\n \n %s" % str (e ))
234
237
238
+ def test_in_bounds_fuzz (self ):
239
+ # Don't use fixed seed
240
+ mt19937 .seed ()
241
+
242
+ for dt in self .itype [1 :]:
243
+ for ubnd in [4 , 8 , 16 ]:
244
+ vals = self .rfunc (2 , ubnd , size = 2 ** 16 , dtype = dt )
245
+ assert_ (vals .max () < ubnd )
246
+ assert_ (vals .min () >= 2 )
247
+
248
+ vals = self .rfunc (0 , 2 , size = 2 ** 16 , dtype = bool )
249
+
250
+ assert_ (vals .max () < 2 )
251
+ assert_ (vals .min () >= 0 )
252
+
235
253
def test_scalar_array_equiv (self ):
236
254
for dt in self .itype :
237
255
lbnd = 0 if dt is bool else np .iinfo (dt ).min
@@ -242,29 +260,13 @@ def test_scalar_array_equiv(self):
242
260
scalar = self .rfunc (lbnd , ubnd , size = size , dtype = dt )
243
261
244
262
mt19937 .seed (1234 )
245
- scalar_array = self .rfunc (lbnd , ubnd , size = size , dtype = dt )
263
+ scalar_array = self .rfunc ([ lbnd ], [ ubnd ] , size = size , dtype = dt )
246
264
247
265
mt19937 .seed (1234 )
248
266
array = self .rfunc ([lbnd ] * size , [ubnd ] * size , size = size , dtype = dt )
249
267
assert_array_equal (scalar , scalar_array )
250
268
assert_array_equal (scalar , array )
251
269
252
-
253
- def test_in_bounds_fuzz (self ):
254
- # Don't use fixed seed
255
- mt19937 .seed ()
256
-
257
- for dt in self .itype [1 :]:
258
- for ubnd in [4 , 8 , 16 ]:
259
- vals = self .rfunc (2 , ubnd , size = 2 ** 16 , dtype = dt )
260
- assert_ (vals .max () < ubnd )
261
- assert_ (vals .min () >= 2 )
262
-
263
- vals = self .rfunc (0 , 2 , size = 2 ** 16 , dtype = bool )
264
-
265
- assert_ (vals .max () < 2 )
266
- assert_ (vals .min () >= 0 )
267
-
268
270
def test_repeatability (self ):
269
271
import hashlib
270
272
# We use a md5 hash of generated sequences of 1000 samples
@@ -301,7 +303,6 @@ def test_repeatability(self):
301
303
def test_repeatability_broadcasting (self ):
302
304
303
305
for dt in self .itype :
304
-
305
306
lbnd = 0 if dt in (np .bool , bool , np .bool_ ) else np .iinfo (dt ).min
306
307
ubnd = 2 if dt in (np .bool , bool , np .bool_ ) else np .iinfo (dt ).max + 1
307
308
@@ -361,7 +362,6 @@ def test_respect_dtype_singleton(self):
361
362
assert not hasattr (sample , 'dtype' )
362
363
assert_equal (type (sample ), dt )
363
364
364
-
365
365
def test_respect_dtype_array (self ):
366
366
# See gh-7203
367
367
for dt in self .itype :
@@ -374,21 +374,21 @@ def test_respect_dtype_array(self):
374
374
sample = self .rfunc ([lbnd ] * 2 , [ubnd ] * 2 , dtype = dt )
375
375
assert_equal (sample .dtype , dt )
376
376
377
- def test_empty (self ):
377
+ def test_zero_size (self ):
378
+ # See gh-7203
378
379
for dt in self .itype :
379
380
sample = self .rfunc (0 , 0 , (3 , 0 , 4 ), dtype = dt )
380
- assert_equal ( sample .shape , (3 , 0 , 4 ) )
381
- assert_equal ( self . rfunc ( 0 , - 10 , size = 0 , dtype = dt ). shape , ( 0 ,))
382
- assert_equal ( sample . dtype , dt )
381
+ assert sample .shape == (3 , 0 , 4 )
382
+ assert sample . dtype == dt
383
+ assert self . rfunc ( 0 , - 10 , 0 , dtype = dt ). shape == ( 0 , )
383
384
384
385
385
386
class TestRandomDist (object ):
386
387
# Make sure the random distribution returns the correct value for a
387
388
# given seed
388
389
389
- @classmethod
390
- def setup_class (cls ):
391
- cls .seed = 1234567890
390
+ def setup (self ):
391
+ self .seed = 1234567890
392
392
393
393
def test_rand (self ):
394
394
mt19937 .seed (self .seed )
@@ -638,6 +638,11 @@ def test_dirichlet_size(self):
638
638
639
639
assert_raises (TypeError , mt19937 .dirichlet , p , float (1 ))
640
640
641
+ def test_dirichlet_bad_alpha (self ):
642
+ # gh-2089
643
+ alpha = np .array ([5.4e-01 , - 1.0e-16 ])
644
+ assert_raises (ValueError , mt19937 .dirichlet , alpha )
645
+
641
646
def test_exponential (self ):
642
647
mt19937 .seed (self .seed )
643
648
actual = mt19937 .exponential (1.1234 , size = (3 , 2 ))
@@ -1046,9 +1051,8 @@ def test_zipf(self):
1046
1051
class TestBroadcast (object ):
1047
1052
# tests that functions that broadcast behave
1048
1053
# correctly when presented with non-scalar arguments
1049
- @classmethod
1050
- def setup_class (cls ):
1051
- cls .seed = 123456789
1054
+ def setup (self ):
1055
+ self .seed = 123456789
1052
1056
1053
1057
def set_seed (self ):
1054
1058
random .seed (self .seed )
@@ -1603,9 +1607,8 @@ def test_logseries(self):
1603
1607
class TestThread (object ):
1604
1608
# make sure each state produces the same sequence even in threads
1605
1609
1606
- @classmethod
1607
- def setup_class (cls ):
1608
- cls .seeds = range (4 )
1610
+ def setup (self ):
1611
+ self .seeds = range (4 )
1609
1612
1610
1613
def check_function (self , function , sz ):
1611
1614
from threading import Thread
@@ -1650,12 +1653,11 @@ def gen_random(state, out):
1650
1653
1651
1654
# See Issue #4263
1652
1655
class TestSingleEltArrayInput (object ):
1653
- @classmethod
1654
- def setup_class (cls ):
1655
- cls .argOne = np .array ([2 ])
1656
- cls .argTwo = np .array ([3 ])
1657
- cls .argThree = np .array ([4 ])
1658
- cls .tgtShape = (1 ,)
1656
+ def setup (self ):
1657
+ self .argOne = np .array ([2 ])
1658
+ self .argTwo = np .array ([3 ])
1659
+ self .argThree = np .array ([4 ])
1660
+ self .tgtShape = (1 ,)
1659
1661
1660
1662
def test_one_arg_funcs (self ):
1661
1663
funcs = (mt19937 .exponential , mt19937 .standard_gamma ,
0 commit comments