Skip to content

Commit 73d16d0

Browse files
committed
Update dtype strictness for array_api statistical functions
Original NumPy Commit: 315b0d0db60977be164a251f55a25b64497d3db9
1 parent c1f27b5 commit 73d16d0

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

array_api_strict/_statistical_functions.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

33
from ._dtypes import (
4-
_floating_dtypes,
4+
_real_floating_dtypes,
5+
_real_numeric_dtypes,
56
_numeric_dtypes,
67
)
78
from ._array_object import Array
8-
from ._creation_functions import asarray
99
from ._dtypes import float32, float64, complex64, complex128
1010

1111
from typing import TYPE_CHECKING, Optional, Tuple, Union
@@ -23,8 +23,8 @@ def max(
2323
axis: Optional[Union[int, Tuple[int, ...]]] = None,
2424
keepdims: bool = False,
2525
) -> Array:
26-
if x.dtype not in _numeric_dtypes:
27-
raise TypeError("Only numeric dtypes are allowed in max")
26+
if x.dtype not in _real_numeric_dtypes:
27+
raise TypeError("Only real numeric dtypes are allowed in max")
2828
return Array._new(np.max(x._array, axis=axis, keepdims=keepdims))
2929

3030

@@ -35,8 +35,8 @@ def mean(
3535
axis: Optional[Union[int, Tuple[int, ...]]] = None,
3636
keepdims: bool = False,
3737
) -> Array:
38-
if x.dtype not in _floating_dtypes:
39-
raise TypeError("Only floating-point dtypes are allowed in mean")
38+
if x.dtype not in _real_floating_dtypes:
39+
raise TypeError("Only real floating-point dtypes are allowed in mean")
4040
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims))
4141

4242

@@ -47,8 +47,8 @@ def min(
4747
axis: Optional[Union[int, Tuple[int, ...]]] = None,
4848
keepdims: bool = False,
4949
) -> Array:
50-
if x.dtype not in _numeric_dtypes:
51-
raise TypeError("Only numeric dtypes are allowed in min")
50+
if x.dtype not in _real_numeric_dtypes:
51+
raise TypeError("Only real numeric dtypes are allowed in min")
5252
return Array._new(np.min(x._array, axis=axis, keepdims=keepdims))
5353

5454

@@ -82,8 +82,8 @@ def std(
8282
keepdims: bool = False,
8383
) -> Array:
8484
# Note: the keyword argument correction is different here
85-
if x.dtype not in _floating_dtypes:
86-
raise TypeError("Only floating-point dtypes are allowed in std")
85+
if x.dtype not in _real_floating_dtypes:
86+
raise TypeError("Only real floating-point dtypes are allowed in std")
8787
return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))
8888

8989

@@ -117,6 +117,6 @@ def var(
117117
keepdims: bool = False,
118118
) -> Array:
119119
# Note: the keyword argument correction is different here
120-
if x.dtype not in _floating_dtypes:
121-
raise TypeError("Only floating-point dtypes are allowed in var")
120+
if x.dtype not in _real_floating_dtypes:
121+
raise TypeError("Only real floating-point dtypes are allowed in var")
122122
return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))

0 commit comments

Comments
 (0)