Skip to content

Commit 3e5d6e2

Browse files
authored
Add scalar support to binary element-wise ops (#686)
* Add scalar support to binary element-wise ops * Update `where` to use _promote_scalars utility
1 parent b62d9ea commit 3e5d6e2

File tree

4 files changed

+59
-10
lines changed

4 files changed

+59
-10
lines changed

cubed/array_api/dtypes.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@
9090

9191

9292
# A Cubed-specific utility.
93-
def _upcast_integral_dtypes(x, dtype=None, *, allowed_dtypes=("numeric",), fname=None, device=None):
93+
def _upcast_integral_dtypes(
94+
x, dtype=None, *, allowed_dtypes=("numeric",), fname=None, device=None
95+
):
9496
"""Ensure the input dtype is allowed. If it's None, provide a good default dtype."""
9597
dtypes = __array_namespace_info__().default_dtypes(device=device)
9698

@@ -116,3 +118,16 @@ def _upcast_integral_dtypes(x, dtype=None, *, allowed_dtypes=("numeric",), fname
116118
dtype = x.dtype
117119

118120
return dtype
121+
122+
123+
def _promote_scalars(x1, x2, op):
124+
"""Promote at most one of x1 or x2 to an array from a Python scalar"""
125+
x1_is_scalar = isinstance(x1, (int, float, complex, bool))
126+
x2_is_scalar = isinstance(x2, (int, float, complex, bool))
127+
if x1_is_scalar and x2_is_scalar:
128+
raise TypeError(f"At least one of x1 and x2 must be an array in {op}")
129+
elif x1_is_scalar:
130+
x1 = x2._promote_scalar(x1)
131+
elif x2_is_scalar:
132+
x2 = x1._promote_scalar(x2)
133+
return x1, x2

cubed/array_api/elementwise_functions.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
_integer_dtypes,
99
_integer_or_boolean_dtypes,
1010
_numeric_dtypes,
11+
_promote_scalars,
1112
_real_floating_dtypes,
1213
_real_numeric_dtypes,
1314
complex64,
@@ -44,6 +45,7 @@ def acosh(x, /):
4445

4546

4647
def add(x1, x2, /):
48+
x1, x2 = _promote_scalars(x1, x2, "add")
4749
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
4850
raise TypeError("Only numeric dtypes are allowed in add")
4951
return elemwise(nxp.add, x1, x2, dtype=result_type(x1, x2))
@@ -68,6 +70,7 @@ def atan(x, /):
6870

6971

7072
def atan2(x1, x2, /):
73+
x1, x2 = _promote_scalars(x1, x2, "atan2")
7174
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
7275
raise TypeError("Only real floating-point dtypes are allowed in atan2")
7376
return elemwise(nxp.atan2, x1, x2, dtype=result_type(x1, x2))
@@ -80,6 +83,7 @@ def atanh(x, /):
8083

8184

8285
def bitwise_and(x1, x2, /):
86+
x1, x2 = _promote_scalars(x1, x2, "bitwise_and")
8387
if (
8488
x1.dtype not in _integer_or_boolean_dtypes
8589
or x2.dtype not in _integer_or_boolean_dtypes
@@ -95,12 +99,14 @@ def bitwise_invert(x, /):
9599

96100

97101
def bitwise_left_shift(x1, x2, /):
102+
x1, x2 = _promote_scalars(x1, x2, "bitwise_left_shift")
98103
if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
99104
raise TypeError("Only integer dtypes are allowed in bitwise_left_shift")
100105
return elemwise(nxp.bitwise_left_shift, x1, x2, dtype=result_type(x1, x2))
101106

102107

103108
def bitwise_or(x1, x2, /):
109+
x1, x2 = _promote_scalars(x1, x2, "bitwise_or")
104110
if (
105111
x1.dtype not in _integer_or_boolean_dtypes
106112
or x2.dtype not in _integer_or_boolean_dtypes
@@ -110,12 +116,14 @@ def bitwise_or(x1, x2, /):
110116

111117

112118
def bitwise_right_shift(x1, x2, /):
119+
x1, x2 = _promote_scalars(x1, x2, "bitwise_right_shift")
113120
if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
114121
raise TypeError("Only integer dtypes are allowed in bitwise_right_shift")
115122
return elemwise(nxp.bitwise_right_shift, x1, x2, dtype=result_type(x1, x2))
116123

117124

118125
def bitwise_xor(x1, x2, /):
126+
x1, x2 = _promote_scalars(x1, x2, "bitwise_xor")
119127
if (
120128
x1.dtype not in _integer_or_boolean_dtypes
121129
or x2.dtype not in _integer_or_boolean_dtypes
@@ -172,6 +180,7 @@ def conj(x, /):
172180

173181

174182
def copysign(x1, x2, /):
183+
x1, x2 = _promote_scalars(x1, x2, "copysign")
175184
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
176185
raise TypeError("Only real numeric dtypes are allowed in copysign")
177186
return elemwise(nxp.copysign, x1, x2, dtype=result_type(x1, x2))
@@ -190,6 +199,7 @@ def cosh(x, /):
190199

191200

192201
def divide(x1, x2, /):
202+
x1, x2 = _promote_scalars(x1, x2, "divide")
193203
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
194204
raise TypeError("Only floating-point dtypes are allowed in divide")
195205
return elemwise(nxp.divide, x1, x2, dtype=result_type(x1, x2))
@@ -208,6 +218,7 @@ def expm1(x, /):
208218

209219

210220
def equal(x1, x2, /):
221+
x1, x2 = _promote_scalars(x1, x2, "equal")
211222
return elemwise(nxp.equal, x1, x2, dtype=nxp.bool)
212223

213224

@@ -221,20 +232,24 @@ def floor(x, /):
221232

222233

223234
def floor_divide(x1, x2, /):
235+
x1, x2 = _promote_scalars(x1, x2, "floor_divide")
224236
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
225237
raise TypeError("Only real numeric dtypes are allowed in floor_divide")
226238
return elemwise(nxp.floor_divide, x1, x2, dtype=result_type(x1, x2))
227239

228240

229241
def greater(x1, x2, /):
242+
x1, x2 = _promote_scalars(x1, x2, "greater")
230243
return elemwise(nxp.greater, x1, x2, dtype=nxp.bool)
231244

232245

233246
def greater_equal(x1, x2, /):
247+
x1, x2 = _promote_scalars(x1, x2, "greater_equal")
234248
return elemwise(nxp.greater_equal, x1, x2, dtype=nxp.bool)
235249

236250

237251
def hypot(x1, x2, /):
252+
x1, x2 = _promote_scalars(x1, x2, "hypot")
238253
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
239254
raise TypeError("Only real numeric dtypes are allowed in hypot")
240255
return elemwise(nxp.hypot, x1, x2, dtype=result_type(x1, x2))
@@ -269,10 +284,12 @@ def isnan(x, /):
269284

270285

271286
def less(x1, x2, /):
287+
x1, x2 = _promote_scalars(x1, x2, "less")
272288
return elemwise(nxp.less, x1, x2, dtype=nxp.bool)
273289

274290

275291
def less_equal(x1, x2, /):
292+
x1, x2 = _promote_scalars(x1, x2, "less_equal")
276293
return elemwise(nxp.less_equal, x1, x2, dtype=nxp.bool)
277294

278295

@@ -301,12 +318,14 @@ def log10(x, /):
301318

302319

303320
def logaddexp(x1, x2, /):
321+
x1, x2 = _promote_scalars(x1, x2, "logaddexp")
304322
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
305323
raise TypeError("Only real floating-point dtypes are allowed in logaddexp")
306324
return elemwise(nxp.logaddexp, x1, x2, dtype=result_type(x1, x2))
307325

308326

309327
def logical_and(x1, x2, /):
328+
x1, x2 = _promote_scalars(x1, x2, "logical_and")
310329
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
311330
raise TypeError("Only boolean dtypes are allowed in logical_and")
312331
return elemwise(nxp.logical_and, x1, x2, dtype=nxp.bool)
@@ -319,30 +338,35 @@ def logical_not(x, /):
319338

320339

321340
def logical_or(x1, x2, /):
341+
x1, x2 = _promote_scalars(x1, x2, "logical_or")
322342
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
323343
raise TypeError("Only boolean dtypes are allowed in logical_or")
324344
return elemwise(nxp.logical_or, x1, x2, dtype=nxp.bool)
325345

326346

327347
def logical_xor(x1, x2, /):
348+
x1, x2 = _promote_scalars(x1, x2, "logical_xor")
328349
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
329350
raise TypeError("Only boolean dtypes are allowed in logical_xor")
330351
return elemwise(nxp.logical_xor, x1, x2, dtype=nxp.bool)
331352

332353

333354
def maximum(x1, x2, /):
355+
x1, x2 = _promote_scalars(x1, x2, "maximum")
334356
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
335357
raise TypeError("Only real numeric dtypes are allowed in maximum")
336358
return elemwise(nxp.maximum, x1, x2, dtype=result_type(x1, x2))
337359

338360

339361
def minimum(x1, x2, /):
362+
x1, x2 = _promote_scalars(x1, x2, "minimum")
340363
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
341364
raise TypeError("Only real numeric dtypes are allowed in minimum")
342365
return elemwise(nxp.minimum, x1, x2, dtype=result_type(x1, x2))
343366

344367

345368
def multiply(x1, x2, /):
369+
x1, x2 = _promote_scalars(x1, x2, "multiply")
346370
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
347371
raise TypeError("Only numeric dtypes are allowed in multiply")
348372
return elemwise(nxp.multiply, x1, x2, dtype=result_type(x1, x2))
@@ -355,6 +379,7 @@ def negative(x, /):
355379

356380

357381
def not_equal(x1, x2, /):
382+
x1, x2 = _promote_scalars(x1, x2, "not_equal")
358383
return elemwise(nxp.not_equal, x1, x2, dtype=nxp.bool)
359384

360385

@@ -365,6 +390,7 @@ def positive(x, /):
365390

366391

367392
def pow(x1, x2, /):
393+
x1, x2 = _promote_scalars(x1, x2, "pow")
368394
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
369395
raise TypeError("Only numeric dtypes are allowed in pow")
370396
return elemwise(nxp.pow, x1, x2, dtype=result_type(x1, x2))
@@ -381,6 +407,7 @@ def real(x, /):
381407

382408

383409
def remainder(x1, x2, /):
410+
x1, x2 = _promote_scalars(x1, x2, "remainder")
384411
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
385412
raise TypeError("Only real numeric dtypes are allowed in remainder")
386413
return elemwise(nxp.remainder, x1, x2, dtype=result_type(x1, x2))
@@ -429,6 +456,7 @@ def square(x, /):
429456

430457

431458
def subtract(x1, x2, /):
459+
x1, x2 = _promote_scalars(x1, x2, "subtract")
432460
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
433461
raise TypeError("Only numeric dtypes are allowed in subtract")
434462
return elemwise(nxp.subtract, x1, x2, dtype=result_type(x1, x2))

cubed/array_api/searching_functions.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from cubed.array_api.creation_functions import asarray, zeros_like
22
from cubed.array_api.data_type_functions import result_type
3-
from cubed.array_api.dtypes import _real_numeric_dtypes
3+
from cubed.array_api.dtypes import _promote_scalars, _real_numeric_dtypes
44
from cubed.array_api.manipulation_functions import reshape
55
from cubed.array_api.statistical_functions import max
66
from cubed.backend_array_api import namespace as nxp
@@ -88,13 +88,6 @@ def _searchsorted(x, y, side):
8888

8989

9090
def where(condition, x1, x2, /):
91-
x1_is_scalar = isinstance(x1, (int, float, complex, bool))
92-
x2_is_scalar = isinstance(x2, (int, float, complex, bool))
93-
if x1_is_scalar and x2_is_scalar:
94-
raise TypeError("At least one of x1 and x2 must be an array in where")
95-
elif x1_is_scalar:
96-
x1 = x2._promote_scalar(x1)
97-
elif x2_is_scalar:
98-
x2 = x1._promote_scalar(x2)
91+
x1, x2 = _promote_scalars(x1, x2, "where")
9992
dtype = result_type(x1, x2)
10093
return elemwise(nxp.where, condition, x1, x2, dtype=dtype)

cubed/tests/test_array_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,19 @@ def test_add_different_chunks_fail(spec, executor):
194194
assert_array_equal(c.compute(executor=executor), np.ones((10,)) + np.ones((10,)))
195195

196196

197+
def test_add_scalars():
198+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))
199+
200+
b = xp.add(a, 1)
201+
assert_array_equal(b.compute(), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]))
202+
203+
c = xp.add(2, a)
204+
assert_array_equal(c.compute(), np.array([[3, 4, 5], [6, 7, 8], [9, 10, 11]]))
205+
206+
with pytest.raises(TypeError):
207+
xp.add(1, 2)
208+
209+
197210
@pytest.mark.parametrize(
198211
"min, max",
199212
[

0 commit comments

Comments
 (0)