Skip to content

Commit 0a69412

Browse files
committed
ENH: Allow broadcasting in randint
Allow array broadcasting for random integers
1 parent 397dc91 commit 0a69412

File tree

5 files changed

+382
-32
lines changed

5 files changed

+382
-32
lines changed

appveyor.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@ build_script:
2121
- SET PATH=C:\Py;C:\Py\Scripts;C:\Py\Library\bin;%PATH%
2222
- conda config --set always_yes yes
2323
- conda update conda --quiet
24-
- conda install numpy cython nose pandas --quiet
24+
- conda install numpy cython nose pandas pytest --quiet
2525
- python setup.py develop
2626
- set "GIT_DIR=%cd%"
2727

2828
test_script:
29-
- cd ..
30-
- nosetests randomstate
29+
- pytest randomstate
3130

3231
on_success:
3332
- cd %GIT_DIR%\randomstate

doc/source/change-log.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
Change Log
44
==========
5+
Since 1.13.2
6+
------------
7+
* Allow (:meth:`~randomstate.prng.mt19937.randint`) to broadcast inputs
8+
* Sync with upstream NumPy changes
9+
* Add protection against negative inputs in (:meth:`~randomstate.prng.mt19937.dirichlet`)
510

611
Version 1.13.2
712
--------------

randomstate/bounded_integers.pxi.in

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,241 @@ cdef object _rand_{{nptype}}(low, high, size, aug_state *state, lock):
8484

8585
{{endfor}}
8686

