Skip to content

Commit dedc331

Browse files
committed
REF: Make entropy init seed visible
Use public API to get entropy initialization information Save seed/stream info into object Remvoe extraneous print statements
1 parent 0aaee30 commit dedc331

File tree

12 files changed

+84
-45
lines changed

12 files changed

+84
-45
lines changed

randomstate/defaults.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ DEF RS_RNG_SEED = 1
33
DEF RS_RNG_ADVANCEABLE = 0
44
DEF RS_RNG_JUMPABLE = 0
55
DEF RS_RNG_STATE_LEN = 4
6+
DEF RS_SEED_NBYTES = 2

randomstate/entropy.pyx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def random_entropy(size=None, source='system'):
6565
else:
6666
n = compute_numel(size)
6767
randoms = np.zeros(n, dtype=np.uint32)
68-
print(n)
6968
if source == 'system':
7069
success = entropy_getbytes(<void *>(&randoms[0]), 4 * n)
7170
else:

randomstate/interface.pyx

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free
2121
import randomstate
2222
from binomial cimport binomial_t
2323
from cython_overrides cimport PyFloat_AsDouble, PyInt_AsLong, PyErr_Occurred, PyErr_Clear
24+
from randomstate.entropy import random_entropy
2425

2526
np.import_array()
2627

@@ -119,6 +120,20 @@ include "array_utilities.pxi"
119120
include "bounded_integers.pxi"
120121
include "aligned_malloc.pxi"
121122

123+
cdef object _generate_seed(nbytes):
124+
try:
125+
seeds = random_entropy(nbytes)
126+
except:
127+
seeds = random_entropy(nbytes, 'fallback')
128+
if nbytes == 1:
129+
return seeds[0]
130+
131+
seed = long(0)
132+
for i in range(nbytes):
133+
scale = 2 ** (32 * i)
134+
seed += scale * long(seeds[i])
135+
return seed
136+
122137
cdef double kahan_sum(double *darr, np.npy_intp n):
123138
cdef double c, y, t, sum
124139
cdef np.npy_intp i
@@ -148,7 +163,8 @@ cdef class RandomState:
148163
cdef object lock
149164
poisson_lam_max = POISSON_LAM_MAX
150165
__MAXSIZE = <uint64_t>sys.maxsize
151-
166+
cdef object __seed
167+
cdef object __stream
152168

153169
IF RS_RNG_SEED==1:
154170
def __init__(self, seed=None):
@@ -157,15 +173,22 @@ cdef class RandomState:
157173
IF RS_RNG_MOD_NAME == 'dsfmt':
158174
self.rng_state.buffered_uniforms = <double *>PyArray_malloc_aligned(2 * DSFMT_N * sizeof(double))
159175
self.lock = Lock()
176+
self.__seed = seed
177+
self.__stream = None
178+
160179
self._reset_state_variables()
161180
self.seed(seed)
162181
ELSE:
163-
def __init__(self, seed=None, inc=None):
182+
def __init__(self, seed=None, stream=None):
164183
self.rng_state.rng = <rng_t *>PyArray_malloc_aligned(sizeof(rng_t))
165184
self.rng_state.binomial = &self.binomial_info
166185
self.lock = Lock()
186+
self.__seed = seed
187+
self.__stream = stream
188+
167189
self._reset_state_variables()
168-
self.seed(seed, inc)
190+
self.seed(seed, stream)
191+
169192

170193
def __dealloc__(self):
171194
PyArray_free_aligned(self.rng_state.rng)
@@ -196,8 +219,9 @@ cdef class RandomState:
196219
# cdef ndarray obj "arrayObject_obj"
197220
try:
198221
if seed is None:
222+
self.__seed = seed = _generate_seed(1)
199223
with self.lock:
200-
entropy_init(&self.rng_state)
224+
set_seed(&self.rng_state, seed)
201225
else:
202226
idx = operator.index(seed)
203227
if idx > int(2**32 - 1) or idx < 0:
@@ -214,7 +238,7 @@ cdef class RandomState:
214238
self._reset_state_variables()
215239

