Skip to content

Commit cde757c

Browse files
Added implementation of dpctl.tensor.equal
``` In [4]: if dpt.all(dpt.equal( dpt.arange(30), dpt.arange(50)[:30])): print("Equal") Equal ```
1 parent 61100e4 commit cde757c

File tree

4 files changed

+1178
-170
lines changed

4 files changed

+1178
-170
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
add,
9797
cos,
9898
divide,
99+
equal,
99100
isfinite,
100101
isinf,
101102
isnan,
@@ -185,4 +186,5 @@
185186
"isfinite",
186187
"sqrt",
187188
"divide",
189+
"equal",
188190
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 186 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22

33
from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc
44

5-
# ABS
5+
# U01: ==== ABS (x)
66
_abs_docstring_ = """
77
Calculate the absolute value element-wise.
88
"""
99

1010
abs = UnaryElementwiseFunc("abs", ti._abs_result_type, ti._abs, _abs_docstring_)
1111

12-
# ADD
12+
# U02: ==== ACOS (x)
13+
# FIXME: implement U02
14+
15+
# U03: ===== ACOSH (x)
16+
# FIXME: implement U03
17+
18+
# B01: ===== ADD (x1, x2)
1319

1420
_add_docstring_ = """
1521
add(x1, x2, order='K')
@@ -31,8 +37,58 @@
3137
"add", ti._add_result_type, ti._add, _add_docstring_
3238
)
3339

34-
# DIVIDE
40+
# U04: ===== ASIN (x)
41+
# FIXME: implement U04
42+
43+
# U05: ===== ASINH (x)
44+
# FIXME: implement U05
45+
46+
# U06: ===== ATAN (x)
47+
# FIXME: implement U06
48+
49+
# B02: ===== ATAN2 (x1, x2)
50+
# FIXME: implemetn B02
51+
52+
# U07: ===== ATANH (x)
53+
# FIXME: implemetn U07
54+
55+
# B03: ===== BITWISE_AND (x1, x2)
56+
# FIXME: implemetn B03
57+
58+
# B04: ===== BITWISE_LEFT_SHIFT (x1, x2)
59+
# FIXME: implement B04
60+
61+
# U08: ===== BITWISE_INVERT (x)
62+
# FIXME: implement U08
63+
64+
# B05: ===== BITWISE_OR (x1, x2)
65+
# FIXME: implement B05
66+
67+
# B06: ===== BITWISE_RIGHT_SHIFT (x1, x2)
68+
# FIXME: implement B06
69+
70+
# B07: ===== BITWISE_XOR (x1, x2)
71+
# FIXME: implement B07
72+
73+
# U09: ==== CEIL (x)
74+
# FIXME: implement U09
75+
76+
# U10: ==== CONJ (x)
77+
# FIXME: implement U10
78+
79+
# U11: ==== COS (x)
80+
_cos_docstring = """
81+
cos(x, order='K')
82+
83+
Computes cosine for each element `x_i` for input array `x`.
84+
"""
85+
86+
cos = UnaryElementwiseFunc("cos", ti._cos_result_type, ti._cos, _cos_docstring)
87+
88+
# U12: ==== COSH (x)
89+
# FIXME: implement U12
3590

91+
# B08: ==== DIVIDE (x1, x2)
3692
_divide_docstring_ = """
3793
divide(x1, x2, order='K')
3894
@@ -49,23 +105,56 @@
49105
an array containing the result of element-wise division. The data type
50106
of the returned array is determined by the Type Promotion Rules.
51107
"""
108+
52109
divide = BinaryElementwiseFunc(
53110
"divide", ti._divide_result_type, ti._divide, _divide_docstring_
54111
)
55112

113+
# B09: ==== EQUAL (x1, x2)
114+
_equal_docstring_ = """
115+
equal(x1, x2, order='K')
56116
57-
# COS
58-
59-
_cos_docstring = """
60-
cos(x, order='K')
117+
Calculates equality test results for each element `x1_i` of the input array `x1`
118+
with the respective element `x2_i` of the input array `x2`.
61119
62-
Computes cosine for each element `x_i` for input array `x`.
120+
Args:
121+
x1 (usm_ndarray):
122+
First input array, expected to have numeric data type.
123+
x2 (usm_ndarray):
124+
Second input array, also expected to have numeric data type.
125+
Returns:
126+
usm_narray:
127+
an array containing the result of element-wise equality comparison.
128+
The data type of the returned array is determined by the
129+
Type Promotion Rules.
63130
"""
64131

65-
cos = UnaryElementwiseFunc("cos", ti._cos_result_type, ti._cos, _cos_docstring)
132+
equal = BinaryElementwiseFunc(
133+
"equal", ti._equal_result_type, ti._equal, _equal_docstring_
134+
)
135+
136+
# U13: ==== EXP (x)
137+
# FIXME: implement U13
138+
139+
# U14: ==== EXPM1 (x)
140+
# FIXME: implement U14
141+
142+
# U15: ==== FLOOR (x)
143+
# FIXME: implement U15
66144

