Skip to content

Commit 735986c

Browse files
authored
Merge pull request #83 from bashtage/fix-test-input-shape
TST: Fix test input shape
2 parents dcce8f7 + aea9b61 commit 735986c

File tree

5 files changed

+217
-52
lines changed

5 files changed

+217
-52
lines changed

randomstate/interface/pcg-64/pcg-64-emulated.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ ctypedef pcg64_random_t rng_t
4141
cdef object pcg128_to_pylong(pcg128_t x):
4242
return PyLong_FromUnsignedLongLong(x.high) * 2**64 + PyLong_FromUnsignedLongLong(x.low)
4343

44-
cdef pcg128_t pcg128_from_pylong(object x):
44+
cdef pcg128_t pcg128_from_pylong(object x) except *:
4545
cdef pcg128_t out
4646
out.high = PyLong_AsUnsignedLongLong(x // (2 ** 64))
4747
out.low = PyLong_AsUnsignedLongLong(x % (2 ** 64))

randomstate/randomstate.pyx

Lines changed: 77 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,13 @@ cdef class RandomState:
275275
----------
276276
seed : int, optional
277277
Seed for ``RandomState``.
278+
279+
Raises
280+
------
281+
ValueError
282+
If seed values are out of range for the PRNG.
283+
TypeError
284+
If seed values are not integers.
278285
279286
Notes
280287
-----
@@ -287,38 +294,49 @@ cdef class RandomState:
287294
--------
288295
RandomState
289296
"""
290-
try:
291-
if seed is not None:
292-
if hasattr(seed, 'squeeze'):
293-
seed = seed.squeeze()
294-
idx = operator.index(seed)
295-
if idx < 0:
296-
raise ValueError('seed < 0')
297-
else:
298-
self.__seed = seed = _generate_seed(RS_SEED_NBYTES)
297+
if seed is None:
298+
self.__seed = seed = _generate_seed(RS_SEED_NBYTES)
299299
set_seed(&self.rng_state, seed)
300-
except TypeError:
300+
self._reset_state_variables()
301+
return
302+
303+
if hasattr(seed, 'squeeze'):
304+
seed = seed.squeeze()
305+
IF RS_SEED_ARRAY_BITS == 32:
306+
seed = np.asarray(seed).astype(np.object, casting='safe')
307+
if np.any((seed // 1) != seed):
308+
raise TypeError("Seed values must be integers between "
309+
"0 and 4294967295 (2**32-1)")
310+
if np.any((seed < int(0)) | (seed > int(2**32-1))):
311+
raise ValueError("Seed values must be integers between "
312+
"0 and 4294967295 (2**32-1)")
313+
seed = np.asarray(seed).astype(np.uint32, casting='unsafe')
314+
ELSE:
315+
seed = np.asarray(seed).astype(np.object, casting='safe')
316+
if np.any((seed // 1) != seed):
317+
raise TypeError("Seed values must be integers between 0 and "
318+
"18446744073709551616 (2**64-1)")
319+
if np.any((seed < int(0)) | (seed > int(2**64-1))):
320+
raise ValueError("Seed values must be integers between 0 and "
321+
"18446744073709551616 (2**64-1)")
322+
seed = np.asarray(seed).astype(np.uint64, casting='unsafe')
323+
324+
if seed.ndim == 0:
301325
IF RS_SEED_ARRAY_BITS == 32:
302-
seed = np.asarray(seed).astype(np.int64, casting='safe')
303-
if ((seed > int(2**32 - 1)) | (seed < 0)).any():
304-
raise ValueError("Seed values must be between 0 and "
305-
"4294967295 (2**32-1)")
306-
seed = seed.astype(np.uint32, casting='unsafe')
307-
with self.lock:
308-
set_seed_by_array(&self.rng_state,
309-
<uint32_t *>np.PyArray_DATA(seed),
310-
np.PyArray_DIM(seed, 0))
326+
seed = <uint32_t> seed
311327
ELSE:
312-
seed = np.asarray(seed).astype(np.object, casting='safe')
313-
if ((seed > int(2**64 - 1)) | (seed < 0)).any():
314-
raise ValueError("Seed values must be between 0 and "
315-
"18446744073709551616 (2**64-1)")
316-
seed = seed.astype(np.uint64, casting='unsafe')
317-
with self.lock:
318-
set_seed_by_array(&self.rng_state,
319-
<uint64_t *>np.PyArray_DATA(seed),
320-
np.PyArray_DIM(seed, 0))
321-
self.__seed = seed
328+
seed = <uint64_t> seed
329+
set_seed(&self.rng_state, seed)
330+
else:
331+
IF RS_SEED_ARRAY_BITS == 32:
332+
set_seed_by_array(&self.rng_state,
333+
<uint32_t *>np.PyArray_DATA(seed),
334+
np.PyArray_DIM(seed, 0))
335+
ELSE:
336+
set_seed_by_array(&self.rng_state,
337+
<uint64_t *>np.PyArray_DATA(seed),
338+
np.PyArray_DIM(seed, 0))
339+
self.__seed = seed
322340
self._reset_state_variables()
323341

324342
ELSE:
@@ -338,18 +356,41 @@ cdef class RandomState:
338356
stream : int, optional
339357
Generator stream to use
340358
359+
Raises
360+
------
361+
ValueError
362+
If seed values are out of range for the PRNG.
363+
TypeError
364+
If seed values are not scalar integers.
365+
341366
See Also
342367
--------
343368
RandomState
344369
"""
345-
if seed is None:
370+
ub = 2 ** (32 * RS_SEED_NBYTES)
371+
if seed is not None:
372+
error_msg = 'seed must be a scalar integer 0<=seed<{0}'.format(ub)
373+
_seed = np.asarray(seed, dtype=np.object)
374+
if _seed.ndim > 0:
375+
raise TypeError(error_msg)
376+
elif seed // 1 != seed:
377+
raise TypeError(error_msg)
378+
elif seed < 0 or seed >= ub:
379+
raise ValueError(error_msg)
380+
else:
346381
self.__seed = seed = _generate_seed(RS_SEED_NBYTES)
347-
elif seed < 0:
348-
raise ValueError('seed < 0')
349-
if stream is None:
382+
383+
if stream is not None:
384+
error_msg = 'stream must be a scalar integer 0<=stream<{0}'.format(ub)
385+
_stream= np.asarray(stream, dtype=np.object)
386+
if _stream.ndim > 0:
387+
raise TypeError(error_msg)
388+
elif stream // 1 != stream:
389+
raise TypeError(error_msg)
390+
elif stream < 0 or stream >= ub:
391+
raise ValueError(error_msg)
392+
else:
350393
self.__stream = stream = 1
351-
elif stream < 0:
352-
raise ValueError('stream < 0')
353394

354395
IF RS_RNG_MOD_NAME == 'pcg64':
355396
IF RS_PCG128_EMULATED:
@@ -359,7 +400,7 @@ cdef class RandomState:
359400
ELSE:
360401
set_seed(&self.rng_state, seed, stream)
361402
ELSE:
362-
set_seed(&self.rng_state, seed, stream)
403+
set_seed(&self.rng_state, <uint64_t>seed, <uint64_t>stream)
363404
self._reset_state_variables()
364405

365406
def _reset_state_variables(self):

randomstate/tests/test_against_numpy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,19 @@ def compare_2_input(f1, f2, is_np=False, is_scalar=False):
3838
((np.array([a] * 10), b), {}),
3939
((a, np.array([b] * 10)), {}),
4040
((a, np.array([b] * 10)), {'size': 10}),
41-
((np.array([[[a]] * 100]), np.array([b] * 10)), {'size': (100, 10)}),
41+
((np.reshape(np.array([[a] * 100]), (100,1)), np.array([b] * 10)), {'size': (100, 10)}),
4242
((np.ones((7, 31), dtype=dtype) * a, np.array([b] * 31)), {'size': (7, 31)}),
4343
((np.ones((7, 31), dtype=dtype) * a, np.array([b] * 31)), {'size': (10, 7, 31)})]
4444

