Skip to content

Commit 0b6ddcd

Browse files
committed
Merge branch 'main' into fft
2 parents f228b58 + 74b7b79 commit 0b6ddcd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1375
-349
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
name: Array API Tests (Dask)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-dask:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: dask
10+
module-name: dask.array
11+
extra-requires: numpy
12+
pytest-extra-args: --disable-deadline --max-examples=5
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: Array API Tests (NumPy dev)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-numpy-dev:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: numpy
10+
extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple'
11+
xfails-file-extra: '-dev'

.github/workflows/array-api-tests.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ on:
66
package-name:
77
required: true
88
type: string
9+
module-name:
10+
required: false
11+
type: string
12+
extra-requires:
13+
required: false
14+
type: string
915
package-version:
1016
required: false
1117
type: string
@@ -24,7 +30,7 @@ on:
2430

2531

2632
env:
27-
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }}"
33+
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline"
2834

2935
jobs:
3036
tests:
@@ -49,17 +55,18 @@ jobs:
4955
with:
5056
python-version: ${{ matrix.python-version }}
5157
- name: Install dependencies
52-
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
53-
# to put this in the numpy 1.21 config file.
54-
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
58+
# NumPy 1.21 doesn't support Python 3.11. NumPy 2.0 doesn't support
59+
# Python 3.8. There doesn't seem to be a way to put this in the numpy
60+
# 1.21 config file.
61+
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
5562
run: |
5663
python -m pip install --upgrade pip
57-
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}'
64+
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }}
5865
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
5966
- name: Run the array API testsuite (${{ inputs.package-name }})
60-
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
67+
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
6168
env:
62-
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.package-name }}
69+
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
6370
# This enables the NEP 50 type promotion behavior (without it a lot of
6471
# tests fail on bad scalar type promotion behavior)
6572
NPY_PROMOTION_STATE: weak

.github/workflows/ruff.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: CI
2+
on: [push, pull_request]
3+
jobs:
4+
check-ruff:
5+
runs-on: ubuntu-latest
6+
continue-on-error: true
7+
steps:
8+
- uses: actions/checkout@v4
9+
- name: Install Python
10+
uses: actions/setup-python@v5
11+
with:
12+
python-version: "3.11"
13+
- name: Install dependencies
14+
run: |
15+
python -m pip install --upgrade pip
16+
pip install ruff
17+
# Update output format to enable automatic inline annotations.
18+
- name: Run Ruff
19+
run: ruff check --output-format=github .

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
- name: Install Dependencies
1616
run: |
1717
python -m pip install --upgrade pip
18-
python -m pip install pytest numpy torch
18+
python -m pip install pytest numpy torch dask[array] jax[cpu]
1919
2020
- name: Run Tests
2121
run: |

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# macOS specific iles
132+
.DS_Store

CHANGELOG.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
# 1.4.1 (2024-01-18)
2+
3+
## Minor Changes
4+
5+
- Add support for the upcoming NumPy 2.0 release.
6+
7+
- Added a torch wrapper for `trace` (`torch.trace` doesn't support the
8+
`offset` argument or stacking)
9+
10+
- Wrap numpy, cupy, and torch `nonzero` to raise an error for zero-dimensional
11+
input arrays.
12+
13+
- Add torch wrapper for `newaxis`.
14+
15+
- Improve error message for `array_namespace`
16+
17+
- Fix linalg.cholesky returning the conjugate of the expected upper
18+
decomposition for numpy and cupy.
19+
120
# 1.4 (2023-09-13)
221

322
## Major Changes

