8
8
_integer_dtypes ,
9
9
_integer_or_boolean_dtypes ,
10
10
_numeric_dtypes ,
11
+ _promote_scalars ,
11
12
_real_floating_dtypes ,
12
13
_real_numeric_dtypes ,
13
14
complex64 ,
@@ -44,6 +45,7 @@ def acosh(x, /):
44
45
45
46
46
47
def add (x1 , x2 , / ):
48
+ x1 , x2 = _promote_scalars (x1 , x2 , "add" )
47
49
if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
48
50
raise TypeError ("Only numeric dtypes are allowed in add" )
49
51
return elemwise (nxp .add , x1 , x2 , dtype = result_type (x1 , x2 ))
@@ -68,6 +70,7 @@ def atan(x, /):
68
70
69
71
70
72
def atan2 (x1 , x2 , / ):
73
+ x1 , x2 = _promote_scalars (x1 , x2 , "atan2" )
71
74
if x1 .dtype not in _real_floating_dtypes or x2 .dtype not in _real_floating_dtypes :
72
75
raise TypeError ("Only real floating-point dtypes are allowed in atan2" )
73
76
return elemwise (nxp .atan2 , x1 , x2 , dtype = result_type (x1 , x2 ))
@@ -80,6 +83,7 @@ def atanh(x, /):
80
83
81
84
82
85
def bitwise_and (x1 , x2 , / ):
86
+ x1 , x2 = _promote_scalars (x1 , x2 , "bitwise_and" )
83
87
if (
84
88
x1 .dtype not in _integer_or_boolean_dtypes
85
89
or x2 .dtype not in _integer_or_boolean_dtypes
@@ -95,12 +99,14 @@ def bitwise_invert(x, /):
95
99
96
100
97
101
def bitwise_left_shift (x1 , x2 , / ):
102
+ x1 , x2 = _promote_scalars (x1 , x2 , "bitwise_left_shift" )
98
103
if x1 .dtype not in _integer_dtypes or x2 .dtype not in _integer_dtypes :
99
104
raise TypeError ("Only integer dtypes are allowed in bitwise_left_shift" )
100
105
return elemwise (nxp .bitwise_left_shift , x1 , x2 , dtype = result_type (x1 , x2 ))
101
106
102
107
103
108
def bitwise_or (x1 , x2 , / ):
109
+ x1 , x2 = _promote_scalars (x1 , x2 , "bitwise_or" )
104
110
if (
105
111
x1 .dtype not in _integer_or_boolean_dtypes
106
112
or x2 .dtype not in _integer_or_boolean_dtypes
@@ -110,12 +116,14 @@ def bitwise_or(x1, x2, /):
110
116
111
117
112
118
def bitwise_right_shift (x1 , x2 , / ):
119
+ x1 , x2 = _promote_scalars (x1 , x2 , "bitwise_right_shift" )
113
120
if x1 .dtype not in _integer_dtypes or x2 .dtype not in _integer_dtypes :
114
121
raise TypeError ("Only integer dtypes are allowed in bitwise_right_shift" )
115
122
return elemwise (nxp .bitwise_right_shift , x1 , x2 , dtype = result_type (x1 , x2 ))
116
123
117
124
118
125
def bitwise_xor (x1 , x2 , / ):
126
+ x1 , x2 = _promote_scalars (x1 , x2 , "bitwise_xor" )
119
127
if (
120
128
x1 .dtype not in _integer_or_boolean_dtypes
121
129
or x2 .dtype not in _integer_or_boolean_dtypes
@@ -172,6 +180,7 @@ def conj(x, /):
172
180
173
181
174
182
def copysign (x1 , x2 , / ):
183
+ x1 , x2 = _promote_scalars (x1 , x2 , "copysign" )
175
184
if x1 .dtype not in _real_numeric_dtypes or x2 .dtype not in _real_numeric_dtypes :
176
185
raise TypeError ("Only real numeric dtypes are allowed in copysign" )
177
186
return elemwise (nxp .copysign , x1 , x2 , dtype = result_type (x1 , x2 ))
@@ -190,6 +199,7 @@ def cosh(x, /):
190
199
191
200
192
201
def divide (x1 , x2 , / ):
202
+ x1 , x2 = _promote_scalars (x1 , x2 , "divide" )
193
203
if x1 .dtype not in _floating_dtypes or x2 .dtype not in _floating_dtypes :
194
204
raise TypeError ("Only floating-point dtypes are allowed in divide" )
195
205
return elemwise (nxp .divide , x1 , x2 , dtype = result_type (x1 , x2 ))
@@ -208,6 +218,7 @@ def expm1(x, /):
208
218
209
219
210
220
def equal (x1 , x2 , / ):
221
+ x1 , x2 = _promote_scalars (x1 , x2 , "equal" )
211
222
return elemwise (nxp .equal , x1 , x2 , dtype = nxp .bool )
212
223
213
224
@@ -221,20 +232,24 @@ def floor(x, /):
221
232
222
233
223
234
def floor_divide (x1 , x2 , / ):
235
+ x1 , x2 = _promote_scalars (x1 , x2 , "floor_divide" )
224
236
if x1 .dtype not in _real_numeric_dtypes or x2 .dtype not in _real_numeric_dtypes :
225
237
raise TypeError ("Only real numeric dtypes are allowed in floor_divide" )
226
238
return elemwise (nxp .floor_divide , x1 , x2 , dtype = result_type (x1 , x2 ))
227
239
228
240
229
241
def greater (x1 , x2 , / ):
242
+ x1 , x2 = _promote_scalars (x1 , x2 , "greater" )
230
243
return elemwise (nxp .greater , x1 , x2 , dtype = nxp .bool )
231
244
232
245
233
246
def greater_equal (x1 , x2 , / ):
247
+ x1 , x2 = _promote_scalars (x1 , x2 , "greater_equal" )
234
248
return elemwise (nxp .greater_equal , x1 , x2 , dtype = nxp .bool )
235
249
236
250
237
251
def hypot (x1 , x2 , / ):
252
+ x1 , x2 = _promote_scalars (x1 , x2 , "hypot" )
238
253
if x1 .dtype not in _real_numeric_dtypes or x2 .dtype not in _real_numeric_dtypes :
239
254
raise TypeError ("Only real numeric dtypes are allowed in hypot" )
240
255
return elemwise (nxp .hypot , x1 , x2 , dtype = result_type (x1 , x2 ))
@@ -269,10 +284,12 @@ def isnan(x, /):
269
284
270
285
271
286
def less (x1 , x2 , / ):
287
+ x1 , x2 = _promote_scalars (x1 , x2 , "less" )
272
288
return elemwise (nxp .less , x1 , x2 , dtype = nxp .bool )
273
289
274
290
275
291
def less_equal (x1 , x2 , / ):
292
+ x1 , x2 = _promote_scalars (x1 , x2 , "less_equal" )
276
293
return elemwise (nxp .less_equal , x1 , x2 , dtype = nxp .bool )
277
294
278
295
@@ -301,12 +318,14 @@ def log10(x, /):
301
318
302
319
303
320
def logaddexp (x1 , x2 , / ):
321
+ x1 , x2 = _promote_scalars (x1 , x2 , "logaddexp" )
304
322
if x1 .dtype not in _real_floating_dtypes or x2 .dtype not in _real_floating_dtypes :
305
323
raise TypeError ("Only real floating-point dtypes are allowed in logaddexp" )
306
324
return elemwise (nxp .logaddexp , x1 , x2 , dtype = result_type (x1 , x2 ))
307
325
308
326
309
327
def logical_and (x1 , x2 , / ):
328
+ x1 , x2 = _promote_scalars (x1 , x2 , "logical_and" )
310
329
if x1 .dtype not in _boolean_dtypes or x2 .dtype not in _boolean_dtypes :
311
330
raise TypeError ("Only boolean dtypes are allowed in logical_and" )
312
331
return elemwise (nxp .logical_and , x1 , x2 , dtype = nxp .bool )
@@ -319,30 +338,35 @@ def logical_not(x, /):
319
338
320
339
321
340
def logical_or (x1 , x2 , / ):
341
+ x1 , x2 = _promote_scalars (x1 , x2 , "logical_or" )
322
342
if x1 .dtype not in _boolean_dtypes or x2 .dtype not in _boolean_dtypes :
323
343
raise TypeError ("Only boolean dtypes are allowed in logical_or" )
324
344
return elemwise (nxp .logical_or , x1 , x2 , dtype = nxp .bool )
325
345
326
346
327
347
def logical_xor (x1 , x2 , / ):
348
+ x1 , x2 = _promote_scalars (x1 , x2 , "logical_xor" )
328
349
if x1 .dtype not in _boolean_dtypes or x2 .dtype not in _boolean_dtypes :
329
350
raise TypeError ("Only boolean dtypes are allowed in logical_xor" )
330
351
return elemwise (nxp .logical_xor , x1 , x2 , dtype = nxp .bool )
331
352
332
353
333
354
def maximum (x1 , x2 , / ):
355
+ x1 , x2 = _promote_scalars (x1 , x2 , "maximum" )
334
356
if x1 .dtype not in _real_numeric_dtypes or x2 .dtype not in _real_numeric_dtypes :
335
357
raise TypeError ("Only real numeric dtypes are allowed in maximum" )
336
358
return elemwise (nxp .maximum , x1 , x2 , dtype = result_type (x1 , x2 ))
337
359
338
360
339
361
def minimum (x1 , x2 , / ):
362
+ x1 , x2 = _promote_scalars (x1 , x2 , "minimum" )
340
363
if x1 .dtype not in _real_numeric_dtypes or x2 .dtype not in _real_numeric_dtypes :
341
364
raise TypeError ("Only real numeric dtypes are allowed in minimum" )
342
365
return elemwise (nxp .minimum , x1 , x2 , dtype = result_type (x1 , x2 ))
343
366
344
367
345
368
def multiply (x1 , x2 , / ):
369
+ x1 , x2 = _promote_scalars (x1 , x2 , "multiply" )
346
370
if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
347
371
raise TypeError ("Only numeric dtypes are allowed in multiply" )
348
372
return elemwise (nxp .multiply , x1 , x2 , dtype = result_type (x1 , x2 ))
@@ -355,6 +379,7 @@ def negative(x, /):
355
379
356
380
357
381
def not_equal (x1 , x2 , / ):
382
+ x1 , x2 = _promote_scalars (x1 , x2 , "not_equal" )
358
383
return elemwise (nxp .not_equal , x1 , x2 , dtype = nxp .bool )
359
384
360
385
@@ -365,6 +390,7 @@ def positive(x, /):
365
390
366
391
367
392
def pow (x1 , x2 , / ):
393
+ x1 , x2 = _promote_scalars (x1 , x2 , "pow" )
368
394
if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
369
395
raise TypeError ("Only numeric dtypes are allowed in pow" )
370
396
return elemwise (nxp .pow , x1 , x2 , dtype = result_type (x1 , x2 ))
@@ -381,6 +407,7 @@ def real(x, /):
381
407
382
408
383
409
def remainder (x1 , x2 , / ):
410
+ x1 , x2 = _promote_scalars (x1 , x2 , "remainder" )
384
411
if x1 .dtype not in _real_numeric_dtypes or x2 .dtype not in _real_numeric_dtypes :
385
412
raise TypeError ("Only real numeric dtypes are allowed in remainder" )
386
413
return elemwise (nxp .remainder , x1 , x2 , dtype = result_type (x1 , x2 ))
@@ -429,6 +456,7 @@ def square(x, /):
429
456
430
457
431
458
def subtract (x1 , x2 , / ):
459
+ x1 , x2 = _promote_scalars (x1 , x2 , "subtract" )
432
460
if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
433
461
raise TypeError ("Only numeric dtypes are allowed in subtract" )
434
462
return elemwise (nxp .subtract , x1 , x2 , dtype = result_type (x1 , x2 ))
0 commit comments