Skip to content

ENH: add quantile #341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ exclude: ^.cruft.json|.copier-answers.yml$

repos:
- repo: https://github.com/adamchainz/blacken-docs
rev: "1.18.0"
rev: "1.19.1"
hooks:
- id: blacken-docs
additional_dependencies: [black==24.*]
Expand Down Expand Up @@ -35,21 +35,21 @@ repos:
- id: rst-inline-touching-normal

- repo: https://github.com/rbubley/mirrors-prettier
rev: "v3.4.2"
rev: "v3.6.2"
hooks:
- id: prettier
types_or: [yaml, markdown, html, css, scss, javascript, json]
args: [--prose-wrap=always]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.8.2"
rev: "v0.12.1"
hooks:
- id: ruff-format
- id: ruff
args: ["--fix", "--show-fixes"]

- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
rev: "v2.4.1"
hooks:
- id: codespell
exclude: pixi.lock
Expand All @@ -68,17 +68,17 @@ repos:
exclude: .pre-commit-config.yaml

- repo: https://github.com/abravalheri/validate-pyproject
rev: "v0.23"
rev: "v0.24.1"
hooks:
- id: validate-pyproject
additional_dependencies: ["validate-pyproject-schema-store[all]"]

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: "0.30.0"
rev: "0.33.1"
hooks:
- id: check-github-workflows