4545
if is_scalar:
4646
inputs = inputs[:3]
4747

4848
for i in inputs:
49+
print(i[0], i[1])
4950
v1 = f1(*i[0], **i[1])
5051
v2 = f2(*i[0], **i[1])
5152
assert_allclose(v1, v2)
53+
print('OK!'*20)
5254

5355

5456
def compare_3_input(f1, f2, is_np=False):

randomstate/tests/test_direct.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from randomstate.prng.xorshift128 import xorshift128
1515
from randomstate.prng.xoroshiro128plus import xoroshiro128plus
1616
from randomstate.prng.dsfmt import dsfmt
17-
from numpy.testing import assert_equal, assert_allclose, assert_array_equal
17+
from numpy.testing import assert_equal, assert_allclose, assert_array_equal, \
18+
assert_raises
1819

1920
if (sys.version_info > (3, 0)):
2021
long = int
@@ -126,10 +127,11 @@ class Base(object):
126127
data2 = data1 = {}
127128

128129
@classmethod
129-
def setUpClass(cls):
130+
def setup_class(cls):
130131
cls.RandomState = xorshift128.RandomState
131132
cls.bits = 64
132133
cls.dtype = np.uint64
134+
cls.seed_error_type = TypeError
133135

134136
@classmethod
135137
def _read_csv(cls, filename):
@@ -189,86 +191,163 @@ def test_uniform_float(self):
189191
assert_allclose(uniforms, vals)
190192
assert_equal(uniforms.dtype, np.float32)
191193