README.md

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
This is a small wrapper around common array libraries that is compatible with
44
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5-
NumPy, CuPy, and PyTorch are supported. If you want support for other array
5+
NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
66
libraries, or if you encounter any issues, please [open an
77
issue](https://github.com/data-apis/array-api-compat/issues).
88

@@ -56,28 +56,39 @@ import array_api_compat.cupy as cp
5656
import array_api_compat.torch as torch
5757
```
5858

59-
Each will include all the functions from the normal NumPy/CuPy/PyTorch
59+
```py
60+
import array_api_compat.dask as da
61+
```
62+
63+
> [!NOTE]
64+
> There is no `array_api_compat.jax` submodule. JAX support is contained
65+
> in JAX itself in the `jax.experimental.array_api` module. array-api-compat simply
66+
> wraps that submodule. The main JAX support in this module consists of
67+
> supporting it in the [helper functions](#helper-functions) defined below.
68+
69+
Each will include all the functions from the normal NumPy/CuPy/PyTorch/dask.array
6070
namespace, except that functions that are part of the array API are wrapped so
6171
that they have the correct array API behavior. In each case, the array object
6272
used will be the same array object from the wrapped library.
6373

64-
## Difference between `array_api_compat` and `numpy.array_api`
74+
## Difference between `array_api_compat` and `array_api_strict`
6575

66-
`numpy.array_api` is a strict minimal implementation of the Array API (see
76+
`array_api_strict` is a strict minimal implementation of the array API standard, formerly
77+
known as `numpy.array_api` (see
6778
[NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html)). For
68-
example, `numpy.array_api` does not include any functions that are not part of
79+
example, `array_api_strict` does not include any functions that are not part of
6980
the array API specification, and will explicitly disallow behaviors that are
7081
not required by the spec (e.g., [cross-kind type
7182
promotions](https://data-apis.org/array-api/latest/API_specification/type_promotion.html)).
72-
(`cupy.array_api` is similar to `numpy.array_api`)
83+
(`cupy.array_api` is similar to `array_api_strict`)
7384

7485
`array_api_compat`, on the other hand, is just an extension of the
7586
corresponding array library namespaces with changes needed to be compliant
7687
with the array API. It includes all additional library functions not mentioned
7788
in the spec, and allows any library behaviors not explicitly disallowed by it,
7889
such as cross-kind casting.
7990

80-
In particular, unlike `numpy.array_api`, this package does not use a separate
91+
In particular, unlike `array_api_strict`, this package does not use a separate
8192
`Array` object, but rather just uses the corresponding array library array
8293
objects (`numpy.ndarray`, `cupy.ndarray`, `torch.Tensor`, etc.) directly. This
8394
is because those are the objects that are going to be passed as inputs to
@@ -86,7 +97,7 @@ functions by end users. This does mean that a few behaviors cannot be wrapped
8697
most things.
8798

8899
Array consuming library authors coding against the array API may wish to test
89-
against `numpy.array_api` to ensure they are not using functionality outside
100+
against `array_api_strict` to ensure they are not using functionality outside
90101
of the standard, but prefer this implementation for the default behavior for
91102
end-users.
92103

@@ -99,6 +110,11 @@ part of the specification but which are useful for using the array API:
99110
- `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array
100111
object.
101112

113+
- `is_numpy_array(x)`, `is_cupy_array(x)`, `is_torch_array(x)`,
114+
`is_dask_array(x)`, `is_jax_array(x)`: return `True` if `x` is an array from
115+
the corresponding library. These functions do not import the underlying
116+
library if it has not already been imported, so they are cheap to use.
117+
102118
- `array_namespace(*xs)`: Get the corresponding array API namespace for the
103119
arrays `xs`. For example, if the arrays are NumPy arrays, the returned
104120
namespace will be `array_api_compat.numpy`. Note that this function will
@@ -110,11 +126,11 @@ part of the specification but which are useful for using the array API:
110126
[`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html)
111127
in the array API specification. Included because `numpy.ndarray` does not
112128
include the `device` attribute and this library does not wrap or extend the
113-
array object. Note that for NumPy, `device(x)` is always `"cpu"`.
129+
array object. Note that for NumPy and dask, `device(x)` is always `"cpu"`.
114130

115131
- `to_device(x, device, /, *, stream=None)`: Equivalent to
116132
[`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html).
117-
Included because neither NumPy's, CuPy's, nor PyTorch's array objects
133+
Included because neither NumPy's, CuPy's, Dask's, nor PyTorch's array objects
118134
include this method. For NumPy, this function effectively does nothing since
119135
the only supported device is the CPU, but for CuPy, this method supports
120136
CuPy CUDA
@@ -219,6 +235,36 @@ version.
219235

220236
The minimum supported PyTorch version is 1.13.
221237

238+
### JAX
239+
240+
Unlike the other libraries supported here, JAX array API support is contained
241+
entirely in the JAX library. The JAX array API support is tracked at
242+
https://github.com/google/jax/issues/18353.
243+
244+
## Dask
245+
246+
If you're using dask with numpy, many of the same limitations that apply to numpy
247+
will also apply to dask. Besides those differences, other limitations include missing
248+
sort functionality (no `sort` or `argsort`), and limited support for the optional `linalg`
249+
and `fft` extensions.
250+
251+
In particular, the `fft` namespace is not compliant with the array API spec. Any functions
252+
that you find under the `fft` namespace are the original, unwrapped functions under [`dask.array.fft`](https://docs.dask.org/en/latest/array-api.html#fast-fourier-transforms), which may or may not be Array API compliant. Use at your own risk!
253+
254+
For `linalg`, several methods are missing, for example:
255+
- `cross`
256+
- `det`
257+
- `eigh`
258+
- `eigvalsh`
259+
- `matrix_power`
260+
- `pinv`
261+
- `slogdet`
262+
- `matrix_norm`
263+
- `matrix_rank`
264+
Other methods may only be partially implemented or return incorrect results at times.
265+
266+
The minimum supported Dask version is 2023.12.0.
267+
222268
## Vendoring
223269

224270
This library supports vendoring as an installation method. To vendor the
@@ -300,3 +346,54 @@ corresponding document does not yet exist for PyTorch, but you can examine the
300346
various comments in the
301347
[implementation](https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/torch/_aliases.py)
302348
to see what functions and behaviors have been wrapped.
349+
350+
351+
## Releasing
352+
353+
To release, first note that CuPy must be tested manually (it isn't tested on
354+
CI). Use the script
355+
356+
```
357+
./test_cupy.sh
358+
```
359+
360+
on a machine with a CUDA GPU.
361+
362+
Once you are ready to release, create a PR with a release branch, so that you
363+
can verify that CI is passing. You must edit
364+
365+
```
366+
array_api_compat/__init__.py
367+
```
368+
369+
and update the version (the version is not computed from the tag because that
370+
would break vendorability). You should also edit
371+
372+
```
373+
CHANGELOG.md
374+
```
375+
376+
with the changes for the release.
377+
378+
Then create a tag
379+
380+
```
381+
git tag -a <version>
382+
```
383+
384+
and push it to GitHub
385+
386+
```
387+
git push origin <version>
388+
```
389+
390+
Check that the `publish distributions` action works. Note that this action
391+
will run even if the other CI fails, so you must make sure that CI is passing
392+
*before* tagging.
393+
394+
This does mean you can ignore CI failures, but ideally you should fix any
395+
failures or update the `*-xfails.txt` files before tagging, so that CI and the
396+
cupy tests pass. Otherwise it will be hard to tell what things are breaking in
397+
the future. It's also a good idea to remove any xpasses from those files (but
398+
be aware that some xfails are from flaky failures, so unless you know the
399+
underlying issue has been fixed, a xpass test is probably still xfail).

array_api_compat/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
66
https://numpy.org/neps/nep-0047-array-api-standard.html.
77
8-
Unlike numpy.array_api, this is not a strict minimal implementation of the
8+
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with
1010
changes needed to be compliant with the Array API. See
1111
https://numpy.org/doc/stable/reference/array_api.html for a full list of
12-
changes. In particular, unlike numpy.array_api, this package does not use a
12+
changes. In particular, unlike array_api_strict, this package does not use a
1313
separate Array object, but rather just uses numpy.ndarray directly.
1414
15-
Library authors using the Array API may wish to test against numpy.array_api
15+
Library authors using the Array API may wish to test against array_api_strict
1616
to ensure they are not using functionality outside of the standard, but prefer
1717
this implementation for the default when working with NumPy arrays.
1818
1919
"""
20-
__version__ = '1.4'
20+
__version__ = '1.4.1'
2121

22-
from .common import *
22+
from .common import * # noqa: F401, F403

array_api_compat/_internal.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@ def func(x, /, xp, kwarg=None):
2121
arguments.
2222
2323
"""
24+
2425
def inner(f):
2526
@wraps(f)
2627
def wrapped_f(*args, **kwargs):
2728
return f(*args, xp=xp, **kwargs)
2829

2930
sig = signature(f)
30-
new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp'])
31+
new_sig = sig.replace(
32+
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
33+
)
3134

3235
if wrapped_f.__doc__ is None:
3336
wrapped_f.__doc__ = f"""\

0 commit comments

Comments
 (0)