Skip to content

Commit 8f44a5e

Browse files
antonwolfyoleksandr-pavlyk
authored andcommitted
Added support for out=arg by temporary copy
1 parent f032154 commit 8f44a5e

File tree

1 file changed

+31
-22
lines changed

1 file changed

+31
-22
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ def __call__(self, x, out=None, order="K"):
5252
if not isinstance(x, dpt.usm_ndarray):
5353
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
5454

55+
if order not in ["C", "F", "K", "A"]:
56+
order = "K"
57+
buf_dt, res_dt = _find_buf_dtype(
58+
x.dtype, self.result_type_resolver_fn_, x.sycl_device
59+
)
60+
if res_dt is None:
61+
raise RuntimeError
62+
63+
orig_out = out
5564
if out is not None:
5665
if not isinstance(out, dpt.usm_ndarray):
5766
raise TypeError(
@@ -64,8 +73,17 @@ def __call__(self, x, out=None, order="K"):
6473
f"Expected output shape is {x.shape}, got {out.shape}"
6574
)
6675

67-
if ti._array_overlap(x, out):
68-
raise TypeError("Input and output arrays have memory overlap")
76+
if res_dt != out.dtype:
77+
raise TypeError(
78+
f"Output array of type {res_dt} is needed,"
79+
f" got {out.dtype}"
80+
)
81+
82+
if buf_dt is None and ti._array_overlap(x, out):
83+
# Allocate a temporary buffer to avoid memory overlapping.
84+
# Note if `buf_dt` is not None, a temporary copy of `x` will be
85+
# created, so the array overlap check isn't needed.
86+
out = dpt.empty_like(out)
6987

7088
if (
7189
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
@@ -75,13 +93,6 @@ def __call__(self, x, out=None, order="K"):
7593
"Input and output allocation queues are not compatible"
7694
)
7795

78-
if order not in ["C", "F", "K", "A"]:
79-
order = "K"
80-
buf_dt, res_dt = _find_buf_dtype(
81-
x.dtype, self.result_type_resolver_fn_, x.sycl_device
82-
)
83-
if res_dt is None:
84-
raise RuntimeError
8596
exec_q = x.sycl_queue
8697
if buf_dt is None:
8798
if out is None:
@@ -91,17 +102,20 @@ def __call__(self, x, out=None, order="K"):
91102
if order == "A":
92103
order = "F" if x.flags.f_contiguous else "C"
93104
out = dpt.empty_like(x, dtype=res_dt, order=order)
94-
else:
95-
if res_dt != out.dtype:
96-
raise TypeError(
97-
f"Output array of type {res_dt} is needed,"
98-
f" got {out.dtype}"
99-
)
100105

101-
ht, _ = self.unary_fn_(x, out, sycl_queue=exec_q)
102-
ht.wait()
106+
ht_unary_ev, unary_ev = self.unary_fn_(x, out, sycl_queue=exec_q)
107+
108+
if not (orig_out is None or orig_out is out):
109+
# Copy the out data from temporary buffer to original memory
110+
ht_copy_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
111+
src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
112+
)
113+
ht_copy_ev.wait()
114+
out = orig_out
103115

116+
ht_unary_ev.wait()
104117
return out
118+
105119
if order == "K":
106120
buf = _empty_like_orderK(x, buf_dt)
107121
else:
@@ -117,11 +131,6 @@ def __call__(self, x, out=None, order="K"):
117131
out = _empty_like_orderK(buf, res_dt)
118132
else:
119133
out = dpt.empty_like(buf, dtype=res_dt, order=order)
120-
else:
121-
if buf_dt != out.dtype:
122-
raise TypeError(
123-
f"Output array of type {buf_dt} is needed, got {out.dtype}"
124-
)
125134

126135
ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
127136
ht_copy_ev.wait()

0 commit comments

Comments
 (0)