diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index b73edef82e..6e84b7c801 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -382,8 +382,9 @@ cdef class usm_ndarray: else: self._cleanup() raise ValueError("buffer='{}' was not understood.".format(buffer)) - if (_offset + ary_min_displacement < 0 or - (_offset + ary_max_displacement + 1) * itemsize > _buffer.nbytes): + if (shape_to_elem_count(nd, shape_ptr) > 0 and + (_offset + ary_min_displacement < 0 or + (_offset + ary_max_displacement + 1) * itemsize > _buffer.nbytes)): self._cleanup() raise ValueError(("buffer='{}' can not accommodate " "the requested array.").format(buffer)) diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index c6e33b600c..ce0f072288 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -28,6 +28,23 @@ from .helper import get_queue_or_skip, skip_if_dtype_not_supported +_all_dtypes = [ + "b1", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] + @pytest.mark.parametrize( "shape", @@ -150,6 +167,21 @@ def test_usm_ndarray_writable_flag_views(): assert not a.imag.flags.writable +@pytest.mark.parametrize("dt1", _all_dtypes) +@pytest.mark.parametrize("dt2", _all_dtypes) +def test_usm_ndarray_from_zero_sized_usm_ndarray(dt1, dt2): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt1, q) + skip_if_dtype_not_supported(dt2, q) + + x1 = dpt.ones((0,), dtype=dt1, sycl_queue=q) + x2 = dpt.usm_ndarray(x1.shape, dtype=dt2, buffer=x1) + assert x2.dtype == dt2 + assert x2.sycl_queue == q + assert x2._pointer == x1._pointer + assert x2.shape == x1.shape + + def test_usm_ndarray_from_usm_ndarray_readonly(): get_queue_or_skip() @@ -161,20 +193,8 @@ def test_usm_ndarray_from_usm_ndarray_readonly(): @pytest.mark.parametrize( "dtype", - [ - "u1", - "i1", - "u2", - "i2", - "u4", - "i4", - "u8", - "i8", - "f2", - "f4", - "f8", - "c8", - "c16", + _all_dtypes + + [ b"float32", dpt.dtype("d"), np.half, @@ -1103,24 +1123,6 @@ def test_pyx_capi_check_constants(): assert cdouble_typenum == dpt.dtype(np.cdouble).num -_all_dtypes = [ - "b1", - "i1", - "u1", - "i2", - "u2", - "i4", - "u4", - "i8", - "u8", - "f2", - "f4", - "f8", - "c8", - "c16", -] - - @pytest.mark.parametrize( "shape", [tuple(), (1,), (5,), (2, 3), (2, 3, 4), (2, 2, 2, 2, 2)] )