From e06611d1b0cab03d8848e2d896a0fe1149f81a1e Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 13 Dec 2024 10:07:31 +0000 Subject: [PATCH 1/3] WIP: ForwardDiff extension --- ext/FFTWForwardDiffExt.jl | 18 ++++++++++++++++++ test/fftwforwarddiff.jl | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 ext/FFTWForwardDiffExt.jl create mode 100644 test/fftwforwarddiff.jl diff --git a/ext/FFTWForwardDiffExt.jl b/ext/FFTWForwardDiffExt.jl new file mode 100644 index 0000000..e01d800 --- /dev/null +++ b/ext/FFTWForwardDiffExt.jl @@ -0,0 +1,18 @@ +module FFTWForwardDiffExt +# AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im) + + +plan_r2r(x::AbstractArray{<:Dual}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims) +plan_r2r(x::AbstractArray{<:Complex{<:Dual}}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims) + +for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex? + @eval begin + $plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims) + $plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, dims=1:ndims(x)) = $plan(dual2array(x), d, 1 .+ dims) + end +end + +r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x +r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x + +end #module \ No newline at end of file diff --git a/test/fftwforwarddiff.jl b/test/fftwforwarddiff.jl new file mode 100644 index 0000000..c0e2601 --- /dev/null +++ b/test/fftwforwarddiff.jl @@ -0,0 +1,18 @@ +@testset "r2r" begin + x1 = Dual.(1:4.0, 2:5, 3:6) + t = FFTW.r2r(x1, FFTW.R2HC) + + @test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC) + + t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC) + @test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC) + + f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1] + @test derivative(f, 0.1) ≡ 1.0 + + @test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC) +end \ No newline at end of file From c182bc6f516fed1e4b6d76bb96a644e30c12004d Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 17 Dec 2024 22:11:27 +0000 Subject: [PATCH 2/3] tests pass --- Project.toml | 9 ++++++++- ext/FFTWForwardDiffExt.jl | 17 +++++++---------- test/Project.toml | 2 ++ test/fftwforwarddiff.jl | 5 ++++- test/runtests.jl | 2 ++ 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index b648109..9277108 100644 --- a/Project.toml +++ b/Project.toml @@ -10,9 +10,16 @@ MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +FFTWForwardDiffExt = "ForwardDiff" + [compat] -AbstractFFTs = "1.5" +AbstractFFTs = "1.6" FFTW_jll = "3.3.9" +ForwardDiff = "0.10" LinearAlgebra = "<0.0.1, 1" MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023, 2024" Preferences = "1.2" diff --git a/ext/FFTWForwardDiffExt.jl b/ext/FFTWForwardDiffExt.jl index e01d800..386b8f8 100644 --- a/ext/FFTWForwardDiffExt.jl +++ b/ext/FFTWForwardDiffExt.jl @@ -1,16 +1,13 @@ module FFTWForwardDiffExt -# AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im) +using FFTW +using ForwardDiff +import FFTW: plan_r2r, r2r +import FFTW.AbstractFFTs: dualplan, dual2array +import ForwardDiff: Dual +plan_r2r(x::AbstractArray{D}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, plan_r2r(dual2array(x), FLAG, 1 .+ dims)) +plan_r2r(x::AbstractArray{<:Complex{D}}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, plan_r2r(dual2array(x), FLAG, 1 .+ dims)) -plan_r2r(x::AbstractArray{<:Dual}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims) -plan_r2r(x::AbstractArray{<:Complex{<:Dual}}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims) - -for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex? - @eval begin - $plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims) - $plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, dims=1:ndims(x)) = $plan(dual2array(x), d, 1 .+ dims) - end -end r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x diff --git a/test/Project.toml b/test/Project.toml index 895c586..f0b0722 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,9 +5,11 @@ [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" +ForwardDiff = "0.10" Test = "<0.0.1, 1" diff --git a/test/fftwforwarddiff.jl b/test/fftwforwarddiff.jl index c0e2601..55cc7a0 100644 --- a/test/fftwforwarddiff.jl +++ b/test/fftwforwarddiff.jl @@ -1,3 +1,6 @@ +using FFTW, ForwardDiff, Test +using ForwardDiff: Dual, value, partials + @testset "r2r" begin x1 = Dual.(1:4.0, 2:5, 3:6) t = FFTW.r2r(x1, FFTW.R2HC) @@ -12,7 +15,7 @@ @test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC) f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1] - @test derivative(f, 0.1) ≡ 1.0 + @test ForwardDiff.derivative(f, 0.1) ≡ 1.0 @test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC) end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 6b158ac..1dda166 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -606,3 +606,5 @@ end AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true) end end + +include("fftwforwarddiff.jl") \ No newline at end of file From d7b17e1bc3294a116ad183b0e36c0d0f07f5a827 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 19 Dec 2024 21:51:34 +0000 Subject: [PATCH 3/3] dct, dct! and r2r! --- ext/FFTWForwardDiffExt.jl | 35 ++++++++++++++++++++++++++----- test/fftwforwarddiff.jl | 43 +++++++++++++++++++++++++++------------ 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/ext/FFTWForwardDiffExt.jl b/ext/FFTWForwardDiffExt.jl index 386b8f8..4e57fb3 100644 --- a/ext/FFTWForwardDiffExt.jl +++ b/ext/FFTWForwardDiffExt.jl @@ -1,15 +1,40 @@ module FFTWForwardDiffExt using FFTW using ForwardDiff -import FFTW: plan_r2r, r2r +import FFTW: plan_r2r, plan_r2r!, plan_dct, plan_dct!, plan_idct, plan_idct!, r2r, r2r!, dct, dct!, idct, idct!, fftwReal, REDFT10, REDFT01 import FFTW.AbstractFFTs: dualplan, dual2array import ForwardDiff: Dual -plan_r2r(x::AbstractArray{D}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, plan_r2r(dual2array(x), FLAG, 1 .+ dims)) -plan_r2r(x::AbstractArray{<:Complex{D}}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, plan_r2r(dual2array(x), FLAG, 1 .+ dims)) +for plan in (:plan_r2r, :plan_r2r!) + @eval begin + $plan(x::AbstractArray{D}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, $plan(dual2array(x), FLAG, 1 .+ dims)) + $plan(x::AbstractArray{<:Complex{D}}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, $plan(dual2array(x), FLAG, 1 .+ dims)) + end +end -r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x -r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x +for f in (:r2r, :r2r!) + pf = Symbol("plan_", f) + @eval begin + $f(x::AbstractArray{<:Dual}, kinds, region...) = $pf(x, kinds, region...) * x + $f(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = $pf(x, kinds, region...) * x + end +end + + +for f in (:dct, :dct!, :idct, :idct!) + pf = Symbol("plan_", f) + @eval begin + $f(x::AbstractArray{<:Dual}) = $pf(x) * x + $f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x + end +end + +for plan in (:plan_dct, :plan_dct!, :plan_idct, :plan_idct!) + @eval begin + $plan(x::AbstractArray{D}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, $plan(dual2array(x), 1 .+ dims; kwds...)) + $plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, $plan(dual2array(x), 1 .+ dims; kwds...)) + end +end end #module \ No newline at end of file diff --git a/test/fftwforwarddiff.jl b/test/fftwforwarddiff.jl index 55cc7a0..5f14906 100644 --- a/test/fftwforwarddiff.jl +++ b/test/fftwforwarddiff.jl @@ -1,21 +1,38 @@ using FFTW, ForwardDiff, Test using ForwardDiff: Dual, value, partials -@testset "r2r" begin - x1 = Dual.(1:4.0, 2:5, 3:6) - t = FFTW.r2r(x1, FFTW.R2HC) +@testset "ForwardDiff extension" begin + @testset "r2r" begin + x1 = Dual.(1:4.0, 2:5, 3:6) + t = FFTW.r2r(x1, FFTW.R2HC) - @test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC) - @test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC) - @test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC) + @test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC) - t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC) - @test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC) - @test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC) - @test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC) + t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC) + @test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC) - f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1] - @test ForwardDiff.derivative(f, 0.1) ≡ 1.0 + f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1] + @test ForwardDiff.derivative(f, 0.1) ≡ 1.0 - @test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC) + @test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC) + + x = [Dual(1.0,2,3), Dual(4,5,6)] + a = FFTW.r2r(x, FFTW.REDFT00) + b = FFTW.r2r!(x, FFTW.REDFT00) + @test a == b == x + end + + @testset "dct" begin + x = [Dual(1.0,2,3), Dual(4,5,6)] + a = dct(x) + b = dct!(x) + @test a == b == x + + c = x -> dct([x; 0; 0])[1] + @test ForwardDiff.derivative(c,0.1) ≈ 1/sqrt(3) + end end \ No newline at end of file