@@ -47,9 +47,33 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
47
47
self .unary_fn_ = unary_dp_impl_fn
48
48
self .__doc__ = docs
49
49
50
- def __call__ (self , x , order = "K" ):
50
+ def __call__ (self , x , out = None , order = "K" ):
51
51
if not isinstance (x , dpt .usm_ndarray ):
52
52
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
53
+
54
+ if out is not None :
55
+ if not isinstance (out , dpt .usm_ndarray ):
56
+ raise TypeError (
57
+ f"output array must be of usm_ndarray type, got { type (out )} "
58
+ )
59
+
60
+ if out .shape != x .shape :
61
+ raise TypeError (
62
+ "The shape of input and output arrays are inconsistent."
63
+ f"Expected output shape is { x .shape } , got { out .shape } "
64
+ )
65
+
66
+ if ti ._array_overlap (x , out ):
67
+ raise TypeError ("Input and output arrays have memory overlap" )
68
+
69
+ if (
70
+ dpctl .utils .get_execution_queue ((x .sycl_queue , out .sycl_queue ))
71
+ is None
72
+ ):
73
+ raise TypeError (
74
+ "Input and output allocation queues are not compatible"
75
+ )
76
+
53
77
if order not in ["C" , "F" , "K" , "A" ]:
54
78
order = "K"
55
79
buf_dt , res_dt = _find_buf_dtype (
@@ -59,17 +83,24 @@ def __call__(self, x, order="K"):
59
83
raise RuntimeError
60
84
exec_q = x .sycl_queue
61
85
if buf_dt is None :
62
- if order == "K" :
63
- r = _empty_like_orderK (x , res_dt )
86
+ if out is None :
87
+ if order == "K" :
88
+ out = _empty_like_orderK (x , res_dt )
89
+ else :
90
+ if order == "A" :
91
+ order = "F" if x .flags .f_contiguous else "C"
92
+ out = dpt .empty_like (x , dtype = res_dt , order = order )
64
93
else :
65
- if order == "A" :
66
- order = "F" if x .flags .f_contiguous else "C"
67
- r = dpt .empty_like (x , dtype = res_dt , order = order )
94
+ if res_dt != out .dtype :
95
+ raise TypeError (
96
+ f"Expected output array of type { res_dt } is supported"
97
+ f", got { out .dtype } "
98
+ )
68
99
69
- ht , _ = self .unary_fn_ (x , r , sycl_queue = exec_q )
100
+ ht , _ = self .unary_fn_ (x , out , sycl_queue = exec_q )
70
101
ht .wait ()
71
102
72
- return r
103
+ return out
73
104
if order == "K" :
74
105
buf = _empty_like_orderK (x , buf_dt )
75
106
else :
@@ -80,16 +111,23 @@ def __call__(self, x, order="K"):
80
111
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
81
112
src = x , dst = buf , sycl_queue = exec_q
82
113
)
83
- if order == "K" :
84
- r = _empty_like_orderK (buf , res_dt )
114
+ if out is None :
115
+ if order == "K" :
116
+ out = _empty_like_orderK (buf , res_dt )
117
+ else :
118
+ out = dpt .empty_like (buf , dtype = res_dt , order = order )
85
119
else :
86
- r = dpt .empty_like (buf , dtype = res_dt , order = order )
120
+ if buf_dt != out .dtype :
121
+ raise TypeError (
122
+ f"Expected output array of type { buf_dt } is supported,"
123
+ f"got { out .dtype } "
124
+ )
87
125
88
- ht , _ = self .unary_fn_ (buf , r , sycl_queue = exec_q , depends = [copy_ev ])
126
+ ht , _ = self .unary_fn_ (buf , out , sycl_queue = exec_q , depends = [copy_ev ])
89
127
ht_copy_ev .wait ()
90
128
ht .wait ()
91
129
92
- return r
130
+ return out
93
131
94
132
95
133
def _get_queue_usm_type (o ):
@@ -281,7 +319,7 @@ def __str__(self):
281
319
def __repr__ (self ):
282
320
return f"<BinaryElementwiseFunc '{ self .name_ } '>"
283
321
284
- def __call__ (self , o1 , o2 , order = "K" ):
322
+ def __call__ (self , o1 , o2 , out = None , order = "K" ):
285
323
if order not in ["K" , "C" , "F" , "A" ]:
286
324
order = "K"
287
325
q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -358,6 +396,31 @@ def __call__(self, o1, o2, order="K"):
358
396
"supported types according to the casting rule ''safe''."
359
397
)
360
398
399
+ if out is not None :
400
+ if not isinstance (out , dpt .usm_ndarray ):
401
+ raise TypeError (
402
+ f"output array must be of usm_ndarray type, got { type (out )} "
403
+ )
404
+
405
+ if out .shape != o1_shape or out .shape != o2_shape :
406
+ raise TypeError (
407
+ "The shape of input and output arrays are inconsistent."
408
+ f"Expected output shape is { o1_shape } , got { out .shape } "
409
+ )
410
+
411
+ if ti ._array_overlap (o1 , out ) or ti ._array_overlap (o2 , out ):
412
+ raise TypeError ("Input and output arrays have memory overlap" )
413
+
414
+ if (
415
+ dpctl .utils .get_execution_queue (
416
+ (o1 .sycl_queue , o2 .sycl_queue , out .sycl_queue )
417
+ )
418
+ is None
419
+ ):
420
+ raise TypeError (
421
+ "Input and output allocation queues are not compatible"
422
+ )
423
+
361
424
if isinstance (o1 , dpt .usm_ndarray ):
362
425
src1 = o1
363
426
else :
@@ -368,37 +431,45 @@ def __call__(self, o1, o2, order="K"):
368
431
src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
369
432
370
433
if buf1_dt is None and buf2_dt is None :
371
- if order == "K" :
372
- r = _empty_like_pair_orderK (
373
- src1 , src2 , res_dt , res_usm_type , exec_q
374
- )
375
- else :
376
- if order == "A" :
377
- order = (
378
- "F"
379
- if all (
380
- arr .flags .f_contiguous
381
- for arr in (
382
- src1 ,
383
- src2 ,
434
+ if out is None :
435
+ if order == "K" :
436
+ out = _empty_like_pair_orderK (
437
+ src1 , src2 , res_dt , res_usm_type , exec_q
438
+ )
439
+ else :
440
+ if order == "A" :
441
+ order = (
442
+ "F"
443
+ if all (
444
+ arr .flags .f_contiguous
445
+ for arr in (
446
+ src1 ,
447
+ src2 ,
448
+ )
384
449
)
450
+ else "C"
385
451
)
386
- else "C"
452
+ out = dpt .empty (
453
+ res_shape ,
454
+ dtype = res_dt ,
455
+ usm_type = res_usm_type ,
456
+ sycl_queue = exec_q ,
457
+ order = order ,
387
458
)
388
- r = dpt . empty (
389
- res_shape ,
390
- dtype = res_dt ,
391
- usm_type = res_usm_type ,
392
- sycl_queue = exec_q ,
393
- order = order ,
394
- )
459
+ else :
460
+ if res_dt != out . dtype :
461
+ raise TypeError (
462
+ f"Output array of type { res_dt } is needed,"
463
+ f"got { out . dtype } "
464
+ )
465
+
395
466
src1 = dpt .broadcast_to (src1 , res_shape )
396
467
src2 = dpt .broadcast_to (src2 , res_shape )
397
468
ht_ , _ = self .binary_fn_ (
398
- src1 = src1 , src2 = src2 , dst = r , sycl_queue = exec_q
469
+ src1 = src1 , src2 = src2 , dst = out , sycl_queue = exec_q
399
470
)
400
471
ht_ .wait ()
401
- return r
472
+ return out
402
473
elif buf1_dt is None :
403
474
if order == "K" :
404
475
buf2 = _empty_like_orderK (src2 , buf2_dt )
@@ -409,30 +480,38 @@ def __call__(self, o1, o2, order="K"):
409
480
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
410
481
src = src2 , dst = buf2 , sycl_queue = exec_q
411
482
)
412
- if order == "K" :
413
- r = _empty_like_pair_orderK (
414
- src1 , buf2 , res_dt , res_usm_type , exec_q
415
- )
483
+ if out is None :
484
+ if order == "K" :
485
+ out = _empty_like_pair_orderK (
486
+ src1 , buf2 , res_dt , res_usm_type , exec_q
487
+ )
488
+ else :
489
+ out = dpt .empty (
490
+ res_shape ,
491
+ dtype = res_dt ,
492
+ usm_type = res_usm_type ,
493
+ sycl_queue = exec_q ,
494
+ order = order ,
495
+ )
416
496
else :
417
- r = dpt .empty (
418
- res_shape ,
419
- dtype = res_dt ,
420
- usm_type = res_usm_type ,
421
- sycl_queue = exec_q ,
422
- order = order ,
423
- )
497
+ if res_dt != out .dtype :
498
+ raise TypeError (
499
+ f"Output array of type { res_dt } is needed,"
500
+ f"got { out .dtype } "
501
+ )
502
+
424
503
src1 = dpt .broadcast_to (src1 , res_shape )
425
504
buf2 = dpt .broadcast_to (buf2 , res_shape )
426
505
ht_ , _ = self .binary_fn_ (
427
506
src1 = src1 ,
428
507
src2 = buf2 ,
429
- dst = r ,
508
+ dst = out ,
430
509
sycl_queue = exec_q ,
431
510
depends = [copy_ev ],
432
511
)
433
512
ht_copy_ev .wait ()
434
513
ht_ .wait ()
435
- return r
514
+ return out
436
515
elif buf2_dt is None :
437
516
if order == "K" :
438
517
buf1 = _empty_like_orderK (src1 , buf1_dt )
@@ -443,30 +522,38 @@ def __call__(self, o1, o2, order="K"):
443
522
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
444
523
src = src1 , dst = buf1 , sycl_queue = exec_q
445
524
)
446
- if order == "K" :
447
- r = _empty_like_pair_orderK (
448
- buf1 , src2 , res_dt , res_usm_type , exec_q
449
- )
525
+ if out is None :
526
+ if order == "K" :
527
+ out = _empty_like_pair_orderK (
528
+ buf1 , src2 , res_dt , res_usm_type , exec_q
529
+ )
530
+ else :
531
+ out = dpt .empty (
532
+ res_shape ,
533
+ dtype = res_dt ,
534
+ usm_type = res_usm_type ,
535
+ sycl_queue = exec_q ,
536
+ order = order ,
537
+ )
450
538
else :
451
- r = dpt .empty (
452
- res_shape ,
453
- dtype = res_dt ,
454
- usm_type = res_usm_type ,
455
- sycl_queue = exec_q ,
456
- order = order ,
457
- )
539
+ if res_dt != out .dtype :
540
+ raise TypeError (
541
+ f"Output array of type { res_dt } is needed,"
542
+ f"got { out .dtype } "
543
+ )
544
+
458
545
buf1 = dpt .broadcast_to (buf1 , res_shape )
459
546
src2 = dpt .broadcast_to (src2 , res_shape )
460
547
ht_ , _ = self .binary_fn_ (
461
548
src1 = buf1 ,
462
549
src2 = src2 ,
463
- dst = r ,
550
+ dst = out ,
464
551
sycl_queue = exec_q ,
465
552
depends = [copy_ev ],
466
553
)
467
554
ht_copy_ev .wait ()
468
555
ht_ .wait ()
469
- return r
556
+ return out
470
557
471
558
if order in ["K" , "A" ]:
472
559
if src1 .flags .f_contiguous and src2 .flags .f_contiguous :
@@ -489,26 +576,33 @@ def __call__(self, o1, o2, order="K"):
489
576
ht_copy2_ev , copy2_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
490
577
src = src2 , dst = buf2 , sycl_queue = exec_q
491
578
)
492
- if order == "K" :
493
- r = _empty_like_pair_orderK (
494
- buf1 , buf2 , res_dt , res_usm_type , exec_q
495
- )
579
+ if out is None :
580
+ if order == "K" :
581
+ out = _empty_like_pair_orderK (
582
+ buf1 , buf2 , res_dt , res_usm_type , exec_q
583
+ )
584
+ else :
585
+ out = dpt .empty (
586
+ res_shape ,
587
+ dtype = res_dt ,
588
+ usm_type = res_usm_type ,
589
+ sycl_queue = exec_q ,
590
+ order = order ,
591
+ )
496
592
else :
497
- r = dpt .empty (
498
- res_shape ,
499
- dtype = res_dt ,
500
- usm_type = res_usm_type ,
501
- sycl_queue = exec_q ,
502
- order = order ,
503
- )
593
+ if res_dt != out .dtype :
594
+ raise TypeError (
595
+ f"Output array of type { res_dt } is needed, got { out .dtype } "
596
+ )
597
+
504
598
buf1 = dpt .broadcast_to (buf1 , res_shape )
505
599
buf2 = dpt .broadcast_to (buf2 , res_shape )
506
600
ht_ , _ = self .binary_fn_ (
507
601
src1 = buf1 ,
508
602
src2 = buf2 ,
509
- dst = r ,
603
+ dst = out ,
510
604
sycl_queue = exec_q ,
511
605
depends = [copy1_ev , copy2_ev ],
512
606
)
513
607
dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
514
- return r
608
+ return out
0 commit comments