87+
88+
cdef inline uint64_t _gen_mask(uint64_t max_val) nogil:
89+
# Smallest bit mask >= max
90+
cdef uint64_t mask = max_val
91+
mask |= mask >> 1
92+
mask |= mask >> 2
93+
mask |= mask >> 4
94+
mask |= mask >> 8
95+
mask |= mask >> 16
96+
mask |= mask >> 32
97+
return mask
98+
99+
{{
100+
py:
101+
bc_ctypes = (('uint32', 'uint32', 'uint64', 'NPY_UINT64', 0, 0, 0, '0X100000000ULL'),
102+
('uint16', 'uint16', 'uint32', 'NPY_UINT32', 1, 16, 0, '0X10000UL'),
103+
('uint8', 'uint8', 'uint16', 'NPY_UINT16', 3, 8, 0, '0X100UL'),
104+
('bool','uint8', 'uint8', 'NPY_UINT8', 31, 1, 0, '0x2UL'),
105+
('int32', 'uint32', 'uint64', 'NPY_INT64', 0, 0, '-0x80000000LL', '0x80000000LL'),
106+
('int16', 'uint16', 'uint32', 'NPY_INT32', 1, 16, '-0x8000LL', '0x8000LL' ),
107+
('int8', 'uint8', 'uint16', 'NPY_INT16', 3, 8, '-0x80LL', '0x80LL' ),
108+
)}}
109+
110+
{{for nptype, utype, nptype_up, npctype, remaining, bitshift, lb, ub in bc_ctypes}}
111+
112+
{{ py: otype = nptype + '_' if nptype == 'bool' else nptype }}
113+
114+
cdef object _rand_{{nptype}}_broadcast(object low, object high, object size, aug_state *state, object lock):
115+
"""
116+
Generate bounded random {{nptype}} values using broadcasting
117+
118+
Parameters
119+
----------
120+
low : int or array-like
121+
Array containing the lowest (signed) integers to be drawn from the
122+
distribution.
123+
high : int or array-like
124+
Array containing the the open interval bound for the distribution.
125+
size : int or tuple of ints
126+
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
127+
``m * n * k`` samples are drawn. Default is None, in which case
128+
the output shape is determined by the broadcast shapes of low and
129+
high
130+
state : augmented random state
131+
State to use in the core random number generators
132+
lock : threading.Lock
133+
Lock to prevent multiple using a single RandomState simultaneously
134+
135+
Returns
136+
-------
137+
out : ndarray of np.{{nptype}}
138+
array of random integers from the appropriate distribution where the
139+
size is determined by size if provided or the broadcast shape of low
140+
and high
141+
"""
142+
cdef np.ndarray low_arr, high_arr, out_arr
143+
cdef np.npy_intp i
144+
cdef np.broadcast it
145+
cdef int buf_rem = 0
146+
147+
cdef {{utype}}_t *out_data
148+
cdef {{utype}}_t val, mask, off
149+
cdef {{nptype_up}}_t rng, last_rng, low_v, high_v
150+
cdef uint32_t buf
151+
152+
# TODO: Direct error check? Probably not
153+
# TODO: Make constant?
154+
low_arr = <np.ndarray>low
155+
high_arr = <np.ndarray>high
156+
if np.any(np.less(low_arr, {{lb}})):
157+
raise ValueError('low is out of bounds for {{nptype}}')
158+
if np.any(np.greater(high_arr, {{ub}})):
159+
raise ValueError('high is out of bounds for {{nptype}}')
160+
if np.any(np.greater_equal(low_arr, high_arr)):
161+
raise ValueError('low >= high')
162+
163+
low_arr = <np.ndarray>np.PyArray_FROM_OTF(low, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST)
164+
high_arr = <np.ndarray>np.PyArray_FROM_OTF(high, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST)
165+
166+
if size is not None:
167+
out_arr = <np.ndarray>np.empty(size, np.{{otype}})
168+
else:
169+
it = np.PyArray_MultiIterNew2(low_arr, high_arr)
170+
out_arr = <np.ndarray>np.empty(it.shape, np.{{otype}})
171+
172+
it = np.PyArray_MultiIterNew3(low_arr, high_arr, out_arr)
173+
out_data = <{{utype}}_t *>np.PyArray_DATA(out_arr)
174+
n = np.PyArray_SIZE(out_arr)
175+
mask = last_rng = 0
176+
with lock, nogil:
177+
for i in range(n):
178+
low_v = (<{{nptype_up}}_t*>np.PyArray_MultiIter_DATA(it, 0))[0]
179+
high_v = (<{{nptype_up}}_t*>np.PyArray_MultiIter_DATA(it, 1))[0]
180+
rng = (high_v - 1) - low_v
181+
off = <{{utype}}_t>(<{{nptype_up}}_t>low_v)
182+
if rng == 0:
183+
out_data[i] = 0
184+
continue
185+
186+
if rng != last_rng:
187+
# TODO: Is this too much of an optimization? is it worth it?
188+
# Smallest bit mask >= max
189+
mask = <{{utype}}_t>_gen_mask(rng)
190+
191+
while True:
192+
if not buf_rem:
193+
buf = random_uint32(state)
194+
buf_rem = {{remaining}}
195+
else:
196+
buf >>= {{bitshift}}
197+
buf_rem -= 1
198+
val = <{{utype}}_t>buf & mask
199+
if val <= rng:
200+
break
201+
out_data[i] = off + val
202+
203+
np.PyArray_MultiIter_NEXT(it)
204+
205+
return out_arr
206+
207+
{{endfor}}
208+
209+
210+
{{
211+
py:
212+
big_bc_ctypes = (('uint64', 'uint64', 'NPY_UINT64', '0x0ULL', '0xFFFFFFFFFFFFFFFFULL'),
213+
('int64', 'uint64', 'NPY_INT64', '-0x8000000000000000LL', '0x7FFFFFFFFFFFFFFFLL' )
214+
)}}
215+
216+
{{for nptype, utype, npctype, lb, ub in big_bc_ctypes}}
217+
218+
{{ py: otype = nptype}}
219+
220+
cdef object _rand_{{nptype}}_broadcast(object low, object high, object size, aug_state *state, object lock):
221+
"""
222+
Generate bounded random {{nptype}} values using broadcasting
223+
224+
Parameters
225+
----------
226+
low : int or array-like
227+
Array containing the lowest (signed) integers to be drawn from the
228+
distribution.
229+
high : int or array-like
230+
Array containing the the open interval bound for the distribution.
231+
size : int or tuple of ints
232+
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
233+
``m * n * k`` samples are drawn. Default is None, in which case
234+
the output shape is determined by the broadcast shapes of low and
235+
high
236+
state : augmented random state
237+
State to use in the core random number generators
238+
lock : threading.Lock
239+
Lock to prevent multiple using a single RandomState simultaneously
240+
241+
Returns
242+
-------
243+
out : ndarray of np.{{nptype}}
244+
array of random integers from the appropriate distribution where the
245+
size is determined by size if provided or the broadcast shape of low
246+
and high
247+
"""
248+
cdef np.ndarray low_arr, high_arr, out_arr, highm1_arr
249+
cdef np.npy_intp i
250+
cdef np.broadcast it
251+
cdef int buf_rem = 0
252+
cdef object closed_upper
253+
254+
cdef uint64_t *out_data
255+
cdef {{nptype}}_t *highm1_data
256+
cdef {{nptype}}_t low_v, high_v
257+
cdef uint64_t rng, last_rng, val, mask, off
258+
259+
low_arr = <np.ndarray>low
260+
high_arr = <np.ndarray>high
261+
262+
if np.any(np.less(low_arr, {{lb}})):
263+
raise ValueError('low is out of bounds for {{nptype}}')
264+
265+
highm1_arr = <np.ndarray>np.empty_like(high_arr, dtype=np.{{nptype}})
266+
highm1_data = <{{nptype}}_t *>np.PyArray_DATA(highm1_arr)
267+
n = np.PyArray_SIZE(high_arr)
268+
flat = high_arr.flat
269+
for i in range(n):
270+
closed_upper = int(flat[i]) - 1
271+
if closed_upper > {{ub}}:
272+
raise ValueError('high is out of bounds for {{nptype}}')
273+
if closed_upper < {{lb}}:
274+
raise ValueError('low >= high')
275+
276+
highm1_data[i] = <{{nptype}}_t>closed_upper
277+
278+
if np.any(np.greater(low_arr, highm1_arr)):
279+
raise ValueError('low >= high')
280+
281+
high_arr = highm1_arr
282+
low_arr = <np.ndarray>np.PyArray_FROM_OTF(low, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST)
283+
284+
if size is not None:
285+
out_arr = <np.ndarray>np.empty(size, np.{{nptype}})
286+
else:
287+
it = np.PyArray_MultiIterNew2(low_arr, high_arr)
288+
out_arr = <np.ndarray>np.empty(it.shape, np.{{nptype}})
289+
290+
it = np.PyArray_MultiIterNew3(low_arr, high_arr, out_arr)
291+
out_data = <uint64_t *>np.PyArray_DATA(out_arr)
292+
n = np.PyArray_SIZE(out_arr)
293+
mask = last_rng = 0
294+
with lock, nogil:
295+
for i in range(n):
296+
low_v = (<{{nptype}}_t*>np.PyArray_MultiIter_DATA(it, 0))[0]
297+
high_v = (<{{nptype}}_t*>np.PyArray_MultiIter_DATA(it, 1))[0]
298+
rng = <{{nptype}}_t>(high_v - low_v) # No -1 here since implemented above
299+
off = low_v
300+
if rng == 0:
301+
out_data[i] = 0
302+
continue
303+
304+
if rng != last_rng:
305+
mask = <uint64_t>_gen_mask(rng)
306+
307+
if rng <= 0xFFFFFFFFULL:
308+
while True:
309+
val = random_uint32(state) & mask
310+
if val <= rng:
311+
break
312+
else:
313+
while True:
314+
val = random_uint64(state) & mask
315+
if val <= rng:
316+
break
317+
318+
out_data[i] = off + val
319+
320+
np.PyArray_MultiIter_NEXT(it)
321+
322+
return out_arr
323+
324+
{{endfor}}

0 commit comments

Comments
 (0)