Skip to content

Commit de3c5eb

Browse files
committed
ENH: Add fast past for randint
Add fast path for 1d inputs Add tests to ensure all paths are tested
1 parent be0ce93 commit de3c5eb

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

randomstate/bounded_integers.pxi.in

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ cdef object _rand_{{nptype}}(object low, object high, object size, aug_state *st
8080

8181
low = np.asarray(low)
8282
high = np.asarray(high)
83-
if low.shape == high.shape == ():
83+
low_ndim = np.PyArray_NDIM(<np.ndarray>low)
84+
high_ndim = np.PyArray_NDIM(<np.ndarray>high)
85+
if ((low_ndim == 0 or (low_ndim==1 and low.size==1 and size is not None)) and
86+
(high_ndim == 0 or (high_ndim==1 and high.size==1 and size is not None))):
8487
low = int(low) # TODO: Cast appropriately?
8588
high = int(high) # TODO: Cast appropriately?
8689

@@ -199,7 +202,10 @@ cdef object _rand_{{nptype}}(object low, object high, object size, aug_state *st
199202

200203
low = np.asarray(low)
201204
high = np.asarray(high)
202-
if low.shape == high.shape == ():
205+
low_ndim = np.PyArray_NDIM(<np.ndarray>low)
206+
high_ndim = np.PyArray_NDIM(<np.ndarray>high)
207+
if ((low_ndim == 0 or (low_ndim==1 and low.size==1 and size is not None)) and
208+
(high_ndim == 0 or (high_ndim==1 and high.size==1 and size is not None))):
203209
low = int(low) # TODO: Cast appropriately?
204210
high = int(high) # TODO: Cast appropriately?
205211
high -= 1 # Use a closed interval

randomstate/tests/test_numpy_mt19937.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ def test_bounds_checking_array(self):
160160
for dt in self.itype:
161161
lbnd = 0 if dt is bool else np.iinfo(dt).min
162162
ubnd = 2 if dt is bool else np.iinfo(dt).max + 1
163-
assert_raises(ValueError, self.rfunc, [lbnd - 1], [ubnd], dtype=dt)
164-
assert_raises(ValueError, self.rfunc, [lbnd], [ubnd + 1], dtype=dt)
165-
assert_raises(ValueError, self.rfunc, ubnd, [lbnd], dtype=dt)
166-
assert_raises(ValueError, self.rfunc, [1], 0, dtype=dt)
163+
assert_raises(ValueError, self.rfunc, [lbnd - 1] * 2, [ubnd] * 2, dtype=dt)
164+
assert_raises(ValueError, self.rfunc, [lbnd] * 2, [ubnd + 1] * 2, dtype=dt)
165+
assert_raises(ValueError, self.rfunc, ubnd, [lbnd] * 2, dtype=dt)
166+
assert_raises(ValueError, self.rfunc, [1] * 2, 0, dtype=dt)
167167

168168
def test_rng_zero_and_extremes(self):
169169
for dt in self.itype:
@@ -180,19 +180,25 @@ def test_rng_zero_and_extremes(self):
180180
assert_equal(self.rfunc(tgt, tgt + 1, size=1000, dtype=dt), tgt)
181181

182182
def test_rng_zero_and_extremes_array(self):
183+
size = 1000
183184
for dt in self.itype:
184185
lbnd = 0 if dt is bool else np.iinfo(dt).min
185186
ubnd = 2 if dt is bool else np.iinfo(dt).max + 1
186187

187188
tgt = ubnd - 1
188-
print(dt)
189-
assert_equal(self.rfunc([tgt], [tgt + 1], size=1000, dtype=dt), tgt)
189+
assert_equal(self.rfunc([tgt], [tgt + 1], size=size, dtype=dt), tgt)
190+
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, dtype=dt), tgt)
191+
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, size=size, dtype=dt), tgt)
190192

191193
tgt = lbnd
192-
assert_equal(self.rfunc([tgt], [tgt + 1], size=1000, dtype=dt), tgt)
194+
assert_equal(self.rfunc([tgt], [tgt + 1], size=size, dtype=dt), tgt)
195+
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, dtype=dt), tgt)
196+
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, size=size, dtype=dt), tgt)
193197

194198
tgt = (lbnd + ubnd) // 2
195-
assert_equal(self.rfunc([tgt], [tgt + 1], size=1000, dtype=dt), tgt)
199+
assert_equal(self.rfunc([tgt], [tgt + 1], size=size, dtype=dt), tgt)
200+
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, dtype=dt), tgt)
201+
assert_equal(self.rfunc([tgt] * size, [tgt + 1] * size, size=size, dtype=dt), tgt)
196202

197203
def test_full_range(self):
198204
# Test for ticket #1690
@@ -222,6 +228,24 @@ def test_full_range_array(self):
222228
"but one was with the following "
223229
"message:\n\n%s" % str(e))
224230

231+
def test_scalar_array_equiv(self):
232+
for dt in self.itype:
233+
lbnd = 0 if dt is bool else np.iinfo(dt).min
234+
ubnd = 2 if dt is bool else np.iinfo(dt).max + 1
235+
236+
size = 1000
237+
mt19937.seed(1234)
238+
scalar = self.rfunc(lbnd, ubnd, size=size, dtype=dt)
239+
240+
mt19937.seed(1234)
241+
scalar_array = self.rfunc(lbnd, ubnd, size=size, dtype=dt)
242+
243+
mt19937.seed(1234)
244+
array = self.rfunc([lbnd] * size, [ubnd] * size, size=size, dtype=dt)
245+
assert_array_equal(scalar, scalar_array)
246+
assert_array_equal(scalar, array)
247+
248+
225249
def test_in_bounds_fuzz(self):
226250
# Don't use fixed seed
227251
mt19937.seed()
@@ -270,6 +294,27 @@ def test_repeatability(self):
270294
res = hashlib.md5(val).hexdigest()
271295
assert_(tgt[np.dtype(bool).name] == res)
272296

297+
def test_repeatability_broadcasting(self):
298+
299+
for dt in self.itype:
300+
301+
lbnd = 0 if dt in (np.bool, bool, np.bool_) else np.iinfo(dt).min
302+
ubnd = 2 if dt in (np.bool, bool, np.bool_) else np.iinfo(dt).max + 1
303+
304+
# view as little endian for hash
305+
mt19937.seed(1234)
306+
val = self.rfunc(lbnd, ubnd, size=1000, dtype=dt)
307+
308+
mt19937.seed(1234)
309+
val_bc = self.rfunc([lbnd] * 1000, ubnd, dtype=dt)
310+
311+
assert_array_equal(val, val_bc)
312+
313+
mt19937.seed(1234)
314+
val_bc = self.rfunc([lbnd] * 1000, [ubnd] * 1000, dtype=dt)
315+
316+
assert_array_equal(val, val_bc)
317+
273318
def test_int64_uint64_corner_case(self):
274319
# When stored in Numpy arrays, `lbnd` is casted
275320
# as np.int64, and `ubnd` is casted as np.uint64.
@@ -322,6 +367,8 @@ def test_respect_dtype_array(self):
322367

323368
sample = self.rfunc([lbnd], [ubnd], dtype=dt)
324369
self.assertEqual(sample.dtype, dt)
370+
sample = self.rfunc([lbnd] * 2, [ubnd] * 2, dtype=dt)
371+
self.assertEqual(sample.dtype, dt)
325372

326373

327374
class TestRandomDist(TestCase):

0 commit comments

Comments
 (0)