Skip to content

finufft: update to 2.4 #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .devcontainer/Dockerfile
9 changes: 9 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "CUDA",
"runArgs": [
"--gpus=all"
],
"build": {
"dockerfile": "Dockerfile"
}
}
3 changes: 0 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ if(JAX_FINUFFT_USE_CUDA)

message(STATUS "jax_finufft: CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")

# Propagate to finufft, because it doesn't look at CMAKE_CUDA_ARCHITECTURES by default
set(FINUFFT_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})

# This needs to be run after the CMAKE_CUDA_ARCHITECTURES check, otherwise
# it will set it to the compiler default
enable_language(CUDA)
Expand Down
4 changes: 2 additions & 2 deletions ci/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04
FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu24.04

RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y \
Expand All @@ -8,4 +8,4 @@ RUN apt-get update && \
curl \
libfftw3-dev

COPY --from=ghcr.io/astral-sh/uv:0.6.17 /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.7.11 /uv /uvx /bin/
4 changes: 2 additions & 2 deletions ci/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pipeline {
}
steps {
sh '''
uv run --extra test pytest -n 8 tests/
uv run --extra test pytest -n 8
'''
}
}
Expand All @@ -34,7 +34,7 @@ pipeline {
steps {
// TODO: add "-n 8", but GPU kernels don't seem to be thread-safe
sh '''
uv run --extra test --extra cuda12 pytest tests/
uv run --extra test --extra cuda12-local pytest
'''
}
}
Expand Down
9 changes: 4 additions & 5 deletions lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,13 @@ NB_MODULE(jax_finufft_cpu, m) {

nb::class_<finufft_opts> opts(m, "FinufftOpts");
opts.def("__init__",
[](finufft_opts *self, bool modeord, bool chkbnds, int debug, int spread_debug,
bool showwarn, int nthreads, int fftw, int spread_sort, bool spread_kerevalmeth,
bool spread_kerpad, double upsampfac, int spread_thread, int maxbatchsize,
int spread_nthr_atomic, int spread_max_sp_size) {
[](finufft_opts *self, bool modeord, int debug, int spread_debug, bool showwarn,
int nthreads, int fftw, int spread_sort, bool spread_kerevalmeth, bool spread_kerpad,
double upsampfac, int spread_thread, int maxbatchsize, int spread_nthr_atomic,
int spread_max_sp_size) {
new (self) finufft_opts;
default_opts<double>(self);
self->modeord = int(modeord);
self->chkbnds = int(chkbnds);
self->debug = debug;
self->spread_debug = spread_debug;
self->showwarn = int(showwarn);
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dynamic = ["version"]
[project.optional-dependencies]
test = ["pytest", "pytest-xdist", "absl-py"]
cuda12 = ["jax[cuda12]"]
cuda12-local = ["jax[cuda12-local]"]

[tool.scikit-build]
metadata.version.provider = "scikit_build_core.metadata.setuptools_scm"
Expand Down
2 changes: 0 additions & 2 deletions src/jax_finufft/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class Opts:
# These correspond to the default cufinufft options
# set in vendor/finufft/src/cuda/cufinufft.cu
modeord: bool = False
chkbnds: bool = True
debug: DebugLevel = DebugLevel.Silent
spread_debug: DebugLevel = DebugLevel.Silent
showwarn: bool = False
Expand Down Expand Up @@ -77,7 +76,6 @@ def to_finufft_opts(self):
compiled_with_omp = jax_finufft_cpu._omp_compile_check()
return jax_finufft_cpu.FinufftOpts(
self.modeord,
self.chkbnds,
int(self.debug),
int(self.spread_debug),
self.showwarn,
Expand Down
2 changes: 1 addition & 1 deletion vendor/finufft
Submodule finufft updated 401 files