Skip to content

Commit 828cba8

Browse files
committed
Added missing tests for usm_ndarray.flags
1 parent bd3eb4d commit 828cba8

File tree

1 file changed

+40
-23
lines changed

1 file changed

+40
-23
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,35 @@ def test_allocate_usm_ndarray(shape, usm_type):
5858

5959
def test_usm_ndarray_flags():
6060
get_queue_or_skip()
61-
assert dpt.usm_ndarray((5,), dtype="i4").flags.fc
62-
assert dpt.usm_ndarray((5, 2), dtype="i4").flags.c_contiguous
63-
assert not dpt.usm_ndarray((5, 2), dtype="i4").flags.fnc
64-
assert dpt.usm_ndarray((5, 2), dtype="i4", order="F").flags.f_contiguous
65-
assert dpt.usm_ndarray((5, 1, 2), dtype="i4", order="F").flags.f_contiguous
66-
assert dpt.usm_ndarray((5, 1, 2), dtype="i4", order="F").flags.fnc
67-
assert dpt.usm_ndarray(
68-
(5, 1, 2), dtype="i4", strides=(2, 0, 1)
69-
).flags.c_contiguous
61+
f = dpt.usm_ndarray((5,), dtype="i4").flags
62+
assert f.fc
63+
assert f.forc
64+
65+
f = dpt.usm_ndarray((5, 2), dtype="i4").flags
66+
assert f.c_contiguous
67+
assert f.forc
68+
69+
f = dpt.usm_ndarray((5, 2), dtype="i4", order="F").flags
70+
assert f.f_contiguous
71+
assert f.forc
72+
assert f.fnc
73+
74+
f = dpt.usm_ndarray((5, 1, 2), dtype="i4", strides=(2, 0, 1)).flags
75+
assert f.c_contiguous
76+
assert f.forc
77+
78+
f = dpt.usm_ndarray((5, 1, 2), dtype="i4", strides=(1, 0, 5)).flags
79+
assert f.f_contiguous
80+
assert f.forc
81+
assert f.fnc
82+
83+
f = dpt.usm_ndarray((5, 1, 1), dtype="i4", strides=(1, 0, 1)).flags
84+
assert f.fc
85+
assert f.forc
7086
assert not dpt.usm_ndarray(
71-
(5, 1, 2), dtype="i4", strides=(2, 0, 1)
72-
).flags.fnc
73-
assert dpt.usm_ndarray(
74-
(5, 1, 2), dtype="i4", strides=(1, 0, 5)
75-
).flags.f_contiguous
76-
assert dpt.usm_ndarray((5, 1, 2), dtype="i4", strides=(1, 0, 5)).flags.fnc
77-
assert dpt.usm_ndarray((5, 1, 1), dtype="i4", strides=(1, 0, 1)).flags.fc
87+
(5, 1, 1), dtype="i4", strides=(2, 0, 1)
88+
).flags.forc
89+
7890
x = dpt.empty(5, dtype="u2")
7991
assert x.flags.writable is True
8092
x.flags.writable = False
@@ -87,6 +99,11 @@ def test_usm_ndarray_flags():
8799
assert x.flags.writable is True
88100
x[:] = 0
89101

102+
with pytest.raises(TypeError):
103+
x.flags.writable = dict()
104+
with pytest.raises(ValueError):
105+
x.flags["C"] = False
106+
90107

91108
@pytest.mark.parametrize(
92109
"dtype",
@@ -1834,13 +1851,13 @@ def test_flags():
18341851
x = dpt.empty(tuple(), "i4")
18351852
f = x.flags
18361853
f.__repr__()
1837-
f.c_contiguous
1838-
f.f_contiguous
1839-
f.contiguous
1840-
f.fc
1841-
f.fnc
1842-
f.forc
1843-
f.writable
1854+
assert f.c_contiguous == f["C"]
1855+
assert f.f_contiguous == f["F"]
1856+
assert f.contiguous == f["CONTIGUOUS"]
1857+
assert f.fc == f["FC"]
1858+
assert f.forc == f["FORC"]
1859+
assert f.fnc == f["FNC"]
1860+
assert f.writable == f["W"]
18441861
# check comparison with generic types
18451862
f == Ellipsis
18461863

0 commit comments

Comments
 (0)