Skip to content

ENH: add quantile #341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
nunique
one_hot
pad
quantile
setdiff1d
sinc
```
3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, one_hot, pad
from ._delegation import isclose, one_hot, pad, quantile
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand Down Expand Up @@ -36,6 +36,7 @@
"nunique",
"one_hot",
"pad",
"quantile",
"setdiff1d",
"sinc",
]
72 changes: 71 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array, DType

__all__ = ["isclose", "one_hot", "pad"]
__all__ = ["isclose", "one_hot", "pad", "quantile"]


def isclose(
Expand Down Expand Up @@ -247,3 +247,73 @@ def pad(
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def quantile(
x: Array,
q: Array | float,
/,
*,
axis: int | None = None,
keepdims: bool = False,
method: str = "linear",
xp: ModuleType | None = None,
) -> Array:
"""
Compute the q-th quantile(s) of the data along the specified axis.

Parameters
----------
x : array of real numbers
Data array.
q : array of float
Probability or sequence of probabilities of the quantiles to compute.
Values must be between 0 and 1 (inclusive). Must have length 1 along
`axis` unless ``keepdims=True``.
axis : int or None, default: None
Axis along which the quantiles are computed. ``None`` ravels both `x`
and `q` before performing the calculation.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the
result as dimensions with size one. With this option, the result will
broadcast correctly against the original array `x`.
method : str, default: 'linear'
The method to use for estimating the quantile. The available options are:
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default),
'median_unbiased', 'normal_unbiased', 'harrell-davis'.
xp : array_namespace, optional
The standard-compatible namespace for `x` and `q`. Default: infer.

Returns
-------
array
An array with the quantiles of the data.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([[10, 8, 7, 5, 4], [0, 1, 2, 3, 5]])
>>> xpx.quantile(x, 0.5, axis=-1)
Array([7., 2.], dtype=array_api_strict.float64)
>>> xpx.quantile(x, [0.25, 0.75], axis=-1)
Array([[5., 8.],
[1., 3.]], dtype=array_api_strict.float64)
"""
xp = array_namespace(x, q) if xp is None else xp

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scikit-learn/scikit-learn#31671 (comment) suggests that delegation to some existing array libraries may be desirable here

try:
import scipy
from packaging import version

# The quantile function in scipy 1.16 supports array API directly, no need
# to delegate
if version.parse(scipy.__version__) >= version.parse("1.16"):
from scipy.stats import quantile as scipy_quantile

return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method)
except (ImportError, AttributeError):
pass

return _funcs.quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp)
179 changes: 179 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"kron",
"nunique",
"pad",
"quantile",
"setdiff1d",
"sinc",
]
Expand Down Expand Up @@ -988,3 +989,181 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y


def quantile(
x: Array,
q: Array | float,
/,
*,
axis: int | None = None,
keepdims: bool = False,
method: str = "linear",
xp: ModuleType | None = None,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
if xp is None:
xp = array_namespace(x, q)

# Convert q to array if it's a scalar
q_is_scalar = isinstance(q, int | float)
if q_is_scalar:
q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x))

# Validate inputs
if not xp.isdtype(x.dtype, ("integral", "real floating")):
raise ValueError("`x` must have real dtype.") # noqa: EM101
if not xp.isdtype(q.dtype, "real floating"):
raise ValueError("`q` must have real floating dtype.") # noqa: EM101

# Promote to common dtype
x = xp.astype(x, xp.float64)
q = xp.astype(q, xp.float64)
q = xp.asarray(q, device=_compat.device(x))

dtype = x.dtype
axis_none = axis is None
ndim = max(x.ndim, q.ndim)

if axis_none:
x = xp.reshape(x, (-1,))
q = xp.reshape(q, (-1,))
axis = 0
elif not isinstance(axis, int):
raise ValueError("`axis` must be an integer or None.") # noqa: EM101
elif axis >= ndim or axis < -ndim:
raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101
else:
axis = int(axis)

# Validate method
methods = {
"inverted_cdf",
"averaged_inverted_cdf",
"closest_observation",
"hazen",
"interpolated_inverted_cdf",
"linear",
"median_unbiased",
"normal_unbiased",
"weibull",
"harrell-davis",
}
if method not in methods:
raise ValueError(f"`method` must be one of {methods}") # noqa: EM102

# Handle keepdims parameter
if keepdims not in {None, True, False}:
raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101

# Handle empty arrays
if x.shape[axis] == 0:
shape = list(x.shape)
shape[axis] = 1
x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x))

# Sort the data
y = xp.sort(x, axis=axis)

# Move axis to the end for easier processing
y = xp.moveaxis(y, axis, -1)
if not (q_is_scalar or q.ndim == 0):
q = xp.moveaxis(q, axis, -1)

# Get the number of elements along the axis
n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y))

# Apply quantile calculation based on method
if method in {
"inverted_cdf",
"averaged_inverted_cdf",
"closest_observation",
"hazen",
"interpolated_inverted_cdf",
"linear",
"median_unbiased",
"normal_unbiased",
"weibull",
}:
res = _quantile_hf(y, q, n, method, xp)
elif method == "harrell-davis":
res = _quantile_hd(y, q, n, xp)
else:
raise ValueError(f"Unknown method: {method}") # noqa: EM102

# Handle NaN output for invalid q values
p_mask = (q > 1) | (q < 0) | xp.isnan(q)
if xp.any(p_mask):
res = xp.asarray(res, copy=True)
res = at(res, p_mask).set(xp.nan)

# Reshape per axis/keepdims
if axis_none and keepdims:
shape = (1,) * (ndim - 1) + res.shape
res = xp.reshape(res, shape)
axis = -1

# Move axis back to original position
res = xp.moveaxis(res, -1, axis)

# Handle keepdims
if not keepdims and res.shape[axis] == 1:
res = xp.squeeze(res, axis=axis)

# For scalar q, ensure we return a scalar result
if q_is_scalar and hasattr(res, "shape") and res.shape != ():
res = res[()]

return res


def _quantile_hf(y: Array, p: Array, n: Array, method: str, xp: ModuleType) -> Array:
"""Helper function for Hyndman-Fan quantile methods."""
ms = {
"inverted_cdf": 0,
"averaged_inverted_cdf": 0,
"closest_observation": -0.5,
"interpolated_inverted_cdf": 0,
"hazen": 0.5,
"weibull": p,
"linear": 1 - p,
"median_unbiased": p / 3 + 1 / 3,
"normal_unbiased": p / 4 + 3 / 8,
}
m = ms[method]

jg = p * n + m - 1
j = xp.astype(jg // 1, xp.int64) # Convert to integer
g = jg % 1

if method == "inverted_cdf":
g = xp.astype((g > 0), jg.dtype)
elif method == "averaged_inverted_cdf":
g = (1 + xp.astype((g > 0), jg.dtype)) / 2
elif method == "closest_observation":
g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype)
if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}:
g = xp.asarray(g)
g = at(g, jg < 0).set(0)
g = at(g, j < 0).set(0)
j = xp.clip(j, 0, n - 1)
jp1 = xp.clip(j + 1, 0, n - 1)

# Broadcast indices to match y shape except for the last axis
if y.ndim > 1:
# Create broadcast shape for indices
broadcast_shape = list(y.shape[:-1]) + [1] # noqa: RUF005
j = xp.broadcast_to(j, broadcast_shape)
jp1 = xp.broadcast_to(jp1, broadcast_shape)
g = xp.broadcast_to(g, broadcast_shape)

return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis(
y, jp1, axis=-1
)


def _quantile_hd(y: Array, p: Array, n: Array, xp: ModuleType) -> Array:
"""Helper function for Harrell-Davis quantile method."""
# For now, implement a simplified version that falls back to linear method
# since betainc is not available in the array API standard
return _quantile_hf(y, p, n, "linear", xp)
69 changes: 69 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
nunique,
one_hot,
pad,
quantile,
setdiff1d,
sinc,
)
Expand All @@ -43,6 +44,7 @@
lazy_xp_function(nunique)
lazy_xp_function(one_hot)
lazy_xp_function(pad)
lazy_xp_function(quantile)
# FIXME calls in1d which calls xp.unique_values without size
lazy_xp_function(setdiff1d, jax_jit=False)
lazy_xp_function(sinc)
Expand Down Expand Up @@ -1162,3 +1164,70 @@ def test_device(self, xp: ModuleType, device: Device):

def test_xp(self, xp: ModuleType):
xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))


class TestQuantile:
def test_basic(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
actual = quantile(x, 0.5)
expect = xp.asarray(3.0)
xp_assert_close(actual, expect)

def test_multiple_quantiles(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
actual = quantile(x, xp.asarray([0.25, 0.5, 0.75]))
expect = xp.asarray([2.0, 3.0, 4.0])
xp_assert_close(actual, expect)

def test_2d_axis(self, xp: ModuleType):
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
actual = quantile(x, 0.5, axis=0)
expect = xp.asarray([2.5, 3.5, 4.5])
xp_assert_close(actual, expect)

def test_2d_axis_keepdims(self, xp: ModuleType):
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
actual = quantile(x, 0.5, axis=0, keepdims=True)
expect = xp.asarray([[2.5, 3.5, 4.5]])
xp_assert_close(actual, expect)

def test_methods(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
methods = ["linear", "hazen", "weibull"]
for method in methods:
actual = quantile(x, 0.5, method=method)
# All methods should give reasonable results
assert 2.5 <= float(actual) <= 3.5

def test_edge_cases(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
# q = 0 should give minimum
actual = quantile(x, 0.0)
expect = xp.asarray(1.0)
xp_assert_close(actual, expect)

# q = 1 should give maximum
actual = quantile(x, 1.0)
expect = xp.asarray(5.0)
xp_assert_close(actual, expect)

def test_invalid_q(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
# q > 1 should return NaN
actual = quantile(x, 1.5)
assert xp.isnan(actual)

# q < 0 should return NaN
actual = quantile(x, -0.5)
assert xp.isnan(actual)

def test_device(self, xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3, 4, 5], device=device)
actual = quantile(x, 0.5)
assert get_device(actual) == device

def test_xp(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
actual = quantile(x, 0.5, xp=xp)
expect = xp.asarray(3.0)
xp_assert_close(actual, expect)
Loading