Skip to content

Commit a504e96

Browse files
authored
Merge pull request #1147 from IntelPython/where-impl-gh-1120
Implements dpctl.tensor.where
2 parents 7bbfce1 + db43042 commit a504e96

File tree

10 files changed

+1298
-0
lines changed

10 files changed

+1298
-0
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pybind11_add_module(${python_module_name} MODULE
4343
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
4444
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
4545
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
46+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
4647
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
4748
)
4849
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
usm_ndarray_str,
8787
)
8888
from dpctl.tensor._reshape import reshape
89+
from dpctl.tensor._search_functions import where
8990
from dpctl.tensor._usmarray import usm_ndarray
9091

9192
from ._constants import e, inf, nan, newaxis, pi
@@ -128,6 +129,7 @@
128129
"from_dlpack",
129130
"tril",
130131
"triu",
132+
"where",
131133
"dtype",
132134
"isdtype",
133135
"bool",

dpctl/tensor/_search_functions.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import dpctl
18+
import dpctl.tensor as dpt
19+
import dpctl.tensor._tensor_impl as ti
20+
from dpctl.tensor._manipulation_functions import _broadcast_shapes
21+
22+
from ._type_utils import _all_data_types, _can_cast
23+
24+
25+
def _where_result_type(dt1, dt2, dev):
26+
res_dtype = dpt.result_type(dt1, dt2)
27+
fp16 = dev.has_aspect_fp16
28+
fp64 = dev.has_aspect_fp64
29+
30+
all_dts = _all_data_types(fp16, fp64)
31+
if res_dtype in all_dts:
32+
return res_dtype
33+
else:
34+
for res_dtype_ in all_dts:
35+
if _can_cast(dt1, res_dtype_, fp16, fp64) and _can_cast(
36+
dt2, res_dtype_, fp16, fp64
37+
):
38+
return res_dtype_
39+
return None
40+
41+
42+
def where(condition, x1, x2):
43+
"""where(condition, x1, x2)
44+
45+
Returns :class:`dpctl.tensor.usm_ndarray` with elements chosen
46+
from `x1` or `x2` depending on `condition`.
47+
48+
Args:
49+
condition (usm_ndarray): When True yields from `x1`,
50+
and otherwise yields from `x2`.
51+
Must be compatible with `x1` and `x2` according
52+
to broadcasting rules.
53+
x1 (usm_ndarray): Array from which values are chosen when
54+
`condition` is True.
55+
Must be compatible with `condition` and `x2` according
56+
to broadcasting rules.
57+
x2 (usm_ndarray): Array from which values are chosen when
58+
`condition` is not True.
59+
Must be compatible with `condition` and `x2` according
60+
to broadcasting rules.
61+
62+
Returns:
63+
usm_ndarray:
64+
An array with elements from `x1` where `condition` is True,
65+
and elements from `x2` elsewhere.
66+
67+
The data type of the returned array is determined by applying
68+
the Type Promotion Rules to `x1` and `x2`.
69+
70+
The memory layout of the returned array is
71+
F-contiguous (column-major) when all inputs are F-contiguous,
72+
and C-contiguous (row-major) otherwise.
73+
"""
74+
if not isinstance(condition, dpt.usm_ndarray):
75+
raise TypeError(
76+
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}"
77+
)
78+
if not isinstance(x1, dpt.usm_ndarray):
79+
raise TypeError(
80+
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x1)}"
81+
)
82+
if not isinstance(x2, dpt.usm_ndarray):
83+
raise TypeError(
84+
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}"
85+
)
86+
exec_q = dpctl.utils.get_execution_queue(
87+
(
88+
condition.sycl_queue,
89+
x1.sycl_queue,
90+
x2.sycl_queue,
91+
)
92+
)
93+
if exec_q is None:
94+
raise dpctl.utils.ExecutionPlacementError
95+
dst_usm_type = dpctl.utils.get_coerced_usm_type(
96+
(
97+
condition.usm_type,
98+
x1.usm_type,
99+
x2.usm_type,
100+
)
101+
)
102+
103+
x1_dtype = x1.dtype
104+
x2_dtype = x2.dtype
105+
dst_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device)
106+
if dst_dtype is None:
107+
raise TypeError(
108+
"function 'where' does not support input "
109+
f"types ({x1_dtype}, {x2_dtype}), "
110+
"and the inputs could not be safely coerced "
111+
"to any supported types according to the casting rule ''safe''."
112+
)
113+
114+
res_shape = _broadcast_shapes(condition, x1, x2)
115+
116+
if condition.size == 0:
117+
return dpt.empty(
118+
res_shape, dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q
119+
)
120+
121+
deps = []
122+
wait_list = []
123+
if x1_dtype != dst_dtype:
124+
_x1 = dpt.empty_like(x1, dtype=dst_dtype)
125+
ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
126+
src=x1, dst=_x1, sycl_queue=exec_q
127+
)
128+
x1 = _x1
129+
deps.append(copy1_ev)
130+
wait_list.append(ht_copy1_ev)
131+
132+
if x2_dtype != dst_dtype:
133+
_x2 = dpt.empty_like(x2, dtype=dst_dtype)
134+
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
135+
src=x2, dst=_x2, sycl_queue=exec_q
136+
)
137+
x2 = _x2
138+
deps.append(copy2_ev)
139+
wait_list.append(ht_copy2_ev)
140+
141+
condition = dpt.broadcast_to(condition, res_shape)
142+
x1 = dpt.broadcast_to(x1, res_shape)
143+
x2 = dpt.broadcast_to(x2, res_shape)
144+
145+
# dst is F-contiguous when all inputs are F contiguous
146+
# otherwise, defaults to C-contiguous
147+
if all(
148+
(
149+
condition.flags.fnc,
150+
x1.flags.fnc,
151+
x2.flags.fnc,
152+
)
153+
):
154+
order = "F"
155+
else:
156+
order = "C"
157+
158+
dst = dpt.empty(
159+
res_shape,
160+
dtype=dst_dtype,
161+
order=order,
162+
usm_type=dst_usm_type,
163+
sycl_queue=exec_q,
164+
)
165+
166+
hev, _ = ti._where(
167+
condition=condition,
168+
x1=x1,
169+
x2=x2,
170+
dst=dst,
171+
sycl_queue=exec_q,
172+
depends=deps,
173+
)
174+
dpctl.SyclEvent.wait_for(wait_list)
175+
hev.wait()
176+
177+
return dst

