Skip to content

Commit a55fa1c

Browse files
committed
TST: Add tests for 32-bit generators and raw data
Add smoke tests for 32 bit uniforms and normals
1 parent a765a98 commit a55fa1c

File tree

5 files changed

+106
-31
lines changed

5 files changed

+106
-31
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ install:
5353
script:
5454
- nosetests randomstate
5555
- cd $BUILD_DIR/randomstate
56-
- if [ ${PYTHON} = "2.7" ]; then python performance.py; fi
56+
- if [ ${PYTHON} = "3.5" ]; then python performance.py; fi

README.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ New Functions
8888

8989
- ``random_entropy`` - Read from the system entropy provider, which is
9090
commonly used in cryptographic applications
91-
- ``random_uintegers`` - unsigned integers ``[0, 2**64-1]``
9291
- ``random_raw`` - Direct access to the values produced by the
9392
underlying PRNG. The range of the values returned depends on the
9493
specifics of the PRNG implementation.

randomstate/performance.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import timeit
44

55
import pandas as pd
6+
import numpy as np
67
from numpy.random import RandomState
78

89
rs = RandomState()
910

1011
SETUP = '''
12+
import numpy as np
1113
import {mod}.{rng}
1214
rs = {mod}.{rng}.RandomState()
1315
rs.random_sample()
@@ -70,6 +72,26 @@ def timer_uniform():
7072
run_timer(dist, command, None, SETUP, 'Uniforms')
7173

7274

75+
def timer_32bit():
76+
info = np.iinfo(np.uint32)
77+
min, max = info.min, info.max
78+
dist = 'randint'
79+
command = 'rs.{dist}({min}, {max}+1, 1000000, dtype=np.uint64)'
80+
command = command.format(dist='{dist}', min=min, max=max)
81+
command_numpy = command
82+
run_timer(dist, command, None, SETUP, '32-bit unsigned integers')
83+
84+
85+
def timer_64bit():
86+
info = np.iinfo(np.uint64)
87+
min, max = info.min, info.max
88+
dist = 'randint'
89+
command = 'rs.{dist}({min}, {max}+1, 1000000, dtype=np.uint64)'
90+
command = command.format(dist='{dist}', min=min, max=max)
91+
command_numpy = command
92+
run_timer(dist, command, None, SETUP, '64-bit unsigned integers')
93+
94+
7395
def timer_normal():
7496
command = 'rs.{dist}(1000000, method="bm")'
7597
command_numpy = 'rs.{dist}(1000000)'

randomstate/tests/test_direct.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from randomstate.prng.xorshift128 import xorshift128
1515
from randomstate.prng.xoroshiro128plus import xoroshiro128plus
1616
from randomstate.prng.dsfmt import dsfmt
17-
from numpy.testing import assert_equal, assert_allclose
17+
from numpy.testing import assert_equal, assert_allclose, assert_array_equal
1818

1919
if (sys.version_info > (3, 0)):
2020
long = int
@@ -142,18 +142,14 @@ def _read_csv(cls, filename):
142142
data.append(long(line.split(',')[-1]))
143143
return {'seed': seed, 'data': np.array(data, dtype=cls.dtype)}
144144

145-
def test_double(self):
145+
def test_raw(self):
146146
rs = self.RandomState(*self.data1['seed'])
147-
vals = uniform_from_uint(self.data1['data'], self.bits)
148-
uniforms = rs.random_sample(len(vals))
149-
assert_allclose(uniforms, vals)
150-
assert_equal(uniforms.dtype, np.float64)
147+
uints = rs.random_raw(1000)
148+
assert_equal(uints, self.data1['data'])
151149

152150
rs = self.RandomState(*self.data2['seed'])
153-
vals = uniform_from_uint(self.data2['data'], self.bits)
154-
uniforms = rs.random_sample(len(vals))
155-
assert_allclose(uniforms, vals)
156-
assert_equal(uniforms.dtype, np.float64)
151+
uints = rs.random_raw(1000)
152+
assert_equal(uints, self.data2['data'])
157153

158154
def test_gauss_inv(self):
159155
n = 25
@@ -167,7 +163,20 @@ def test_gauss_inv(self):
167163
assert_allclose(gauss,
168164
gauss_from_uint(self.data2['data'], n, self.bits))
169165

170-
def test_32bit_uniform(self):
166+
def test_uniform_double(self):
167+
rs = self.RandomState(*self.data1['seed'])
168+
vals = uniform_from_uint(self.data1['data'], self.bits)
169+
uniforms = rs.random_sample(len(vals))
170+
assert_allclose(uniforms, vals)
171+
assert_equal(uniforms.dtype, np.float64)
172+
173+
rs = self.RandomState(*self.data2['seed'])
174+
vals = uniform_from_uint(self.data2['data'], self.bits)
175+
uniforms = rs.random_sample(len(vals))
176+
assert_allclose(uniforms, vals)
177+
assert_equal(uniforms.dtype, np.float64)
178+
179+
def test_uniform_float(self):
171180
rs = self.RandomState(*self.data1['seed'])
172181
vals = uniform32_from_uint(self.data1['data'], self.bits)
173182
uniforms = rs.random_sample(len(vals), dtype=np.float32)
@@ -261,6 +270,17 @@ def setUpClass(cls):
261270
cls.data1 = cls._read_csv(join(pwd, './data/mlfg-testset-1.csv'))
262271
cls.data2 = cls._read_csv(join(pwd, './data/mlfg-testset-2.csv'))
263272

273+
def test_raw(self):
274+
rs = self.RandomState(*self.data1['seed'])
275+
mod_data = self.data1['data'] >> np.uint64(1)
276+
uints = rs.random_raw(1000)
277+
assert_equal(uints, mod_data)
278+
279+
rs = self.RandomState(*self.data2['seed'])
280+
mod_data = self.data2['data'] >> np.uint64(1)
281+
uints = rs.random_raw(1000)
282+
assert_equal(uints, mod_data)
283+
264284

265285
class TestDSFMT(Base, TestCase):
266286
@classmethod
@@ -271,10 +291,11 @@ def setUpClass(cls):
271291
cls.data1 = cls._read_csv(join(pwd, './data/dSFMT-testset-1.csv'))
272292
cls.data2 = cls._read_csv(join(pwd, './data/dSFMT-testset-2.csv'))
273293

274-
def test_double(self):
294+
def test_uniform_double(self):
275295
rs = self.RandomState(*self.data1['seed'])
276-
assert_equal(uniform_from_dsfmt(self.data1['data']),
277-
rs.random_sample(1000))
296+
aa = uniform_from_dsfmt(self.data1['data'])
297+
assert_array_equal(uniform_from_dsfmt(self.data1['data']),
298+
rs.random_sample(1000))
278299

279300
rs = self.RandomState(*self.data2['seed'])
280301
assert_equal(uniform_from_dsfmt(self.data2['data']),

randomstate/tests/test_smoke.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from randomstate.prng.xorshift128 import xorshift128
2020
from randomstate.prng.xoroshiro128plus import xoroshiro128plus
2121
from randomstate.prng.dsfmt import dsfmt
22-
from numpy.testing import assert_almost_equal, assert_equal, assert_raises, assert_
22+
from numpy.testing import assert_almost_equal, assert_equal, assert_raises, assert_, assert_array_equal
2323

2424
from nose import SkipTest
2525

@@ -84,6 +84,19 @@ def comp_state(state1, state2):
8484
return identical
8585

8686

87+
def warmup(rs, n=None):
88+
if n is None:
89+
n = 11 + np.random.randint(0, 20)
90+
rs.standard_normal(n, method='bm')
91+
rs.standard_normal(n, method='zig')
92+
rs.standard_normal(n, method='bm', dtype=np.float32)
93+
rs.randint(0, 2 ** 24, n, dtype=np.uint64)
94+
rs.randint(0, 2 ** 48, n, dtype=np.uint64)
95+
rs.standard_gamma(11, n)
96+
rs.random_sample(n, dtype=np.float64)
97+
rs.random_sample(n, dtype=np.float32)
98+
99+
87100
class RNG(object):
88101
@classmethod
89102
def _extra_setup(cls):
@@ -121,7 +134,7 @@ def test_jump(self):
121134

122135
def test_random_raw(self):
123136
assert_(len(self.rs.random_raw(10)) == 10)
124-
assert_(self.rs.random_raw((10,10)).shape == (10,10))
137+
assert_(self.rs.random_raw((10, 10)).shape == (10, 10))
125138

126139
def test_uniform(self):
127140
r = self.rs.uniform(-1.0, 0.0, size=10)
@@ -202,17 +215,17 @@ def test_reset_state_gauss(self):
202215
rs2 = self.mod.RandomState()
203216
rs2.set_state(state)
204217
n2 = rs2.standard_normal(size=10)
205-
assert_((n1 == n2).all())
218+
assert_array_equal(n1, n2)
206219

207220
def test_reset_state_uint32(self):
208221
rs = self.mod.RandomState(*self.seed)
209-
rs.randint(0, 2 ** 24, dtype=np.uint32)
222+
rs.randint(0, 2 ** 24, 120, dtype=np.uint32)
210223
state = rs.get_state()
211-
n1 = rs.randint(0, 2**24, 10, dtype=np.uint32)
224+
n1 = rs.randint(0, 2 ** 24, 10, dtype=np.uint32)
212225
rs2 = self.mod.RandomState()
213226
rs2.set_state(state)
214-
n2 = rs.randint(0, 2**24, 10, dtype=np.uint32)
215-
assert_((n1 == n2).all())
227+
n2 = rs2.randint(0, 2 ** 24, 10, dtype=np.uint32)
228+
assert_array_equal(n1, n2)
216229

217230
def test_shuffle(self):
218231
original = np.arange(200, 0, -1)
@@ -487,10 +500,10 @@ def test_seed_array(self):
487500
def test_seed_array_error(self):
488501
if self.seed_vector_bits == 32:
489502
dtype = np.uint32
490-
out_of_bounds = 2**32
503+
out_of_bounds = 2 ** 32
491504
else:
492505
dtype = np.uint64
493-
out_of_bounds = 2**64
506+
out_of_bounds = 2 ** 64
494507

495508
seed = -1
496509
assert_raises(ValueError, self.rs.seed, seed)
@@ -504,6 +517,32 @@ def test_seed_array_error(self):
504517
seed = np.array([1, 2, 3, out_of_bounds])
505518
assert_raises(ValueError, self.rs.seed, seed)
506519

520+
def test_uniform_float(self):
521+
rs = self.mod.RandomState(12345)
522+
warmup(rs)
523+
state = rs.get_state()
524+
r1 = rs.random_sample(11, dtype=np.float32)
525+
rs2 = self.mod.RandomState()
526+
warmup(rs2)
527+
rs2.set_state(state)
528+
r2 = rs2.random_sample(11, dtype=np.float32)
529+
assert_array_equal(r1, r2)
530+
assert_equal(r1.dtype, np.float32)
531+
assert_(comp_state(rs.get_state(), rs2.get_state()))
532+
533+
def test_normal_floats(self):
534+
rs = self.mod.RandomState()
535+
warmup(rs)
536+
state = rs.get_state()
537+
r1 = rs.standard_normal(11, method='bm', dtype=np.float32)
538+
rs2 = self.mod.RandomState()
539+
warmup(rs2)
540+
rs2.set_state(state)
541+
r2 = rs2.standard_normal(11, method='bm', dtype=np.float32)
542+
assert_array_equal(r1, r2)
543+
assert_equal(r1.dtype, np.float32)
544+
assert_(comp_state(rs.get_state(), rs2.get_state()))
545+
507546

508547
class TestMT19937(RNG):
509548
@classmethod
@@ -642,9 +681,3 @@ def test_fallback(self):
642681
time.sleep(0.1)
643682
e2 = entropy.random_entropy(source='fallback')
644683
assert_((e1 != e2))
645-
646-
647-
if __name__ == '__main__':
648-
import nose
649-
650-
nose.run(argv=[__file__, '-vv'])

0 commit comments

Comments
 (0)