Skip to content

Commit e0dd5bd

Browse files
authored
Merge pull request #2123 from IntelPython/resolve-gh-2119
Allow type casting of zero-sized array to any dtype
2 parents 8c94751 + 4947f5c commit e0dd5bd

File tree

2 files changed

+37
-34
lines changed

2 files changed

+37
-34
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,9 @@ cdef class usm_ndarray:
382382
else:
383383
self._cleanup()
384384
raise ValueError("buffer='{}' was not understood.".format(buffer))
385-
if (_offset + ary_min_displacement < 0 or
386-
(_offset + ary_max_displacement + 1) * itemsize > _buffer.nbytes):
385+
if (shape_to_elem_count(nd, shape_ptr) > 0 and
386+
(_offset + ary_min_displacement < 0 or
387+
(_offset + ary_max_displacement + 1) * itemsize > _buffer.nbytes)):
387388
self._cleanup()
388389
raise ValueError(("buffer='{}' can not accommodate "
389390
"the requested array.").format(buffer))

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@
2828

2929
from .helper import get_queue_or_skip, skip_if_dtype_not_supported
3030

31+
_all_dtypes = [
32+
"b1",
33+
"i1",
34+
"u1",
35+
"i2",
36+
"u2",
37+
"i4",
38+
"u4",
39+
"i8",
40+
"u8",
41+
"f2",
42+
"f4",
43+
"f8",
44+
"c8",
45+
"c16",
46+
]
47+
3148

3249
@pytest.mark.parametrize(
3350
"shape",
@@ -150,6 +167,21 @@ def test_usm_ndarray_writable_flag_views():
150167
assert not a.imag.flags.writable
151168

152169

170+
@pytest.mark.parametrize("dt1", _all_dtypes)
171+
@pytest.mark.parametrize("dt2", _all_dtypes)
172+
def test_usm_ndarray_from_zero_sized_usm_ndarray(dt1, dt2):
173+
q = get_queue_or_skip()
174+
skip_if_dtype_not_supported(dt1, q)
175+
skip_if_dtype_not_supported(dt2, q)
176+
177+
x1 = dpt.ones((0,), dtype=dt1, sycl_queue=q)
178+
x2 = dpt.usm_ndarray(x1.shape, dtype=dt2, buffer=x1)
179+
assert x2.dtype == dt2
180+
assert x2.sycl_queue == q
181+
assert x2._pointer == x1._pointer
182+
assert x2.shape == x1.shape
183+
184+
153185
def test_usm_ndarray_from_usm_ndarray_readonly():
154186
get_queue_or_skip()
155187

@@ -161,20 +193,8 @@ def test_usm_ndarray_from_usm_ndarray_readonly():
161193

162194
@pytest.mark.parametrize(
163195
"dtype",
164-
[
165-
"u1",
166-
"i1",
167-
"u2",
168-
"i2",
169-
"u4",
170-
"i4",
171-
"u8",
172-
"i8",
173-
"f2",
174-
"f4",
175-
"f8",
176-
"c8",
177-
"c16",
196+
_all_dtypes
197+
+ [
178198
b"float32",
179199
dpt.dtype("d"),
180200
np.half,
@@ -1103,24 +1123,6 @@ def test_pyx_capi_check_constants():
11031123
assert cdouble_typenum == dpt.dtype(np.cdouble).num
11041124

11051125

1106-
_all_dtypes = [
1107-
"b1",
1108-
"i1",
1109-
"u1",
1110-
"i2",
1111-
"u2",
1112-
"i4",
1113-
"u4",
1114-
"i8",
1115-
"u8",
1116-
"f2",
1117-
"f4",
1118-
"f8",
1119-
"c8",
1120-
"c16",
1121-
]
1122-
1123-
11241126
@pytest.mark.parametrize(
11251127
"shape", [tuple(), (1,), (5,), (2, 3), (2, 3, 4), (2, 2, 2, 2, 2)]
11261128
)

0 commit comments

Comments
 (0)