Skip to content

Commit aef6de3

Browse files
committed
out keyword for elementwise functions
1 parent aebaf3a commit aef6de3

File tree

6 files changed

+395
-76
lines changed

6 files changed

+395
-76
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 170 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,33 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
4747
self.unary_fn_ = unary_dp_impl_fn
4848
self.__doc__ = docs
4949

50-
def __call__(self, x, order="K"):
50+
def __call__(self, x, out=None, order="K"):
5151
if not isinstance(x, dpt.usm_ndarray):
5252
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
53+
54+
if out is not None:
55+
if not isinstance(out, dpt.usm_ndarray):
56+
raise TypeError(
57+
f"output array must be of usm_ndarray type, got {type(out)}"
58+
)
59+
60+
if out.shape != x.shape:
61+
raise TypeError(
62+
"The shape of input and output arrays are inconsistent."
63+
f"Expected output shape is {x.shape}, got {out.shape}"
64+
)
65+
66+
if ti._array_overlap(x, out):
67+
raise TypeError("Input and output arrays have memory overlap")
68+
69+
if (
70+
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
71+
is None
72+
):
73+
raise TypeError(
74+
"Input and output allocation queues are not compatible"
75+
)
76+
5377
if order not in ["C", "F", "K", "A"]:
5478
order = "K"
5579
buf_dt, res_dt = _find_buf_dtype(
@@ -59,17 +83,24 @@ def __call__(self, x, order="K"):
5983
raise RuntimeError
6084
exec_q = x.sycl_queue
6185
if buf_dt is None:
62-
if order == "K":
63-
r = _empty_like_orderK(x, res_dt)
86+
if out is None:
87+
if order == "K":
88+
out = _empty_like_orderK(x, res_dt)
89+
else:
90+
if order == "A":
91+
order = "F" if x.flags.f_contiguous else "C"
92+
out = dpt.empty_like(x, dtype=res_dt, order=order)
6493
else:
65-
if order == "A":
66-
order = "F" if x.flags.f_contiguous else "C"
67-
r = dpt.empty_like(x, dtype=res_dt, order=order)
94+
if res_dt != out.dtype:
95+
raise TypeError(
96+
f"Expected output array of type {res_dt} is supported"
97+
f", got {out.dtype}"
98+
)
6899

69-
ht, _ = self.unary_fn_(x, r, sycl_queue=exec_q)
100+
ht, _ = self.unary_fn_(x, out, sycl_queue=exec_q)
70101
ht.wait()
71102

72-
return r
103+
return out
73104
if order == "K":
74105
buf = _empty_like_orderK(x, buf_dt)
75106
else:
@@ -80,16 +111,23 @@ def __call__(self, x, order="K"):
80111
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
81112
src=x, dst=buf, sycl_queue=exec_q
82113
)
83-
if order == "K":
84-
r = _empty_like_orderK(buf, res_dt)
114+
if out is None:
115+
if order == "K":
116+
out = _empty_like_orderK(buf, res_dt)
117+
else:
118+
out = dpt.empty_like(buf, dtype=res_dt, order=order)
85119
else:
86-
r = dpt.empty_like(buf, dtype=res_dt, order=order)
120+
if buf_dt != out.dtype:
121+
raise TypeError(
122+
f"Expected output array of type {buf_dt} is supported,"
123+
f"got {out.dtype}"
124+
)
87125

88-
ht, _ = self.unary_fn_(buf, r, sycl_queue=exec_q, depends=[copy_ev])
126+
ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
89127
ht_copy_ev.wait()
90128
ht.wait()
91129

92-
return r
130+
return out
93131

94132

95133
def _get_queue_usm_type(o):
@@ -281,7 +319,7 @@ def __str__(self):
281319
def __repr__(self):
282320
return f"<BinaryElementwiseFunc '{self.name_}'>"
283321

284-
def __call__(self, o1, o2, order="K"):
322+
def __call__(self, o1, o2, out=None, order="K"):
285323
if order not in ["K", "C", "F", "A"]:
286324
order = "K"
287325
q1, o1_usm_type = _get_queue_usm_type(o1)
@@ -358,6 +396,31 @@ def __call__(self, o1, o2, order="K"):
358396
"supported types according to the casting rule ''safe''."
359397
)
360398

