Skip to content

Commit be0ce93

Browse files
committed
BUG: Fix incorrect cast
Fix incorrect cast in bounded integers Add range testing for broadcast bounded integers
1 parent 24f30ef commit be0ce93

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

randomstate/bounded_integers.pxi.in

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ cdef object _rand_{{nptype}}(object low, object high, object size, aug_state *st
241241
raise ValueError('high is out of bounds for {{nptype}}')
242242
if closed_upper < {{lb}}:
243243
raise ValueError('low >= high')
244-
245244
highm1_data[i] = <{{nptype}}_t>closed_upper
246245

247246
if np.any(np.greater(low_arr, highm1_arr)):
@@ -264,11 +263,8 @@ cdef object _rand_{{nptype}}(object low, object high, object size, aug_state *st
264263
for i in range(n):
265264
low_v = (<{{nptype}}_t*>np.PyArray_MultiIter_DATA(it, 0))[0]
266265
high_v = (<{{nptype}}_t*>np.PyArray_MultiIter_DATA(it, 1))[0]
267-
rng = <{{nptype}}_t>(high_v - low_v) # No -1 here since implemented above
268-
off = low_v
269-
if rng == 0:
270-
out_data[i] = 0
271-
continue
266+
rng = <{{utype}}_t>(high_v - low_v) # No -1 here since implemented above
267+
off = <{{utype}}_t>(<{{nptype}}_t>low_v)
272268

273269
if rng != last_rng:
274270
mask = _gen_mask(rng)

randomstate/tests/test_numpy_mt19937.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ def test_bounds_checking(self):
156156
assert_raises(ValueError, self.rfunc, ubnd, lbnd, dtype=dt)
157157
assert_raises(ValueError, self.rfunc, 1, 0, dtype=dt)
158158

159+
def test_bounds_checking_array(self):
160+
for dt in self.itype:
161+
lbnd = 0 if dt is bool else np.iinfo(dt).min
162+
ubnd = 2 if dt is bool else np.iinfo(dt).max + 1
163+
assert_raises(ValueError, self.rfunc, [lbnd - 1], [ubnd], dtype=dt)
164+
assert_raises(ValueError, self.rfunc, [lbnd], [ubnd + 1], dtype=dt)
165+
assert_raises(ValueError, self.rfunc, ubnd, [lbnd], dtype=dt)
166+
assert_raises(ValueError, self.rfunc, [1], 0, dtype=dt)
167+
159168
def test_rng_zero_and_extremes(self):
160169
for dt in self.itype:
161170
lbnd = 0 if dt is bool else np.iinfo(dt).min
@@ -170,6 +179,21 @@ def test_rng_zero_and_extremes(self):
170179
tgt = (lbnd + ubnd) // 2
171180
assert_equal(self.rfunc(tgt, tgt + 1, size=1000, dtype=dt), tgt)
172181

182+
def test_rng_zero_and_extremes_array(self):
183+
for dt in self.itype:
184+
lbnd = 0 if dt is bool else np.iinfo(dt).min
185+
ubnd = 2 if dt is bool else np.iinfo(dt).max + 1
186+
187+
tgt = ubnd - 1
188+
print(dt)
189+
assert_equal(self.rfunc([tgt], [tgt + 1], size=1000, dtype=dt), tgt)
190+
191+
tgt = lbnd
192+
assert_equal(self.rfunc([tgt], [tgt + 1], size=1000, dtype=dt), tgt)
193+
194+
tgt = (lbnd + ubnd) // 2
195+
assert_equal(self.rfunc([tgt], [tgt + 1], size=1000, dtype=dt), tgt)
196+
173197
def test_full_range(self):
174198
# Test for ticket #1690
175199

@@ -184,6 +208,20 @@ def test_full_range(self):
184208
"but one was with the following "
185209
"message:\n\n%s" % str(e))
186210

211+
def test_full_range_array(self):
212+
# Test for ticket #1690
213+
214+
for dt in self.itype:
215+
lbnd = 0 if dt is bool else np.iinfo(dt).min
216+
ubnd = 2 if dt is bool else np.iinfo(dt).max + 1
217+
218+
try:
219+
self.rfunc([lbnd], [ubnd], dtype=dt)
220+
except Exception as e:
221+
raise AssertionError("No error should have been raised, "
222+
"but one was with the following "
223+
"message:\n\n%s" % str(e))
224+
187225
def test_in_bounds_fuzz(self):
188226
# Don't use fixed seed
189227
mt19937.seed()
@@ -275,6 +313,17 @@ def test_respect_dtype_singleton(self):
275313
self.assertEqual(type(sample), dt)
276314

277315

316+
def test_respect_dtype_array(self):
317+
# See gh-7203
318+
for dt in self.itype:
319+
lbnd = 0 if dt is bool else np.iinfo(dt).min
320+
ubnd = 2 if dt is bool else np.iinfo(dt).max + 1
321+
dt = np.bool_ if dt is bool else dt
322+
323+
sample = self.rfunc([lbnd], [ubnd], dtype=dt)
324+
self.assertEqual(sample.dtype, dt)
325+
326+
278327
class TestRandomDist(TestCase):
279328
# Make sure the random distribution returns the correct value for a
280329
# given seed

0 commit comments

Comments
 (0)