216240
ELIF RS_RNG_SEED==1:
217-
def seed(self, val=None):
241+
def seed(self, seed=None):
218242
"""
219243
seed(seed=None)
220244
@@ -225,7 +249,7 @@ cdef class RandomState:
225249
226250
Parameters
227251
----------
228-
val : int, optional
252+
seed : int, optional
229253
Seed for ``RandomState``.
230254
231255
Notes
@@ -239,18 +263,18 @@ cdef class RandomState:
239263
--------
240264
RandomState
241265
"""
242-
if val is not None:
243-
if val < 0:
244-
raise ValueError('val < 0')
245-
set_seed(&self.rng_state, val)
266+
if seed is not None:
267+
if seed < 0:
268+
raise ValueError('seed < 0')
246269
else:
247-
entropy_init(&self.rng_state)
270+
self.__seed = seed = _generate_seed(RS_SEED_NBYTES)
271+
set_seed(&self.rng_state, seed)
248272
self._reset_state_variables()
249273

250274
ELSE:
251-
def seed(self, val=None, inc=None):
275+
def seed(self, seed=None, stream=None):
252276
"""
253-
seed(val=None, inc=None)
277+
seed(seed=None, stream=None)
254278
255279
Seed the generator.
256280
@@ -259,31 +283,33 @@ cdef class RandomState:
259283
260284
Parameters
261285
----------
262-
val : int, optional
286+
seed : int, optional
263287
Seed for ``RandomState``.
264-
inc : int, optional
265-
Increment to use for producing multiple streams
288+
stream : int, optional
289+
Generator stream to use
266290
267291
See Also
268292
--------
269293
RandomState
270294
"""
271-
if val is not None and inc is not None:
272-
if val < 0:
273-
raise ValueError('val < 0')
274-
if inc < 0:
275-
raise ValueError('inc < 0')
276-
IF RS_RNG_MOD_NAME == 'pcg64':
277-
IF RS_PCG128_EMULATED:
278-
set_seed(&self.rng_state,
279-
pcg128_from_pylong(val),
280-
pcg128_from_pylong(inc))
281-
ELSE:
282-
set_seed(&self.rng_state, val, inc)
295+
if seed is None:
296+
self.__seed = seed = _generate_seed(RS_SEED_NBYTES)
297+
elif seed < 0:
298+
raise ValueError('seed < 0')
299+
if stream is None:
300+
self.__stream = stream = 1
301+
elif stream < 0:
302+
raise ValueError('stream < 0')
303+
304+
IF RS_RNG_MOD_NAME == 'pcg64':
305+
IF RS_PCG128_EMULATED:
306+
set_seed(&self.rng_state,
307+
pcg128_from_pylong(seed),
308+
pcg128_from_pylong(stream))
283309
ELSE:
284-
set_seed(&self.rng_state, val, inc)
285-
else:
286-
entropy_init(&self.rng_state)
310+
set_seed(&self.rng_state, seed, stream)
311+
ELSE:
312+
set_seed(&self.rng_state, seed, stream)
287313
self._reset_state_variables()
288314

289315
def _reset_state_variables(self):
@@ -400,11 +426,14 @@ cdef class RandomState:
400426
+ _get_state(self.rng_state) \
401427
+ (self.rng_state.has_gauss, self.rng_state.gauss)
402428

403-
return {'name': rng_name,
429+
state = {'name': rng_name,
404430
'state': _get_state(self.rng_state),
405431
'gauss': {'has_gauss': self.rng_state.has_gauss, 'gauss': self.rng_state.gauss},
406-
'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger}
407-
}
432+
'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger},
433+
'seed': self.__seed}
434+
if self.__stream is not None:
435+
state['stream'] = self.__stream
436+
return state
408437
ELSE:
409438
def get_state(self):
410439
"""
@@ -438,11 +467,14 @@ cdef class RandomState:
438467
component, see the class documentation.
439468
"""
440469
rng_name = _ensure_string(RS_RNG_NAME)
441-
return {'name': rng_name,
470+
state = {'name': rng_name,
442471
'state': _get_state(self.rng_state),
443472
'gauss': {'has_gauss': self.rng_state.has_gauss, 'gauss': self.rng_state.gauss},
444-
'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger}
445-
}
473+
'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger},
474+
'seed': self.__seed}
475+
if self.__stream is not None:
476+
state['stream'] = self.__stream
477+
return state
446478