399+
if out is not None:
400+
if not isinstance(out, dpt.usm_ndarray):
401+
raise TypeError(
402+
f"output array must be of usm_ndarray type, got {type(out)}"
403+
)
404+
405+
if out.shape != o1_shape or out.shape != o2_shape:
406+
raise TypeError(
407+
"The shape of input and output arrays are inconsistent."
408+
f"Expected output shape is {o1_shape}, got {out.shape}"
409+
)
410+
411+
if ti._array_overlap(o1, out) or ti._array_overlap(o2, out):
412+
raise TypeError("Input and output arrays have memory overlap")
413+
414+
if (
415+
dpctl.utils.get_execution_queue(
416+
(o1.sycl_queue, o2.sycl_queue, out.sycl_queue)
417+
)
418+
is None
419+
):
420+
raise TypeError(
421+
"Input and output allocation queues are not compatible"
422+
)
423+
361424
if isinstance(o1, dpt.usm_ndarray):
362425
src1 = o1
363426
else:
@@ -368,37 +431,45 @@ def __call__(self, o1, o2, order="K"):
368431
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
369432

370433
if buf1_dt is None and buf2_dt is None:
371-
if order == "K":
372-
r = _empty_like_pair_orderK(
373-
src1, src2, res_dt, res_usm_type, exec_q
374-
)
375-
else:
376-
if order == "A":
377-
order = (
378-
"F"
379-
if all(
380-
arr.flags.f_contiguous
381-
for arr in (
382-
src1,
383-
src2,
434+
if out is None:
435+
if order == "K":
436+
out = _empty_like_pair_orderK(
437+
src1, src2, res_dt, res_usm_type, exec_q
438+
)
439+
else:
440+
if order == "A":
441+
order = (
442+
"F"
443+
if all(
444+
arr.flags.f_contiguous
445+
for arr in (
446+
src1,
447+
src2,
448+
)
384449
)
450+
else "C"
385451
)
386-
else "C"
452+
out = dpt.empty(
453+
res_shape,
454+
dtype=res_dt,
455+
usm_type=res_usm_type,
456+
sycl_queue=exec_q,
457+
order=order,
387458
)
388-
r = dpt.empty(
389-
res_shape,
390-
dtype=res_dt,
391-
usm_type=res_usm_type,
392-
sycl_queue=exec_q,
393-
order=order,
394-
)
459+
else:
460+
if res_dt != out.dtype:
461+
raise TypeError(
462+
f"Output array of type {res_dt} is needed,"
463+
f"got {out.dtype}"
464+
)
465+
395466
src1 = dpt.broadcast_to(src1, res_shape)
396467
src2 = dpt.broadcast_to(src2, res_shape)
397468
ht_, _ = self.binary_fn_(
398-
src1=src1, src2=src2, dst=r, sycl_queue=exec_q
469+
src1=src1, src2=src2, dst=out, sycl_queue=exec_q
399470
)
400471
ht_.wait()
401-
return r
472+
return out
402473
elif buf1_dt is None:
403474
if order == "K":
404475
buf2 = _empty_like_orderK(src2, buf2_dt)
@@ -409,30 +480,38 @@ def __call__(self, o1, o2, order="K"):
409480
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
410481
src=src2, dst=buf2, sycl_queue=exec_q
411482
)
412-
if order == "K":
413-
r = _empty_like_pair_orderK(
414-
src1, buf2, res_dt, res_usm_type, exec_q
415-
)
483+
if out is None:
484+
if order == "K":
485+
out = _empty_like_pair_orderK(
486+
src1, buf2, res_dt, res_usm_type, exec_q
487+
)
488+
else:
489+
out = dpt.empty(
490+
res_shape,
491+
dtype=res_dt,
492+
usm_type=res_usm_type,
493+
sycl_queue=exec_q,
494+
order=order,
495+
)
416496
else:
417-
r = dpt.empty(
418-
res_shape,
419-
dtype=res_dt,
420-
usm_type=res_usm_type,
421-
sycl_queue=exec_q,
422-
order=order,
423-
)
497+
if res_dt != out.dtype:
498+
raise TypeError(
499+
f"Output array of type {res_dt} is needed,"
500+
f"got {out.dtype}"
501+
)
502+
424503
src1 = dpt.broadcast_to(src1, res_shape)
425504
buf2 = dpt.broadcast_to(buf2, res_shape)
426505
ht_, _ = self.binary_fn_(
427506
src1=src1,
428507
src2=buf2,
429-
dst=r,
508+
dst=out,
430509
sycl_queue=exec_q,
431510
depends=[copy_ev],
432511
)
433512
ht_copy_ev.wait()
434513
ht_.wait()
435-
return r
514+
return out
436515
elif buf2_dt is None:
437516
if order == "K":
438517
buf1 = _empty_like_orderK(src1, buf1_dt)
@@ -443,30 +522,38 @@ def __call__(self, o1, o2, order="K"):
443522
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
444523
src=src1, dst=buf1, sycl_queue=exec_q
445524
)
446-
if order == "K":
447-
r = _empty_like_pair_orderK(
448-
buf1, src2, res_dt, res_usm_type, exec_q
449-
)
525+
if out is None:
526+
if order == "K":
527+
out = _empty_like_pair_orderK(
528+
buf1, src2, res_dt, res_usm_type, exec_q
529+
)
530+
else:
531+
out = dpt.empty(
532+
res_shape,
533+
dtype=res_dt,
534+
usm_type=res_usm_type,
535+
sycl_queue=exec_q,
536+
order=order,
537+
)
450538
else:
451-
r = dpt.empty(
452-
res_shape,
453-
dtype=res_dt,
454-
usm_type=res_usm_type,
455-
sycl_queue=exec_q,
456-
order=order,
457-
)
539+
if res_dt != out.dtype:
540+
raise TypeError(
541+
f"Output array of type {res_dt} is needed,"
542+
f"got {out.dtype}"
543+
)
544+
458545
buf1 = dpt.broadcast_to(buf1, res_shape)
459546
src2 = dpt.broadcast_to(src2, res_shape)
460547
ht_, _ = self.binary_fn_(
461548
src1=buf1,
462549
src2=src2,
463-
dst=r,
550+
dst=out,
464551
sycl_queue=exec_q,
465552
depends=[copy_ev],
466553
)
467554
ht_copy_ev.wait()
468555
ht_.wait()
469-
return r
556+
return out
470557

471558
if order in ["K", "A"]:
472559
if src1.flags.f_contiguous and src2.flags.f_contiguous:
@@ -489,26 +576,33 @@ def __call__(self, o1, o2, order="K"):
489576
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
490577
src=src2, dst=buf2, sycl_queue=exec_q
491578
)
492-
if order == "K":
493-
r = _empty_like_pair_orderK(
494-
buf1, buf2, res_dt, res_usm_type, exec_q
495-
)
579+
if out is None:
580+
if order == "K":
581+
out = _empty_like_pair_orderK(
582+
buf1, buf2, res_dt, res_usm_type, exec_q
583+
)
584+
else:
585+
out = dpt.empty(
586+
res_shape,
587+
dtype=res_dt,
588+
usm_type=res_usm_type,
589+
sycl_queue=exec_q,
590+
order=order,
591+
)
496592
else:
497-
r = dpt.empty(
498-
res_shape,
499-
dtype=res_dt,
500-
usm_type=res_usm_type,
501-
sycl_queue=exec_q,
502-
order=order,
503-
)
593+
if res_dt != out.dtype:
594+
raise TypeError(
595+
f"Output array of type {res_dt} is needed, got {out.dtype}"
596+
)
597+
504598
buf1 = dpt.broadcast_to(buf1, res_shape)
505599
buf2 = dpt.broadcast_to(buf2, res_shape)
506600
ht_, _ = self.binary_fn_(
507601
src1=buf1,
508602
src2=buf2,
509-
dst=r,
603+
dst=out,
510604
sycl_queue=exec_q,
511605
depends=[copy1_ev, copy2_ev],
512606
)
513607
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
514-
return r
608+
return out

dpctl/tests/elementwise/test_abs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,20 @@ def test_abs_complex(dtype):
8989
np.testing.assert_allclose(
9090
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
9191
)
92+
93+
94+
@pytest.mark.parametrize("dtype", _all_dtypes[:-2])
95+
def test_abs_out_keyword(dtype):
96+
q = get_queue_or_skip()
97+
skip_if_dtype_not_supported(dtype, q)
98+
99+
arg_dt = np.dtype(dtype)
100+
input_shape = (10, 10, 10, 10)
101+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
102+
X[..., 0::2] = 1
103+
X[..., 1::2] = 0
104+
Y = dpt.empty_like(X, dtype=arg_dt)
105+
dpt.abs(X, Y)
106+
107+
expected_Y = dpt.asnumpy(X)
108+
assert np.allclose(dpt.asnumpy(Y), expected_Y)

0 commit comments

Comments
 (0)