From 3c727ab53b105a4d1397ba8ed97fbb433ad1935c Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Jul 2023 21:17:07 -0400 Subject: [PATCH 1/8] Support new AdjointStyle trait --- Project.toml | 2 +- src/dct.jl | 2 ++ src/fft.jl | 15 +++++++++++++++ test/Project.toml | 8 ++++---- test/runtests.jl | 18 ++++++++++++++++++ 5 files changed, 40 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 79902d9..b3cd54e 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" [compat] -AbstractFFTs = "1.0" +AbstractFFTs = "1.4" FFTW_jll = "3.3.9" MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023" Preferences = "1.2" diff --git a/src/dct.jl b/src/dct.jl index cd3ec60..a921731 100644 --- a/src/dct.jl +++ b/src/dct.jl @@ -171,3 +171,5 @@ end mul!(Array{T}(undef, p.plan.osz), p, copy(x)) # need copy to preserve input *(p::DCTPlan{T,K,true}, x::StridedArray{T}) where {T,K} = mul!(x, p, x) + +AbstractFFTs.AdjointStyle(::DCTPlan) = AbstractFFTs.UnitaryAdjointStyle() diff --git a/src/fft.jl b/src/fft.jl index daa5866..db88044 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1049,3 +1049,18 @@ function *(p::r2rFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K} unsafe_execute!(p, x, x) return x end + +####################################################################### + +""" + R2RAdjointStyle(kinds) + +Projection style for real to real transforms +""" +struct R2RAdjointStyle{K} <: AbstractFFTs.AdjointStyle + kinds::K +end + +AbstractFFTs.AdjointStyle(::cFFTWPlan) = AbstractFFTs.FFTAdjointStyle() +AbstractFFTs.AdjointStyle(::rFFTWPlan{T, FORWARD}) where {T} = AbstractFFTs.RFFTAdjointStyle() +AbstractFFTs.AdjointStyle(P::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(P.osz[first(P.region)]) diff --git a/test/Project.toml b/test/Project.toml index c46e7ba..3ac006b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,8 @@ -# A bug in Julia 1.6.0's Pkg causes Preferences to be dropped during `Pkg.test()`, so we work around -# it by explicitly creating a `test/Project.toml` which will correctly communicate any preferences -# through to the child Julia process. X-ref: https://github.com/JuliaLang/Pkg.jl/issues/2500 - [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/test/runtests.jl b/test/runtests.jl index 301194d..9f0e6de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -577,3 +577,21 @@ end end end end + +@testset "DCT adjoints" begin + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) + y = randn(size(x)) + N = ndims(x) + for dims in unique((1, 1:N, N)) + for P in (plan_dct(x, dims), plan_idct(x, dims)) + AbstractFFTs.TestUtils.test_plan_adjoint(P, x) + end + end + end +end + +@testset "AbstractFFTs FFT backend tests" begin + # note this also tests adjoint functionality for FFT plans + AbstractFFTs.TestUtils.test_complex_ffts(Array) + AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true) +end From 16624894dc7228f62b94cb007f822ae4637a58ae Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Jul 2023 21:22:30 -0400 Subject: [PATCH 2/8] Remove extraneous test deps --- test/Project.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 3ac006b..f5407aa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,4 @@ [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" From b700223a99b2caaf1c3934597622534674bfca6f Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Jul 2023 21:22:59 -0400 Subject: [PATCH 3/8] Add AbstractFFTs compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b3cd54e..e53ebb6 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" [compat] -AbstractFFTs = "1.4" +AbstractFFTs = "1.5" FFTW_jll = "3.3.9" MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023" Preferences = "1.2" From 6a90c92c121bc42bb4392c379afcd6c8f7bfaf6e Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Jul 2023 21:34:05 -0400 Subject: [PATCH 4/8] Only run new adjoint tests on FFTW --- test/runtests.jl | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9f0e6de..981f3f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -579,12 +579,15 @@ end end @testset "DCT adjoints" begin - for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) - y = randn(size(x)) - N = ndims(x) - for dims in unique((1, 1:N, N)) - for P in (plan_dct(x, dims), plan_idct(x, dims)) - AbstractFFTs.TestUtils.test_plan_adjoint(P, x) + # only test on FFTW because MKL is missing functionality + if FFTW.get_provider() == "fftw" + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) + y = randn(size(x)) + N = ndims(x) + for dims in unique((1, 1:N, N)) + for P in (plan_dct(x, dims), plan_idct(x, dims)) + AbstractFFTs.TestUtils.test_plan_adjoint(P, x) + end end end end @@ -592,6 +595,9 @@ end @testset "AbstractFFTs FFT backend tests" begin # note this also tests adjoint functionality for FFT plans - AbstractFFTs.TestUtils.test_complex_ffts(Array) - AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true) + # only test on FFTW because MKL is missing functionality + if FFTW.get_provider() == "fftw" + AbstractFFTs.TestUtils.test_complex_ffts(Array) + AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true) + end end From 6da19395de1ab2b2b272ae8274f58cd4de0ccf01 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 2 Aug 2023 17:21:26 -0400 Subject: [PATCH 5/8] Revert test/Project.toml change --- test/Project.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index f5407aa..c46e7ba 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,3 +1,7 @@ +# A bug in Julia 1.6.0's Pkg causes Preferences to be dropped during `Pkg.test()`, so we work around +# it by explicitly creating a `test/Project.toml` which will correctly communicate any preferences +# through to the child Julia process. X-ref: https://github.com/JuliaLang/Pkg.jl/issues/2500 + [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 4662374a788598b9c64ed7d234f23dc92a3cb4eb Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 2 Aug 2023 17:23:09 -0400 Subject: [PATCH 6/8] Remove R2RAdjointStyle.(accidentally included) --- src/fft.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/fft.jl b/src/fft.jl index db88044..fc9ac71 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1052,15 +1052,6 @@ end ####################################################################### -""" - R2RAdjointStyle(kinds) - -Projection style for real to real transforms -""" -struct R2RAdjointStyle{K} <: AbstractFFTs.AdjointStyle - kinds::K -end - AbstractFFTs.AdjointStyle(::cFFTWPlan) = AbstractFFTs.FFTAdjointStyle() AbstractFFTs.AdjointStyle(::rFFTWPlan{T, FORWARD}) where {T} = AbstractFFTs.RFFTAdjointStyle() AbstractFFTs.AdjointStyle(P::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(P.osz[first(P.region)]) From 38f72b491822a85bff67714d5871aeabcd878e23 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 3 Sep 2023 22:04:41 -0400 Subject: [PATCH 7/8] Use small p for consistency --- src/fft.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fft.jl b/src/fft.jl index fc9ac71..bd9e604 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1054,4 +1054,4 @@ end AbstractFFTs.AdjointStyle(::cFFTWPlan) = AbstractFFTs.FFTAdjointStyle() AbstractFFTs.AdjointStyle(::rFFTWPlan{T, FORWARD}) where {T} = AbstractFFTs.RFFTAdjointStyle() -AbstractFFTs.AdjointStyle(P::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(P.osz[first(P.region)]) +AbstractFFTs.AdjointStyle(p::rFFTWplan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(p.osz[first(p.region)]) From 969315a6dabeac25715a6b608b58782e6ff6b648 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 3 Sep 2023 22:23:11 -0400 Subject: [PATCH 8/8] Fix typo --- src/fft.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fft.jl b/src/fft.jl index bd9e604..4831a18 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1054,4 +1054,4 @@ end AbstractFFTs.AdjointStyle(::cFFTWPlan) = AbstractFFTs.FFTAdjointStyle() AbstractFFTs.AdjointStyle(::rFFTWPlan{T, FORWARD}) where {T} = AbstractFFTs.RFFTAdjointStyle() -AbstractFFTs.AdjointStyle(p::rFFTWplan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(p.osz[first(p.region)]) +AbstractFFTs.AdjointStyle(p::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(p.osz[first(p.region)])