Skip to content

Commit 77db7a6

Browse files
author
jax authors
committed
Merge pull request #20637 from jakevdp:array-api-scalar
PiperOrigin-RevId: 623184036
2 parents f5cc272 + c19c1a7 commit 77db7a6

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

jax/experimental/array_api/_data_type_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
import builtins
1616
import functools
1717
from typing import NamedTuple
1818
import jax
@@ -183,6 +183,8 @@ def isdtype(dtype, kind):
183183
def result_type(*arrays_and_dtypes):
184184
dtypes = []
185185
for val in arrays_and_dtypes:
186+
if isinstance(val, (builtins.bool, int, float, complex)):
187+
val = jax.numpy.array(val)
186188
if isinstance(val, jax.Array):
187189
val = val.dtype
188190
if _is_valid_dtype(val):

jax/experimental/array_api/_elementwise_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222

2323
def _promote_dtypes(name, *args):
2424
assert isinstance(name, str)
25-
if not all(isinstance(arg, jax.Array) for arg in args):
25+
if not all(isinstance(arg, (bool, int, float, complex, jax.Array))
26+
for arg in args):
2627
raise ValueError(f"{name}: inputs must be arrays; got types {[type(arg) for arg in args]}")
2728
dtype = _result_type(*args)
28-
return [arg.astype(dtype) for arg in args]
29+
return [jax.numpy.asarray(arg).astype(dtype) for arg in args]
2930

3031

3132
def abs(x, /):

0 commit comments

Comments
 (0)