dpctl/tensor/_type_utils.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import dpctl.tensor as dpt
18+
19+
20+
def _all_data_types(_fp16, _fp64):
21+
if _fp64:
22+
if _fp16:
23+
return [
24+
dpt.bool,
25+
dpt.int8,
26+
dpt.uint8,
27+
dpt.int16,
28+
dpt.uint16,
29+
dpt.int32,
30+
dpt.uint32,
31+
dpt.int64,
32+
dpt.uint64,
33+
dpt.float16,
34+
dpt.float32,
35+
dpt.float64,
36+
dpt.complex64,
37+
dpt.complex128,
38+
]
39+
else:
40+
return [
41+
dpt.bool,
42+
dpt.int8,
43+
dpt.uint8,
44+
dpt.int16,
45+
dpt.uint16,
46+
dpt.int32,
47+
dpt.uint32,
48+
dpt.int64,
49+
dpt.uint64,
50+
dpt.float32,
51+
dpt.float64,
52+
dpt.complex64,
53+
dpt.complex128,
54+
]
55+
else:
56+
if _fp16:
57+
return [
58+
dpt.bool,
59+
dpt.int8,
60+
dpt.uint8,
61+
dpt.int16,
62+
dpt.uint16,
63+
dpt.int32,
64+
dpt.uint32,
65+
dpt.int64,
66+
dpt.uint64,
67+
dpt.float16,
68+
dpt.float32,
69+
dpt.complex64,
70+
]
71+
else:
72+
return [
73+
dpt.bool,
74+
dpt.int8,
75+
dpt.uint8,
76+
dpt.int16,
77+
dpt.uint16,
78+
dpt.int32,
79+
dpt.uint32,
80+
dpt.int64,
81+
dpt.uint64,
82+
dpt.float32,
83+
dpt.complex64,
84+
]
85+
86+
87+
def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool):
88+
"""
89+
Return True if data type `dt` is the
90+
maximal size inexact data type
91+
"""
92+
if _fp64:
93+
return dt in [dpt.float64, dpt.complex128]
94+
return dt in [dpt.float32, dpt.complex64]
95+
96+
97+
def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool):
98+
"""
99+
Can `from_` be cast to `to_` safely on a device with
100+
fp16 and fp64 aspects as given?
101+
"""
102+
can_cast_v = dpt.can_cast(from_, to_) # ask NumPy
103+
if _fp16 and _fp64:
104+
return can_cast_v
105+
if not can_cast_v:
106+
if (
107+
from_.kind in "biu"
108+
and to_.kind in "fc"
109+
and _is_maximal_inexact_type(to_, _fp16, _fp64)
110+
):
111+
return True
112+
113+
return can_cast_v

0 commit comments

Comments
 (0)