Skip to content

Commit 6607b17

Browse files
committed
REF: Refactor to a common path for more data types
Refactor random integers to a commmon path
1 parent 118c3b6 commit 6607b17

File tree

5 files changed

+152
-4
lines changed

5 files changed

+152
-4
lines changed

randomstate/bounded_integers.pxi.in

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ cdef object _rand_{{nptype}}(low, high, size, aug_state *state, lock):
8484

8585
{{endfor}}
8686

87-
8887
cdef inline uint64_t _gen_mask(uint64_t max_val) nogil:
8988
# Smallest bit mask >= max
9089
cdef uint64_t mask = max_val
@@ -96,6 +95,137 @@ cdef inline uint64_t _gen_mask(uint64_t max_val) nogil:
9695
mask |= mask >> 32
9796
return mask
9897

98+
{{
99+
py:
100+
bc_ctypes = (('uint32', 'uint32', 'uint64', 'NPY_UINT64', 0, 0, 0, '0X100000000ULL'),
101+
('uint16', 'uint16', 'uint32', 'NPY_UINT32', 1, 16, 0, '0X10000UL'),
102+
('uint8', 'uint8', 'uint16', 'NPY_UINT16', 3, 8, 0, '0X100UL'),
103+
('bool','uint8', 'uint8', 'NPY_UINT8', 31, 1, 0, '0x2UL'),
104+
('int32', 'uint32', 'uint64', 'NPY_INT64', 0, 0, '-0x80000000LL', '0x80000000LL'),
105+
('int16', 'uint16', 'uint32', 'NPY_INT32', 1, 16, '-0x8000LL', '0x8000LL' ),
106+
('int8', 'uint8', 'uint16', 'NPY_INT16', 3, 8, '-0x80LL', '0x80LL' ),
107+
)}}
108+
109+
{{for nptype, utype, nptype_up, npctype, remaining, bitshift, lb, ub in bc_ctypes}}
110+
111+
{{ py: otype = nptype + '_' if nptype == 'bool' else nptype }}
112+
113+
cdef object _rand_{{nptype}}_combined(object low, object high, object size, aug_state *state, object lock):
114+
"""
115+
_rand_{{nptype}}_combined(low, high, size, *state, lock)
116+
117+
Return random np.{{nptype}} integers between `low` and `high`, inclusive.
118+
119+
Return random integers from the "discrete uniform" distribution in the
120+
closed interval [`low`, `high`). If `high` is None (the default),
121+
then results are from [0, `low`). On entry the arguments are presumed
122+
to have been validated for size and order for the np.{{nptype}} type.
123+
124+
Parameters
125+
----------
126+
low : int
127+
Lowest (signed) integer to be drawn from the distribution (unless
128+
``high=None``, in which case this parameter is the *highest* such
129+
integer).
130+
high : int
131+
If provided, the largest (signed) integer to be drawn from the
132+
distribution (see above for behavior if ``high=None``).
133+
size : int or tuple of ints
134+
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
135+
``m * n * k`` samples are drawn. Default is None, in which case a
136+
single value is returned.
137+
rngstate : encapsulated pointer to rk_state
138+
The specific type depends on the python version. In Python 2 it is
139+
a PyCObject, in Python 3 a PyCapsule object.
140+
141+
Returns
142+
-------
143+
out : python scalar or ndarray of np.{{nptype}}
144+
`size`-shaped array of random integers from the appropriate
145+
distribution, or a single such random int if `size` not provided.
146+
147+
"""
148+
cdef {{utype}}_t off, val, mask
149+
cdef uint32_t buf
150+
cdef {{utype}}_t *out
151+
cdef {{utype}}_t *out_data
152+
cdef {{nptype_up}}_t rng, last_rng, low_v, high_v
153+
cdef np.ndarray low_arr, high_arr, out_arr
154+
cdef np.npy_intp i, cnt
155+
cdef np.broadcast it
156+
cdef int buf_rem = 0
157+
158+
159+
low = np.asarray(low)
160+
high = np.asarray(high)
161+
if low.shape == high.shape == ():
162+
low = int(low) # TODO: Cast appropriately?
163+
high = int(high) # TODO: Cast appropriately?
164+
165+
if low < {{lb}}:
166+
raise ValueError("low is out of bounds for {{nptype}}")
167+
if high > {{ub}}:
168+
raise ValueError("high is out of bounds for {{nptype}}")
169+
if low >= high:
170+
raise ValueError("low >= high")
171+
172+
high -= 1
173+
rng = <{{utype}}_t>(high - low)
174+
off = <{{utype}}_t>(<{{nptype}}_t>low)
175+
if size is None:
176+
with lock:
177+
random_bounded_{{utype}}_fill(state, off, rng, 1, out)
178+
return np.{{otype}}(<{{nptype}}_t>out[0])
179+
else:
180+
out_arr = <np.ndarray>np.empty(size, np.{{nptype}})
181+
cnt = np.PyArray_SIZE(out_arr)
182+
out = <{{utype}}_t *>np.PyArray_DATA(out_arr)
183+
with lock, nogil:
184+
random_bounded_{{utype}}_fill(state, off, rng, cnt, out)
185+
return out_arr
186+
187+
# Array path
188+
low_arr = <np.ndarray>low
189+
high_arr = <np.ndarray>high
190+
if np.any(np.less(low_arr, {{lb}})):
191+
raise ValueError('low is out of bounds for {{nptype}}')
192+
if np.any(np.greater(high_arr, {{ub}})):
193+
raise ValueError('high is out of bounds for {{nptype}}')
194+
if np.any(np.greater_equal(low_arr, high_arr)):
195+
raise ValueError('low >= high')
196+
197+
low_arr = <np.ndarray>np.PyArray_FROM_OTF(low, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST)
198+
high_arr = <np.ndarray>np.PyArray_FROM_OTF(high, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST)
199+
200+
if size is not None:
201+
out_arr = <np.ndarray>np.empty(size, np.{{otype}})
202+
else:
203+
it = np.PyArray_MultiIterNew2(low_arr, high_arr)
204+
out_arr = <np.ndarray>np.empty(it.shape, np.{{otype}})
205+
206+
it = np.PyArray_MultiIterNew3(low_arr, high_arr, out_arr)
207+
out_data = <{{utype}}_t *>np.PyArray_DATA(out_arr)
208+
n = np.PyArray_SIZE(out_arr)
209+
mask = last_rng = 0
210+
with lock, nogil:
211+
for i in range(n):
212+
low_v = (<{{nptype_up}}_t*>np.PyArray_MultiIter_DATA(it, 0))[0]
213+
high_v = (<{{nptype_up}}_t*>np.PyArray_MultiIter_DATA(it, 1))[0]
214+
rng = (high_v - 1) - low_v
215+
off = <{{utype}}_t>(<{{nptype_up}}_t>low_v)
216+
217+
if rng != last_rng:
218+
# Smallest bit mask >= max
219+
mask = <{{utype}}_t>_gen_mask(rng)
220+
221+
out[i] = random_buffered_bounded_{{utype}}(state, off, rng, mask, &buf_rem, &buf)
222+
223+
np.PyArray_MultiIter_NEXT(it)
224+
225+
return out_arr
226+
{{endfor}}
227+
228+
99229
{{
100230
py:
101231
bc_ctypes = (('uint32', 'uint32', 'uint64', 'NPY_UINT64', 0, 0, 0, '0X100000000ULL'),

randomstate/distributions.c

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,8 +1419,14 @@ inline uint64_t random_bounded_uint64(aug_state *state, uint64_t off, uint64_t r
14191419
}
14201420

14211421

1422-
inline uint32_t random_bounded_uint32(aug_state *state, uint32_t off, uint32_t rng, uint32_t mask)
1422+
inline uint32_t random_buffered_bounded_uint32(aug_state *state, uint32_t off, uint32_t rng, uint32_t mask, int *bcnt, uint32_t *buf)
14231423
{
1424+
/*
1425+
* The buffer and buffer count are not used here but are included to allow
1426+
* this function to be templated with the similar uint8 and uint16
1427+
* functions
1428+
*/
1429+
14241430
uint32_t val;
14251431
if (rng == 0)
14261432
return off;
@@ -1489,11 +1495,13 @@ void random_bounded_uint32_fill(aug_state *state, uint32_t off, uint32_t rng, np
14891495
{
14901496
uint32_t val, mask;
14911497
npy_intp i;
1498+
uint32_t buf = 0;
1499+
int bcnt = 0;
14921500

14931501
/* Smallest bit mask >= max */
14941502
mask = (uint32_t)gen_mask(rng);
14951503
for (i = 0; i < cnt; i++) {
1496-
out[i] = random_bounded_uint32(state, off, rng, mask);
1504+
out[i] = random_buffered_bounded_uint32(state, off, rng, mask, &bcnt, &buf);
14971505
}
14981506
}
14991507

randomstate/distributions.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,9 @@ extern void random_standard_exponential_zig_float_fill(aug_state* state, npy_int
175175
extern double random_standard_gamma_zig_double(aug_state* state, double shape);
176176

177177
extern float random_standard_gamma_zig_float(aug_state* state, float shape);
178+
179+
inline uint32_t random_buffered_bounded_uint32(aug_state *state, uint32_t off, uint32_t rng, uint32_t mask, int *bcnt, uint32_t *buf);
180+
181+
inline uint16_t random_buffered_bounded_uint16(aug_state *state, uint16_t off, uint16_t rng, uint16_t mask, int *bcnt, uint32_t *buf);
182+
183+
inline uint8_t random_buffered_bounded_uint8(aug_state *state, uint8_t off, uint8_t rng, uint8_t mask, int *bcnt, uint32_t *buf);

randomstate/randomstate.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ cdef extern from "distributions.h":
6464
cdef uint64_t random_uint64(aug_state* state) nogil
6565
cdef uint32_t random_uint32(aug_state* state) nogil
6666
cdef uint64_t random_raw_values(aug_state* state) nogil
67+
68+
cdef uint32_t random_buffered_bounded_uint32(aug_state *state, uint32_t off, uint32_t rng, uint32_t mask, int *bcnt, uint32_t *buf) nogil
69+
cdef uint16_t random_buffered_bounded_uint16(aug_state *state, uint16_t off, uint16_t rng, uint16_t mask, int *bcnt, uint32_t *buf) nogil
70+
cdef uint8_t random_buffered_bounded_uint8(aug_state *state, uint8_t off, uint8_t rng, uint8_t mask, int *bcnt, uint32_t *buf) nogil
6771

6872
cdef long random_positive_int(aug_state* state) nogil
6973
cdef unsigned long random_uint(aug_state* state) nogil

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
'RNG_PCG32', 'RNG_PCG64', 'RNG_XORSHIFT128', 'RNG_XOROSHIRO128PLUS',
3535
'RNG_XORSHIFT1024', 'RNG_SFMT']
3636

37-
compile_rngs = rngs[:]
37+
compile_rngs = rngs[:1]
3838

3939
extra_defs = [('_CRT_SECURE_NO_WARNINGS', '1')] if os.name == 'nt' else []
4040
extra_link_args = ['/LTCG', 'Advapi32.lib', 'Kernel32.lib'] if os.name == 'nt' else []

0 commit comments

Comments
 (0)