- repo: https://github.com/numpy/numpydoc
rev: "v1.8.0"
rev: "v1.9.0"
hooks:
- id: numpydoc-validation
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
nunique
one_hot
pad
quantile
setdiff1d
sinc
```
248 changes: 179 additions & 69 deletions pixi.lock

Large diffs are not rendered by default.

49 changes: 25 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ array-api-strict = ">=2.3.1"
numpy = ">=2.1.3"
pytest = ">=8.4.0"
hypothesis = ">=6.131.28"
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
# NOTE: don't add cupy, jax, pytorch, or sparse here,
# as they slow down mypy and are not portable across target OSs

[tool.pixi.feature.lint.tasks]
pre-commit-install = { cmd = "pre-commit install", description = "Install pre-commit"}
pre-commit = { cmd = "pre-commit run --all-files", description = "Run pre-commit"}
mypy = { cmd = "mypy", description="Type check with mypy"}
pylint = { cmd = "pylint array_api_extra", cwd = "src" , description = "Lint using pylint"}
pyright = { cmd = "basedpyright", description = "Type check with basedpyright"}
lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] , description = "Run pre-commit, pylint, mypy, and pyright"}
pre-commit-install = { cmd = "pre-commit install", description = "Install pre-commit" }
pre-commit = { cmd = "pre-commit run --all-files", description = "Run pre-commit" }
mypy = { cmd = "mypy", description = "Type check with mypy" }
pylint = { cmd = "pylint array_api_extra", cwd = "src", description = "Lint using pylint" }
pyright = { cmd = "basedpyright", description = "Type check with basedpyright" }
lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"], description = "Run pre-commit, pylint, mypy, and pyright" }

[tool.pixi.feature.tests.dependencies]
pytest = ">=8.4.0"
Expand All @@ -85,18 +85,18 @@ array-api-strict = ">=2.3.1"
numpy = ">=1.22.0"

[tool.pixi.feature.tests.tasks]
tests = { cmd = "pytest -v", description = "Run tests"}
tests-cov = { cmd="pytest -v -ra --cov --cov-report=xml --cov-report=term --durations=20", description = "Run tests with coverage"}
tests = { cmd = "pytest -v", description = "Run tests" }
tests-cov = { cmd = "pytest -v -ra --cov --cov-report=xml --cov-report=term --durations=20", description = "Run tests with coverage" }

clean-vendor-compat = { cmd = "rm -rf vendor_tests/array_api_compat", description = "Delete the existing vendored version of array-api-compat"}
clean-vendor-extra = { cmd = "rm -rf vendor_tests/array_api_extra" , description = "Delete the existing vendored version of array-api-extra"}
copy-vendor-compat = { cmd = "cp -r $(python -c 'import site; print(site.getsitepackages()[0])')/array_api_compat vendor_tests/", depends-on = ["clean-vendor-compat"] , description = "Vendor a clean copy of array-api-compat"}
copy-vendor-extra = { cmd = "cp -r src/array_api_extra vendor_tests/", depends-on = ["clean-vendor-extra"] , description = "Vendor a clean copy of array-api-extra"}
tests-vendor = { cmd = "pytest -v vendor_tests", depends-on = ["copy-vendor-compat", "copy-vendor-extra"] , description = "Check that array-api-extra and array-api-compat can be vendored together" }
clean-vendor-compat = { cmd = "rm -rf vendor_tests/array_api_compat", description = "Delete the existing vendored version of array-api-compat" }
clean-vendor-extra = { cmd = "rm -rf vendor_tests/array_api_extra", description = "Delete the existing vendored version of array-api-extra" }
copy-vendor-compat = { cmd = "cp -r $(python -c 'import site; print(site.getsitepackages()[0])')/array_api_compat vendor_tests/", depends-on = ["clean-vendor-compat"], description = "Vendor a clean copy of array-api-compat" }
copy-vendor-extra = { cmd = "cp -r src/array_api_extra vendor_tests/", depends-on = ["clean-vendor-extra"], description = "Vendor a clean copy of array-api-extra" }
tests-vendor = { cmd = "pytest -v vendor_tests", depends-on = ["copy-vendor-compat", "copy-vendor-extra"], description = "Check that array-api-extra and array-api-compat can be vendored together" }

tests-ci = { depends-on = ["tests-cov", "tests-vendor"] , description = "Run tests with coverage and vendor tests"}
coverage = { cmd = "coverage html", depends-on = ["tests-cov"], description = "Generate test coverage html report"}
open-coverage = { cmd = "open htmlcov/index.html", depends-on = ["coverage"] , description = "Open test coverage report"}
tests-ci = { depends-on = ["tests-cov", "tests-vendor"], description = "Run tests with coverage and vendor tests" }
coverage = { cmd = "coverage html", depends-on = ["tests-cov"], description = "Generate test coverage html report" }
open-coverage = { cmd = "open htmlcov/index.html", depends-on = ["coverage"], description = "Open test coverage report" }

[tool.pixi.feature.docs.dependencies]
sphinx = ">=7.4.7"
Expand All @@ -105,20 +105,20 @@ myst-parser = ">=4.0.1"
sphinx-copybutton = ">=0.5.2"
sphinx-autodoc-typehints = ">=1.25.3"
# Needed to import parsed modules with autodoc
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
pytest = ">=8.4.0"
typing-extensions = ">=4.14.0"
numpy = ">=2.1.3"

[tool.pixi.feature.docs.tasks]
docs = { cmd = "sphinx-build -E -W . build/", cwd = "docs" , description = "Build docs"}
open-docs = { cmd = "open build/index.html", cwd = "docs", depends-on = ["docs"] , description = "Open the generated docs"}
docs = { cmd = "sphinx-build -E -W . build/", cwd = "docs", description = "Build docs" }
open-docs = { cmd = "open build/index.html", cwd = "docs", depends-on = ["docs"], description = "Open the generated docs" }

[tool.pixi.feature.dev.dependencies]
ipython = ">=7.33.0"

[tool.pixi.feature.dev.tasks]
ipython = { cmd = "ipython" , description = "Launch ipython"}
ipython = { cmd = "ipython", description = "Launch ipython" }

[tool.pixi.feature.py310.dependencies]
python = "~=3.10.0"
Expand All @@ -135,7 +135,7 @@ numpy = "=1.22.0"
# Note: JAX and PyTorch will install CPU variants.
[tool.pixi.feature.backends.dependencies]
pytorch = ">=2.7.0"
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
sparse = ">=0.17.0"

[tool.pixi.feature.backends.target.linux-64.dependencies]
Expand Down Expand Up @@ -184,7 +184,7 @@ python-freethreading = "~=3.13.0"
pytest-run-parallel = ">=0.4.4"
numpy = ">=2.3.0"
# pytorch = "*" # Not available on Python 3.13t yet
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
# sparse = "*" # numba not available on Python 3.13t yet
# jax = "*" # ml_dtypes not available on Python 3.13t yet

Expand Down Expand Up @@ -245,7 +245,7 @@ ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["tests/*"]
disable_error_code = ["no-untyped-def"] # test(...) without -> None
disable_error_code = ["no-untyped-def"] # test(...) without -> None

# pyright

Expand Down Expand Up @@ -322,6 +322,7 @@ ignore = [
"N801", # Class name should use CapWords convention
"N802", # Function name should be lowercase
"N806", # Variable in function should be lowercase
"PLC0415", # `import` should be at the top-level of a file
]


Expand Down
3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, one_hot, pad
from ._delegation import isclose, one_hot, pad, quantile
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand Down Expand Up @@ -36,6 +36,7 @@
"nunique",
"one_hot",
"pad",
"quantile",
"setdiff1d",
"sinc",
]
95 changes: 94 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Literal

from ._lib import _funcs
from ._lib._quantile import quantile as _quantile
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
Expand All @@ -18,7 +19,7 @@
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array, DType

__all__ = ["isclose", "one_hot", "pad"]
__all__ = ["isclose", "one_hot", "pad", "quantile"]


def isclose(
Expand Down Expand Up @@ -247,3 +248,95 @@ def pad(
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def quantile(
x: Array,
q: Array | float,
/,
*,
axis: int | None = None,
keepdims: bool | None = None,
method: str = "linear",
xp: ModuleType | None = None,
) -> Array:
"""
Compute the q-th quantile(s) of the data along the specified axis.

Parameters
----------
x : array of real numbers
Data array.
q : array of float
Probability or sequence of probabilities of the quantiles to compute.
Values must be between 0 and 1 (inclusive). Must have length 1 along
`axis` unless ``keepdims=True``.
axis : int or None, default: None
Axis along which the quantiles are computed. ``None`` ravels both `x`
and `q` before performing the calculation.
keepdims : bool or None, default: None
By default, the axis will be reduced away if possible
(i.e. if there is exactly one element of `q` per axis-slice of `x`).
If `keepdims` is set to True, the axes which are reduced are left in the
result as dimensions with size one. With this option, the result will
broadcast correctly against the original array `x`.
If `keepdims` is set to False, the axis will be reduced away if possible,
and an error will be raised otherwise.
method : str, default: 'linear'
The method to use for estimating the quantile. The available options are:
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default),
'median_unbiased', 'normal_unbiased'.
xp : array_namespace, optional
The standard-compatible namespace for `x` and `q`. Default: infer.

Returns
-------
array
An array with the quantiles of the data.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([[10, 8, 7, 5, 4], [0, 1, 2, 3, 5]])
>>> xpx.quantile(x, 0.5, axis=-1)
Array([7., 2.], dtype=array_api_strict.float64)
>>> xpx.quantile(x, [0.25, 0.75], axis=-1)
Array([[5., 8.],
[1., 3.]], dtype=array_api_strict.float64)
"""
# We only support a subset of the methods supported by scipy.stats.quantile.
# So we need to perform the validation here.
methods = {
"inverted_cdf",
"averaged_inverted_cdf",
"closest_observation",
"hazen",
"interpolated_inverted_cdf",
"linear",
"median_unbiased",
"normal_unbiased",
"weibull",
}
if method not in methods:
raise ValueError(f"`method` must be one of {methods}") # noqa: EM102

xp = array_namespace(x, q) if xp is None else xp

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scikit-learn/scikit-learn#31671 (comment) suggests that delegation to some existing array libraries may be desirable here

try:
import scipy # type: ignore[import-untyped]
from packaging import version

# The quantile function in scipy 1.16 supports array API directly, no need
# to delegate
if version.parse(scipy.__version__) >= version.parse("1.16"): # pyright: ignore[reportUnknownArgumentType]
from scipy.stats import ( # type: ignore[import-untyped]
quantile as scipy_quantile,
)

return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method)
except (ImportError, AttributeError):
pass

return _quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp)
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...
for axis in range(-ndim, 0):
sizes = {shape[axis] for shape in shapes if axis >= -len(shape)}
# Dask uses NaN for unknown shape, which predates the Array API spec for None
none_size = None in sizes or math.nan in sizes
none_size = None in sizes or math.nan in sizes # noqa: PLW0177
sizes -= {1, None, math.nan}
if len(sizes) > 1:
msg = (
Expand Down
Loading
Loading