Skip to content

Commit 6d7e598

Browse files
Used dpctl.tensor constructors instead of NumPy's followed by from_numpy
This change reduced run-time of test_usm_ndarray_ctor by 3 seconds (from 16.5 to 13.5 on my machine) and marginally improves coverage of Python source file.
1 parent 8f2ba46 commit 6d7e598

File tree

1 file changed

+20
-29
lines changed

1 file changed

+20
-29
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,7 @@ def test_empty_slice():
275275

276276
def test_slice_constructor_1d():
277277
Xh = np.arange(37, dtype="i4")
278-
default_device = dpctl.select_default_device()
279-
Xusm = dpt.from_numpy(Xh, device=default_device, usm_type="device")
278+
Xusm = dpt.arange(Xh.size, dtype="i4")
280279
for ind in [
281280
slice(1, None, 2),
282281
slice(0, None, 3),
@@ -293,9 +292,8 @@ def test_slice_constructor_1d():
293292

294293

295294
def test_slice_constructor_3d():
296-
Xh = np.empty((37, 24, 35), dtype="i4")
297-
default_device = dpctl.select_default_device()
298-
Xusm = dpt.from_numpy(Xh, device=default_device, usm_type="device")
295+
Xh = np.ones((37, 24, 35), dtype="i4")
296+
Xusm = dpt.ones(Xh.shape, dtype=Xh.dtype)
299297
for ind in [
300298
slice(1, None, 2),
301299
slice(0, None, 3),
@@ -315,8 +313,7 @@ def test_slice_constructor_3d():
315313
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
316314
def test_slice_suai(usm_type):
317315
Xh = np.arange(0, 10, dtype="u1")
318-
default_device = dpctl.select_default_device()
319-
Xusm = dpt.from_numpy(Xh, device=default_device, usm_type=usm_type)
316+
Xusm = dpt.arange(0, 10, dtype="u1", usm_type=usm_type)
320317
for ind in [slice(2, 3, None), slice(5, 7, None), slice(3, 9, None)]:
321318
assert np.array_equal(
322319
dpm.as_usm_memory(Xusm[ind]).copy_to_host(), Xh[ind]
@@ -866,8 +863,7 @@ def test_pyx_capi_check_constants():
866863
def test_tofrom_numpy(shape, dtype, usm_type):
867864
q = get_queue_or_skip()
868865
skip_if_dtype_not_supported(dtype, q)
869-
Xnp = np.zeros(shape, dtype=dtype)
870-
Xusm = dpt.from_numpy(Xnp, usm_type=usm_type, sycl_queue=q)
866+
Xusm = dpt.zeros(shape, dtype=dtype, usm_type=usm_type, sycl_queue=q)
871867
Ynp = np.ones(shape, dtype=dtype)
872868
ind = (slice(None, None, None),) * Ynp.ndim
873869
Xusm[ind] = Ynp
@@ -883,35 +879,30 @@ def test_tofrom_numpy(shape, dtype, usm_type):
883879
def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
884880
q = get_queue_or_skip()
885881
skip_if_dtype_not_supported(dtype, q)
882+
shape = (2, 4, 3)
886883
Xnp = (
887-
np.random.randint(-10, 10, size=2 * 3 * 4)
884+
np.random.randint(-10, 10, size=np.prod(shape))
888885
.astype(dtype)
889-
.reshape((2, 4, 3))
886+
.reshape(shape)
890887
)
891-
Znp = np.zeros(
892-
(
893-
2,
894-
4,
895-
3,
896-
),
897-
dtype=dtype,
898-
)
899-
Zusm_0d = dpt.from_numpy(Znp[0, 0, 0], usm_type=dst_usm_type)
888+
X = dpt.from_numpy(Xnp, usm_type=src_usm_type)
889+
Z = dpt.zeros(shape, dtype=dtype, usm_type=dst_usm_type)
890+
Zusm_0d = dpt.copy(Z[0, 0, 0])
900891
ind = (-1, -1, -1)
901-
Xusm_0d = dpt.from_numpy(Xnp[ind], usm_type=src_usm_type)
892+
Xusm_0d = X[ind]
902893
Zusm_0d[Ellipsis] = Xusm_0d
903894
assert np.array_equal(dpt.to_numpy(Zusm_0d), Xnp[ind])
904-
Zusm_1d = dpt.from_numpy(Znp[0, 1:3, 0], usm_type=dst_usm_type)
895+
Zusm_1d = dpt.copy(Z[0, 1:3, 0])
905896
ind = (-1, slice(0, 2, None), -1)
906-
Xusm_1d = dpt.from_numpy(Xnp[ind], usm_type=src_usm_type)
897+
Xusm_1d = X[ind]
907898
Zusm_1d[Ellipsis] = Xusm_1d
908899
assert np.array_equal(dpt.to_numpy(Zusm_1d), Xnp[ind])
909-
Zusm_2d = dpt.from_numpy(Znp[:, 1:3, 0], usm_type=dst_usm_type)[::-1]
910-
Xusm_2d = dpt.from_numpy(Xnp[:, 1:4, -1], usm_type=src_usm_type)
900+
Zusm_2d = dpt.copy(Z[:, 1:3, 0])[::-1]
901+
Xusm_2d = X[:, 1:4, -1]
911902
Zusm_2d[:] = Xusm_2d[:, 0:2]
912903
assert np.array_equal(dpt.to_numpy(Zusm_2d), Xnp[:, 1:3, -1])
913-
Zusm_3d = dpt.from_numpy(Znp, usm_type=dst_usm_type)
914-
Xusm_3d = dpt.from_numpy(Xnp, usm_type=src_usm_type)
904+
Zusm_3d = dpt.copy(Z)
905+
Xusm_3d = X
915906
Zusm_3d[:] = Xusm_3d
916907
assert np.array_equal(dpt.to_numpy(Zusm_3d), Xnp)
917908
Zusm_3d[::-1] = Xusm_3d[::-1]
@@ -962,8 +953,8 @@ def test_setitem_errors():
962953
def test_setitem_different_dtypes(src_dt, dst_dt):
963954
q = get_queue_or_skip()
964955
skip_if_dtype_not_supported(dst_dt, q)
965-
X = dpt.from_numpy(np.ones(10, src_dt), sycl_queue=q)
966-
Y = dpt.from_numpy(np.zeros(10, src_dt), sycl_queue=q)
956+
X = dpt.ones(10, src_dt, sycl_queue=q)
957+
Y = dpt.zeros(10, src_dt, sycl_queue=q)
967958
Z = dpt.empty((20,), dtype=dst_dt, sycl_queue=q)
968959
Z[::2] = X
969960
Z[1::2] = Y

0 commit comments

Comments
 (0)