Skip to content

Commit 7fdfbf5

Browse files
committed
REF: Complete bounded interger refactor
Complete refactor by implementing uint64 and int64 refactorings
1 parent aff0f59 commit 7fdfbf5

File tree

2 files changed

+142
-60
lines changed

2 files changed

+142
-60
lines changed

randomstate/bounded_integers.pxi.in

Lines changed: 135 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ cdef object _rand_{{nptype}}_combined(object low, object high, object size, aug_
147147
"""
148148
cdef {{utype}}_t rng, last_rng, off, val, mask, out_val
149149
cdef uint32_t buf
150-
cdef {{utype}}_t *out
151150
cdef {{utype}}_t *out_data
152151
cdef {{nptype_up}}_t low_v, high_v
153152
cdef np.ndarray low_arr, high_arr, out_arr
@@ -179,9 +178,9 @@ cdef object _rand_{{nptype}}_combined(object low, object high, object size, aug_
179178
else:
180179
out_arr = <np.ndarray>np.empty(size, np.{{nptype}})
181180
cnt = np.PyArray_SIZE(out_arr)
182-
out = <{{utype}}_t *>np.PyArray_DATA(out_arr)
181+
out_data = <{{utype}}_t *>np.PyArray_DATA(out_arr)
183182
with lock, nogil:
184-
random_bounded_{{utype}}_fill(state, off, rng, cnt, out)
183+
random_bounded_{{utype}}_fill(state, off, rng, cnt, out_data)
185184
return out_arr
186185

187186
# Array path
@@ -432,20 +431,140 @@ cdef object _rand_{{nptype}}_broadcast(object low, object high, object size, aug
432431
continue
433432

434433
if rng != last_rng:
435-
mask = <uint64_t>_gen_mask(rng)
436-
437-
if rng <= 0xFFFFFFFFULL:
438-
while True:
439-
val = random_uint32(state) & mask
440-
if val <= rng:
441-
break
442-
else:
443-
while True:
444-
val = random_uint64(state) & mask
445-
if val <= rng:
446-
break
434+
mask = _gen_mask(rng)
435+
out_data[i] = random_bounded_uint64(state, off, rng, mask)
447436

448-
out_data[i] = off + val
437+
np.PyArray_MultiIter_NEXT(it)
438+
439+
return out_arr
440+
441+
{{endfor}}
442+
443+
444+
445+
{{
446+
py:
447+
big_bc_ctypes = (('uint64', 'uint64', 'NPY_UINT64', '0x0ULL', '0xFFFFFFFFFFFFFFFFULL'),
448+
('int64', 'uint64', 'NPY_INT64', '-0x8000000000000000LL', '0x7FFFFFFFFFFFFFFFLL' )
449+
)}}
450+
451+
{{for nptype, utype, npctype, lb, ub in big_bc_ctypes}}
452+
453+
{{ py: otype = nptype}}
454+
455+
cdef object _rand_{{nptype}}_combined(object low, object high, object size, aug_state *state, object lock):
456+
"""
457+
Generate bounded random {{nptype}} values using broadcasting
458+
459+
Parameters
460+
----------
461+
low : int or array-like
462+
Array containing the lowest (signed) integers to be drawn from the
463+
distribution.
464+
high : int or array-like
465+
Array containing the the open interval bound for the distribution.
466+
size : int or tuple of ints
467+
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
468+
``m * n * k`` samples are drawn. Default is None, in which case
469+
the output shape is determined by the broadcast shapes of low and
470+
high
471+
state : augmented random state
472+
State to use in the core random number generators
473+
lock : threading.Lock
474+
Lock to prevent multiple using a single RandomState simultaneously
475+
476+
Returns
477+
-------
478+
out : ndarray of np.{{nptype}}
479+
array of random integers from the appropriate distribution where the
480+
size is determined by size if provided or the broadcast shape of low
481+
and high
482+
"""
483+
cdef np.ndarray low_arr, high_arr, out_arr, highm1_arr
484+
cdef np.npy_intp i, cnt
485+
cdef np.broadcast it
486+
cdef object closed_upper
487+
cdef uint64_t *out_data
488+
cdef {{nptype}}_t *highm1_data
489+
cdef {{nptype}}_t low_v, high_v
490+
cdef uint64_t rng, last_rng, val, mask, off, out_val
491+
492+
low = np.asarray(low)
493+
high = np.asarray(high)
494+
if low.shape == high.shape == ():
495+
low = int(low) # TODO: Cast appropriately?
496+
high = int(high) # TODO: Cast appropriately?
497+
high -= 1 # Use a closed interval
498+
499+
if low < {{lb}}:
500+
raise ValueError("low is out of bounds for {{nptype}}")
501+
if high > {{ub}}:
502+
raise ValueError("high is out of bounds for {{nptype}}")
503+
if low > high:
504+
raise ValueError("low >= high")
505+
506+
rng = <{{utype}}_t>(high - low)
507+
off = <{{utype}}_t>(<{{nptype}}_t>low)
508+
if size is None:
509+
with lock:
510+
random_bounded_{{utype}}_fill(state, off, rng, 1, &out_val)
511+
return np.{{otype}}(<{{nptype}}_t>out_val)
512+
else:
513+
out_arr = <np.ndarray>np.empty(size, np.{{nptype}})
514+
cnt = np.PyArray_SIZE(out_arr)
515+
out_data = <{{utype}}_t *>np.PyArray_DATA(out_arr)
516+
with lock, nogil:
517+
random_bounded_{{utype}}_fill(state, off, rng, cnt, out_data)
518+
return out_arr
519+
520+
low_arr = <np.ndarray>low
521+
high_arr = <np.ndarray>high
522+
523+
if np.any(np.less(low_arr, {{lb}})):
524+
raise ValueError('low is out of bounds for {{nptype}}')
525+
526+
highm1_arr = <np.ndarray>np.empty_like(high_arr, dtype=np.{{nptype}})
527+
highm1_data = <{{nptype}}_t *>np.PyArray_DATA(highm1_arr)
528+
n = np.PyArray_SIZE(high_arr)
529+
flat = high_arr.flat
530+
for i in range(n):
531+
closed_upper = int(flat[i]) - 1
532+
if closed_upper > {{ub}}:
533+
raise ValueError('high is out of bounds for {{nptype}}')
534+
if closed_upper < {{lb}}:
535+
raise ValueError('low >= high')
536+
537+
highm1_data[i] = <{{nptype}}_t>closed_upper
538+
539+
if np.any(np.greater(low_arr, highm1_arr)):
540+
raise ValueError('low >= high')
541+
542+
high_arr = highm1_arr
543+
low_arr = <np.ndarray>np.PyArray_FROM_OTF(low, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST)
544+
545+
if size is not None:
546+
out_arr = <np.ndarray>np.empty(size, np.{{nptype}})
547+
else:
548+
it = np.PyArray_MultiIterNew2(low_arr, high_arr)
549+
out_arr = <np.ndarray>np.empty(it.shape, np.{{nptype}})
550+
551+
it = np.PyArray_MultiIterNew3(low_arr, high_arr, out_arr)
552+
out_data = <uint64_t *>np.PyArray_DATA(out_arr)
553+
n = np.PyArray_SIZE(out_arr)
554+
mask = last_rng = 0
555+
with lock, nogil:
556+
for i in range(n):
557+
low_v = (<{{nptype}}_t*>np.PyArray_MultiIter_DATA(it, 0))[0]
558+
high_v = (<{{nptype}}_t*>np.PyArray_MultiIter_DATA(it, 1))[0]
559+
rng = <{{nptype}}_t>(high_v - low_v) # No -1 here since implemented above
560+
off = low_v
561+
if rng == 0:
562+
out_data[i] = 0
563+
continue
564+
565+
if rng != last_rng:
566+
mask = _gen_mask(rng)
567+
out_data[i] = random_bounded_uint64(state, off, rng, mask)
449568

450569
np.PyArray_MultiIter_NEXT(it)
451570

randomstate/randomstate.pyx

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ cdef extern from "distributions.h":
6565
cdef uint32_t random_uint32(aug_state* state) nogil
6666
cdef uint64_t random_raw_values(aug_state* state) nogil
6767

68+
cdef uint64_t random_bounded_uint64(aug_state *state, uint64_t off, uint64_t rng, uint64_t mask) nogil
6869
cdef uint32_t random_buffered_bounded_uint32(aug_state *state, uint32_t off, uint32_t rng, uint32_t mask, int *bcnt, uint32_t *buf) nogil
6970
cdef uint16_t random_buffered_bounded_uint16(aug_state *state, uint16_t off, uint16_t rng, uint16_t mask, int *bcnt, uint32_t *buf) nogil
7071
cdef uint8_t random_buffered_bounded_uint8(aug_state *state, uint8_t off, uint8_t rng, uint8_t mask, int *bcnt, uint32_t *buf) nogil
@@ -1024,10 +1025,14 @@ cdef class RandomState:
10241025

10251026
if key == 'int32':
10261027
ret = _rand_int32_combined(low, high, size, &self.rng_state, self.lock)
1028+
elif key == 'int64':
1029+
ret = _rand_int64_combined(low, high, size, &self.rng_state, self.lock)
10271030
elif key == 'int16':
10281031
ret = _rand_int16_combined(low, high, size, &self.rng_state, self.lock)
10291032
elif key == 'int8':
10301033
ret = _rand_int8_combined(low, high, size, &self.rng_state, self.lock)
1034+
elif key == 'uint64':
1035+
ret = _rand_uint64_combined(low, high, size, &self.rng_state, self.lock)
10311036
elif key == 'uint32':
10321037
ret = _rand_uint32_combined(low, high, size, &self.rng_state, self.lock)
10331038
elif key == 'uint16':
@@ -1037,52 +1042,10 @@ cdef class RandomState:
10371042
elif key == 'bool':
10381043
ret = _rand_bool_combined(low, high, size, &self.rng_state, self.lock)
10391044

1040-
if key != 'int64' and key != 'uint64':
1041-
if size is None and dtype in (np.bool, np.int, np.long):
1042-
if np.array(ret).shape == ():
1043-
return dtype(ret)
1044-
return ret
1045-
1046-
lowbnd, highbnd = _randint_type[key]
1047-
1048-
low = np.asarray(low)
1049-
high = np.asarray(high)
1050-
1051-
if low.shape == high.shape == ():
1052-
# TODO: Do not cast these inputs to Python int
1053-
#
1054-
# This is a workaround until gh-8851 is resolved (bug in NumPy
1055-
# integer comparison and subtraction involving uint64 and non-
1056-
# uint64). Afterwards, remove these two lines.
1057-
ilow = int(low)
1058-
ihigh = int(high)
1059-
1060-
if ilow < lowbnd:
1061-
raise ValueError("low is out of bounds for %s" % (key,))
1062-
if ihigh > highbnd:
1063-
raise ValueError("high is out of bounds for %s" % (key,))
1064-
if ilow >= ihigh:
1065-
raise ValueError("low >= high")
1066-
1067-
if key == 'int64':
1068-
ret = _rand_int64(ilow, ihigh - 1, size, &self.rng_state, self.lock)
1069-
elif key == 'uint64':
1070-
ret = _rand_uint64(ilow, ihigh - 1, size, &self.rng_state, self.lock)
1071-
elif key == 'bool':
1072-
ret = _rand_bool(ilow, ihigh - 1, size, &self.rng_state, self.lock)
10731045

1074-
if size is None:
1075-
if dtype in (np.bool, np.int, np.long):
1046+
if size is None and dtype in (np.bool, np.int, np.long):
1047+
if np.array(ret).shape == ():
10761048
return dtype(ret)
1077-
return ret
1078-
1079-
if key == 'int64':
1080-
ret = _rand_int64_broadcast(low, high, size, &self.rng_state, self.lock)
1081-
elif key == 'uint64':
1082-
ret = _rand_uint64_broadcast(low, high, size, &self.rng_state, self.lock)
1083-
elif key == 'bool':
1084-
ret = _rand_bool_broadcast(low, high, size, &self.rng_state, self.lock)
1085-
10861049
return ret
10871050

10881051
def bytes(self, np.npy_intp length):

0 commit comments

Comments
 (0)