194+
def test_seed_float(self):
195+
# GH #82
196+
rs = self.RandomState(*self.data1['seed'])
197+
assert_raises(self.seed_error_type, rs.seed, np.pi)
198+
assert_raises(self.seed_error_type, rs.seed, -np.pi)
199+
200+
def test_seed_float_array(self):
201+
# GH #82
202+
rs = self.RandomState(*self.data1['seed'])
203+
assert_raises(self.seed_error_type, rs.seed, np.array([np.pi]))
204+
assert_raises(self.seed_error_type, rs.seed, np.array([-np.pi]))
205+
assert_raises(self.seed_error_type, rs.seed, np.array([np.pi, -np.pi]))
206+
assert_raises(self.seed_error_type, rs.seed, np.array([0, np.pi]))
207+
assert_raises(self.seed_error_type, rs.seed, [np.pi])
208+
assert_raises(self.seed_error_type, rs.seed, [0, np.pi])
209+
210+
def test_seed_out_of_range(self):
211+
# GH #82
212+
rs = self.RandomState(*self.data1['seed'])
213+
assert_raises(ValueError, rs.seed, 2 ** (2 * self.bits+1))
214+
assert_raises(ValueError, rs.seed, -1)
215+
216+
def test_seed_out_of_range_array(self):
217+
# GH #82
218+
rs = self.RandomState(*self.data1['seed'])
219+
assert_raises(ValueError, rs.seed, [2 ** (2 * self.bits+1)])
220+
assert_raises(ValueError, rs.seed, [-1])
192221

193222
class TestXorshift128(Base, TestCase):
194223
@classmethod
195-
def setUpClass(cls):
224+
def setup_class(cls):
196225
cls.RandomState = xorshift128.RandomState
197226
cls.bits = 64
198227
cls.dtype = np.uint64
199228
cls.data1 = cls._read_csv(join(pwd, './data/xorshift128-testset-1.csv'))
200229
cls.data2 = cls._read_csv(join(pwd, './data/xorshift128-testset-2.csv'))
201230
cls.uniform32_func = uniform32_from_uint64
231+
cls.seed_error_type = TypeError
202232

203233

204234
class TestXoroshiro128plus(Base, TestCase):
205235
@classmethod
206-
def setUpClass(cls):
236+
def setup_class(cls):
207237
cls.RandomState = xoroshiro128plus.RandomState
208238
cls.bits = 64
209239
cls.dtype = np.uint64
210240
cls.data1 = cls._read_csv(join(pwd, './data/xoroshiro128plus-testset-1.csv'))
211241
cls.data2 = cls._read_csv(join(pwd, './data/xoroshiro128plus-testset-2.csv'))
212-
242+
cls.seed_error_type = TypeError
213243

214244
class TestXorshift1024(Base, TestCase):
215245
@classmethod
216-
def setUpClass(cls):
246+
def setup_class(cls):
217247
cls.RandomState = xorshift1024.RandomState
218248
cls.bits = 64
219249
cls.dtype = np.uint64
220250
cls.data1 = cls._read_csv(join(pwd, './data/xorshift1024-testset-1.csv'))
221251
cls.data2 = cls._read_csv(join(pwd, './data/xorshift1024-testset-2.csv'))
222-
252+
cls.seed_error_type = TypeError
223253

224254
class TestMT19937(Base, TestCase):
225255
@classmethod
226-
def setUpClass(cls):
256+
def setup_class(cls):
227257
cls.RandomState = mt19937.RandomState
228258
cls.bits = 32
229259
cls.dtype = np.uint32
230260
cls.data1 = cls._read_csv(join(pwd, './data/randomkit-testset-1.csv'))
231261
cls.data2 = cls._read_csv(join(pwd, './data/randomkit-testset-2.csv'))
262+
cls.seed_error_type = ValueError
263+
264+
def test_seed_out_of_range(self):
265+
# GH #82
266+
rs = self.RandomState(*self.data1['seed'])
267+
assert_raises(ValueError, rs.seed, 2 ** (self.bits + 1))
268+
assert_raises(ValueError, rs.seed, -1)
269+
assert_raises(ValueError, rs.seed, 2 ** (2 * self.bits+1))
270+
271+
def test_seed_out_of_range_array(self):
272+
# GH #82
273+
rs = self.RandomState(*self.data1['seed'])
274+
assert_raises(ValueError, rs.seed, [2 ** (self.bits + 1)])
275+
assert_raises(ValueError, rs.seed, [-1])
276+
assert_raises(TypeError, rs.seed, [2 ** (2 * self.bits+1)])
277+
278+
def test_seed_float(self):
279+
# GH #82
280+
rs = self.RandomState(*self.data1['seed'])
281+
assert_raises(TypeError, rs.seed, np.pi)
282+
assert_raises(TypeError, rs.seed, -np.pi)
283+
284+
def test_seed_float_array(self):
285+
# GH #82
286+
rs = self.RandomState(*self.data1['seed'])
287+
assert_raises(TypeError, rs.seed, np.array([np.pi]))
288+
assert_raises(TypeError, rs.seed, np.array([-np.pi]))
289+
assert_raises(TypeError, rs.seed, np.array([np.pi, -np.pi]))
290+
assert_raises(TypeError, rs.seed, np.array([0, np.pi]))
291+
assert_raises(TypeError, rs.seed, [np.pi])
292+
assert_raises(TypeError, rs.seed, [0, np.pi])
232293