447479
def set_state(self, state):
448480
"""
@@ -505,6 +537,8 @@ cdef class RandomState:
505537
self.rng_state.gauss = state['gauss']['gauss']
506538
self.rng_state.has_uint32 = state['uint32']['has_uint32']
507539
self.rng_state.uinteger = state['uint32']['uint32']
540+
self.__seed = state['seed']
541+
self.__stream = state['stream'] if 'stream' in state else None
508542

509543
def random_uintegers(self, size=None, int bits=64):
510544
"""
@@ -566,7 +600,9 @@ cdef class RandomState:
566600
self.set_state(state)
567601

568602
def __reduce__(self):
569-
return (randomstate.prng.__generic_ctor, (RS_RNG_MOD_NAME,), self.get_state())
603+
return (randomstate.prng.__generic_ctor,
604+
(_ensure_string(RS_RNG_MOD_NAME),),
605+
self.get_state())
570606

571607
# Basic distributions:
572608
def random_sample(self, size=None):

randomstate/prng/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@ def __generic_ctor(mod_name='mt19937'):
2121
rs: RandomState
2222
RandomState from the module randomstate.prng.mod_name
2323
"""
24-
print(mod_name)
2524
try:
2625
mod_name = mod_name.decode('ascii')
2726
except AttributeError:
2827
pass
29-
print(mod_name)
3028
if mod_name == 'mt19937':
3129
mod = mt19937
3230
elif mod_name == 'mlfg_1279_861':

randomstate/shims/dSFMT/dSFMT.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ DEF RS_RNG_JUMPABLE = 1
33
DEF DSFMT_MEXP = 19937
44
DEF DSFMT_N = 191 # ((DSFMT_MEXP - 128) / 104 + 1)
55
DEF DSFMT_N_PLUS_1 = 192 # DSFMT_N + 1
6+
DEF RS_SEED_NBYTES = 1
67

78
ctypedef uint32_t rng_state_t
89

randomstate/shims/pcg-64/pcg-64-emulated.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ DEF RS_RNG_NAME = 'pcg64'
22
DEF RS_RNG_ADVANCEABLE = 1
33
DEF RS_RNG_SEED=2
44
DEF RS_PCG128_EMULATED = 1
5+
DEF RS_SEED_NBYTES = 4
56

67
from cpython cimport PyLong_FromUnsignedLongLong, PyLong_AsUnsignedLongLong
78

randomstate/shims/pcg-64/pcg-64.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ DEF RS_RNG_NAME = 'pcg64'
22
DEF RS_RNG_ADVANCEABLE = 1
33
DEF RS_RNG_SEED=2
44
DEF RS_PCG128_EMULATED = 0
5+
DEF RS_SEED_NBYTES = 4
56

67
cdef extern from "inttypes.h":
78
ctypedef unsigned long long __uint128_t

randomstate/shims/random-kit/random-kit.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
DEF RS_RNG_NAME = 'mt19937'
22
DEF RS_NORMAL_METHOD = 'inv'
33
DEF RK_STATE_LEN = 624
4+
DEF RS_SEED_NBYTES = 1
45

56
ctypedef uint32_t rng_state_t
67

randomstate/tests/test_against_numpy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,6 @@ def test_array(self):
510510
def test_dir(self):
511511
nprs_d = dir(self.nprs)
512512
rs_d = dir(self.rs)
513-
print(set(nprs_d).difference(rs_d))
514513
assert(len(set(nprs_d).difference(rs_d)) == 0)
515514

516515
npmod = dir(numpy.random)
@@ -519,7 +518,6 @@ def test_dir(self):
519518
'__RandomState_ctor', 'mtrand', 'test',
520519
'__warningregistry__']
521520
mod += known_exlcuded
522-
print(set(npmod).difference(mod))
523521
assert(len(set(npmod).difference(mod)) == 0)
524522

525523

randomstate/tests/test_smoke.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,15 @@ def test_pickle(self):
429429
pick = pickle.dumps(self.rs)
430430
unpick = pickle.loads(pick)
431431
assert (type(self.rs) == type(unpick))
432+
print(self.rs.get_state())
433+
print(unpick.get_state())
432434
assert comp_state(self.rs.get_state(), unpick.get_state())
433435

434436
pick = cPickle.dumps(self.rs)
435437
unpick = cPickle.loads(pick)
436438
assert (type(self.rs) == type(unpick))
439+
print(self.rs.get_state())
440+
print(unpick.get_state())
437441
assert comp_state(self.rs.get_state(), unpick.get_state())
438442

439443

0 commit comments

Comments
 (0)