67-
# ISFINITE
145+
# B10: ==== FLOOR_DIVIDE (x1, x2)
146+
# FIXME: implement B10
68147

148+
# B11: ==== GREATER (x1, x2)
149+
# FIXME: implement B11
150+
151+
# B12: ==== GREATER_EQUAL (x1, x2)
152+
# FIXME: implement B12
153+
154+
# U16: ==== IMAG (x)
155+
# FIXME: implement U16
156+
157+
# U17: ==== ISFINITE (x)
69158
_isfinite_docstring_ = """
70159
Computes if every element of input array is a finite number.
71160
"""
@@ -74,8 +163,16 @@
74163
"isfinite", ti._isfinite_result_type, ti._isfinite, _isfinite_docstring_
75164
)
76165

77-
# ISNAN
166+
# U18: ==== ISINF (x)
167+
_isinf_docstring_ = """
168+
Computes if every element of input array is an infinity.
169+
"""
170+
171+
isinf = UnaryElementwiseFunc(
172+
"isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_
173+
)
78174

175+
# U19: ==== ISNAN (x)
79176
_isnan_docstring_ = """
80177
Computes if every element of input array is a NaN.
81178
"""
@@ -84,22 +181,92 @@
84181
"isnan", ti._isnan_result_type, ti._isnan, _isnan_docstring_
85182
)
86183

87-
# ISINF
184+
# B13: ==== LESS (x1, x2)
185+
# FIXME: implement B13
88186

89-
_isinf_docstring_ = """
90-
Computes if every element of input array is an infinity.
91-
"""
187+
# B14: ==== LESS_EQUAL (x1, x2)
188+
# FIXME: implement B14
92189

93-
isinf = UnaryElementwiseFunc(
94-
"isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_
95-
)
190+
# U20: ==== LOG (x)
191+
# FIXME: implement U20
192+
193+
# U21: ==== LOG1P (x)
194+
# FIXME: implement U21
195+
196+
# U22: ==== LOG2 (x)
197+
# FIXME: implement U22
198+
199+
# U23: ==== LOG10 (x)
200+
# FIXME: implement U23
201+
202+
# B15: ==== LOGADDEXP (x1, x2)
203+
# FIXME: implement B15
96204

97-
# SQRT
205+
# B16: ==== LOGICAL_AND (x1, x2)
206+
# FIXME: implement B16
98207

208+
# U24: ==== LOGICAL_NOT (x)
209+
# FIXME: implement U24
210+
211+
# B17: ==== LOGICAL_OR (x1, x2)
212+
# FIXME: implement B17
213+
214+
# B18: ==== LOGICAL_XOR (x1, x2)
215+
# FIXME: implement B18
216+
217+
# B19: ==== MULTIPLY (x1, x2)
218+
# FIXME: implement B19
219+
220+
# U25: ==== NEGATIVE (x)
221+
# FIXME: implement U25
222+
223+
# B20: ==== NOT_EQUAL (x1, x2)
224+
# FIXME: implement B20
225+
226+
# U26: ==== POSITIVE (x)
227+
# FIXME: implement U26
228+
229+
# B21: ==== POW (x1, x2)
230+
# FIXME: implement B21
231+
232+
# U27: ==== REAL (x)
233+
# FIXME: implement U27
234+
235+
# B22: ==== REMAINDER (x1, x2)
236+
# FIXME: implement B22
237+
238+
# U28: ==== ROUND (x)
239+
# FIXME: implement U28
240+
241+
# U29: ==== SIGN (x)
242+
# FIXME: implement U29
243+
244+
# U30: ==== SIN (x)
245+
# FIXME: implement U30
246+
247+
# U31: ==== SINH (x)
248+
# FIXME: implement U31
249+
250+
# U32: ==== SQUARE (x)
251+
# FIXME: implement U32
252+
253+
# U33: ==== SQRT (x)
99254
_sqrt_docstring_ = """
100255
Computes sqrt for each element `x_i` for input array `x`.
101256
"""
102257

103258
sqrt = UnaryElementwiseFunc(
104259
"sqrt", ti._sqrt_result_type, ti._sqrt, _sqrt_docstring_
105260
)
261+
262+
# B23: ==== SUBTRACT (x1, x2)
263+
# FIXME: implement B23
264+
265+
# U34: ==== TAN (x)
266+
# FIXME: implement U34
267+
268+
# U35: ==== TANH (x)
269+
# FIXME: implement U35
270+
271+
# U36: ==== TRUNC (x)
272+
# FIXME: implement U36

0 commit comments

Comments
 (0)