diff --git a/.github/workflows/gpu-build.yml b/.github/workflows/gpu-build.yml index 00fb6d8..afdd82b 100644 --- a/.github/workflows/gpu-build.yml +++ b/.github/workflows/gpu-build.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.11 + python-version: 3.12 - name: Install Python dependencies run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 21ceee7..14c82dd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,10 +14,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - jax-version: ["jax[cpu]"] - include: - - os: ubuntu-latest - jax-version: "'jax[cpu]==0.4.20' 'numpy<2.0'" + jax-version: ["jax"] steps: - uses: actions/checkout@v4 @@ -28,7 +25,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.9" + python-version: "3.12" - name: Install fftw on ubuntu if: ${{ matrix.os == 'ubuntu-latest' }} diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 36899a7..e89127a 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -35,7 +35,7 @@ jobs: - uses: actions/setup-python@v5 name: Install Python with: - python-version: "3.9" + python-version: "3.12" - name: Build sdist run: | python -m pip install -U pip diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd3d6ac..235156e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,20 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: "v4.6.0" + rev: "v5.0.0" hooks: - id: trailing-whitespace - id: end-of-file-fixer exclude_types: [json, binary] - repo: https://github.com/psf/black - rev: "24.8.0" + rev: "24.10.0" hooks: - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.6.8" + rev: "v0.6.9" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: "v19.1.0" + rev: "v19.1.1" hooks: - id: clang-format diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index 545f56e..127389e 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -3,6 +3,7 @@ from functools import partial, reduce import numpy as np +import jax from jax import core from jax import jit from jax import numpy as jnp @@ -123,7 +124,12 @@ def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts): ) output_tangents += [s * output_tangent[:, :, n] for n, s in enumerate(scales)] - return output, reduce(ad.add_tangents, output_tangents, ad.Zero.from_value(output)) + if jax.version.__version_info__ < (0, 4, 34): + zero = ad.Zero.from_value(output) + else: + zero = ad.Zero.from_primal_value(output) + + return output, reduce(ad.add_tangents, output_tangents, zero) def transpose(doutput, source, *points, output_shape, eps, iflag, opts):