Skip to content

Commit 0008a7f

Browse files
Implemented logic to mutate array's writable flag
Example: ``` import dpctl.tensor as dpt, pytest x = dpt.ones((2,2)) x.flags.writable = False with pytest.raises(ValueError): x[:] = 0 # can not mutate read-only array x.flags["W"] = True x[:] = 0 # works fine now ```
1 parent 26b2eaa commit 0008a7f

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

dpctl/tensor/_flags.pyx

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ cdef class Flags:
7575
"""
7676
return _check_bit(self.flags_, USM_ARRAY_WRITABLE)
7777

78+
@writable.setter
79+
def writable(self, new_val):
80+
if not isinstance(new_val, bool):
81+
raise TypeError("Expecting a boolean value")
82+
self.arr_._set_writable_flag(new_val)
83+
7884
@property
7985
def fc(self):
8086
"""
@@ -129,6 +135,14 @@ cdef class Flags:
129135
elif name == "CONTIGUOUS":
130136
return self.forc
131137

138+
def __setitem__(self, name, val):
139+
if name in ["WRITABLE", "W"]:
140+
self.writable = val
141+
else:
142+
raise ValueError(
143+
"Only writable ('W' or 'WRITABLE') flag can be set"
144+
)
145+
132146
def __repr__(self):
133147
out = []
134148
for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE":

dpctl/tensor/_usmarray.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,6 @@ cdef api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]:
7272
cdef dpctl.DPCTLSyclQueueRef get_queue_ref(self) except *
7373
cdef dpctl.SyclQueue get_sycl_queue(self)
7474

75+
cdef _set_writable_flag(self, int)
76+
7577
cdef __cythonbufferdefaults__ = {"mode": "strided"}

dpctl/tensor/_usmarray.pyx

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,12 @@ cdef class usm_ndarray:
532532
"""
533533
return _flags.Flags(self, self.flags_)
534534

535+
cdef _set_writable_flag(self, int flag):
536+
cdef int arr_fl = self.flags_
537+
arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag
538+
arr_fl |= (USM_ARRAY_WRITABLE if flag else 0)
539+
self.flags_ = arr_fl
540+
535541
@property
536542
def usm_type(self):
537543
"""
@@ -1390,12 +1396,10 @@ cdef api Py_ssize_t UsmNDArray_GetOffset(usm_ndarray arr):
13901396
allocation"""
13911397
return arr.get_offset()
13921398

1399+
13931400
cdef api void UsmNDArray_SetWritableFlag(usm_ndarray arr, int flag):
13941401
"""Set/unset USM_ARRAY_WRITABLE in the given array `arr`."""
1395-
cdef int arr_fl = arr.flags_
1396-
arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag
1397-
arr_fl |= (USM_ARRAY_WRITABLE if flag else 0)
1398-
arr.flags_ = arr_fl
1402+
arr._set_writable_flag(flag)
13991403

14001404
cdef api object UsmNDArray_MakeSimpleFromMemory(
14011405
int nd, const Py_ssize_t *shape, int typenum,

0 commit comments

Comments
 (0)