Skip to content

Commit f58ed17

Browse files
committed
ENH: Add empty support to randint
Add support for empty arrays to randint which allows empty choice Refactor tests to omit TestCase
1 parent 194724f commit f58ed17

File tree

2 files changed

+133
-110
lines changed

2 files changed

+133
-110
lines changed

randomstate/bounded_integers.pxi.in

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ cdef object _rand_{{nptype}}(object low, object high, object size, aug_state *st
7777
cdef np.broadcast it
7878
cdef int buf_rem = 0
7979

80+
if size is not None:
81+
if (np.prod(size) == 0):
82+
return np.empty(size, dtype=np.{{nptype}})
83+
8084
low = np.array(low, copy=False)
8185
high = np.array(high, copy=False)
8286
low_ndim = np.PyArray_NDIM(<np.ndarray>low)
@@ -129,10 +133,10 @@ cdef object _rand_{{nptype}}(object low, object high, object size, aug_state *st
129133

130134
it = np.PyArray_MultiIterNew3(low_arr, high_arr, out_arr)
131135
out_data = <{{utype}}_t *>np.PyArray_DATA(out_arr)
132-
n = np.PyArray_SIZE(out_arr)
136+
cnt = np.PyArray_SIZE(out_arr)
133137
mask = last_rng = 0
134138
with lock, nogil:
135-
for i in range(n):
139+
for i in range(cnt):
136140
low_v = (<{{nptype_up}}_t*>np.PyArray_MultiIter_DATA(it, 0))[0]
137141
high_v = (<{{nptype_up}}_t*>np.PyArray_MultiIter_DATA(it, 1))[0]
138142
rng = <{{utype}}_t>((high_v - 1) - low_v)
@@ -199,6 +203,10 @@ cdef object _rand_{{nptype}}(object low, object high, object size, aug_state *st
199203
cdef {{nptype}}_t low_v, high_v
200204
cdef uint64_t rng, last_rng, val, mask, off, out_val
201205

206+
if size is not None:
207+
if (np.prod(size) == 0):
208+
return np.empty(size, dtype=np.{{nptype}})
209+
202210
low = np.array(low, copy=False)
203211
high = np.array(high, copy=False)
204212
low_ndim = np.PyArray_NDIM(<np.ndarray>low)
@@ -238,9 +246,9 @@ cdef object _rand_{{nptype}}(object low, object high, object size, aug_state *st
238246

239247
highm1_arr = <np.ndarray>np.empty_like(high_arr, dtype=np.{{nptype}})
240248
highm1_data = <{{nptype}}_t *>np.PyArray_DATA(highm1_arr)
241-
n = np.PyArray_SIZE(high_arr)
249+
cnt = np.PyArray_SIZE(high_arr)
242250
flat = high_arr.flat
243-
for i in range(n):
251+
for i in range(cnt):
244252
closed_upper = int(flat[i]) - 1
245253
if closed_upper > {{ub}}:
246254
raise ValueError('high is out of bounds for {{nptype}}')

0 commit comments

Comments
 (0)