233294

234295
class TestPCG32(Base, TestCase):
235296
@classmethod
236-
def setUpClass(cls):
297+
def setup_class(cls):
237298
cls.RandomState = pcg32.RandomState
238299
cls.bits = 32
239300
cls.dtype = np.uint32
240301
cls.data1 = cls._read_csv(join(pwd, './data/pcg32-testset-1.csv'))
241302
cls.data2 = cls._read_csv(join(pwd, './data/pcg32-testset-2.csv'))
303+
cls.seed_error_type = TypeError
304+
305+
def test_seed_out_of_range_array(self):
306+
# GH #82
307+
rs = self.RandomState(*self.data1['seed'])
308+
assert_raises(TypeError, rs.seed, [2 ** (self.bits + 1)])
309+
assert_raises(TypeError, rs.seed, [-1])
310+
assert_raises(TypeError, rs.seed, [2 ** (2 * self.bits+1)])
242311

243312

244313
class TestPCG64(Base, TestCase):
245314
@classmethod
246-
def setUpClass(cls):
315+
def setup_class(cls):
247316
cls.RandomState = pcg64.RandomState
248317
cls.bits = 64
249318
cls.dtype = np.uint64
250319
cls.data1 = cls._read_csv(join(pwd, './data/pcg64-testset-1.csv'))
251320
cls.data2 = cls._read_csv(join(pwd, './data/pcg64-testset-2.csv'))
321+
cls.seed_error_type = TypeError
322+
323+
def test_seed_out_of_range_array(self):
324+
# GH #82
325+
rs = self.RandomState(*self.data1['seed'])
326+
assert_raises(TypeError, rs.seed, [2 ** (self.bits + 1)])
327+
assert_raises(TypeError, rs.seed, [-1])
328+
assert_raises(TypeError, rs.seed, [2 ** (2 * self.bits+1)])
252329

253330

254331
class TestMRG32K3A(Base, TestCase):
255332
@classmethod
256-
def setUpClass(cls):
333+
def setup_class(cls):
257334
cls.RandomState = mrg32k3a.RandomState
258335
cls.bits = 32
259336
cls.dtype = np.uint32
260337
cls.data1 = cls._read_csv(join(pwd, './data/mrg32k3a-testset-1.csv'))
261338
cls.data2 = cls._read_csv(join(pwd, './data/mrg32k3a-testset-2.csv'))
339+
cls.seed_error_type = TypeError
262340

263341

264342
class TestMLFG(Base, TestCase):
265343
@classmethod
266-
def setUpClass(cls):
344+
def setup_class(cls):
267345
cls.RandomState = mlfg_1279_861.RandomState
268346
cls.bits = 63
269347
cls.dtype = np.uint64
270348
cls.data1 = cls._read_csv(join(pwd, './data/mlfg-testset-1.csv'))
271349
cls.data2 = cls._read_csv(join(pwd, './data/mlfg-testset-2.csv'))
350+
cls.seed_error_type = TypeError
272351

273352
def test_raw(self):
274353
rs = self.RandomState(*self.data1['seed'])
@@ -284,12 +363,13 @@ def test_raw(self):
284363

285364
class TestDSFMT(Base, TestCase):
286365
@classmethod
287-
def setUpClass(cls):
366+
def setup_class(cls):
288367
cls.RandomState = dsfmt.RandomState
289368
cls.bits = 53
290369
cls.dtype = np.uint64
291370
cls.data1 = cls._read_csv(join(pwd, './data/dSFMT-testset-1.csv'))
292371
cls.data2 = cls._read_csv(join(pwd, './data/dSFMT-testset-2.csv'))
372+
cls.seed_error_type = TypeError
293373

294374
def test_uniform_double(self):
295375
rs = self.RandomState(*self.data1['seed'])

0 commit comments

Comments
 (0)