Skip to content

Commit ef8fc5b

Browse files
authored
Support new AbstractFFTs.AdjointStyle trait for FFT and DCT plans (#249)
* Support new AdjointStyle trait * Remove extraneous test deps * Add AbstractFFTs compat * Only run new adjoint tests on FFTW * Revert test/Project.toml change * Remove R2RAdjointStyle.(accidentally included) * Use small p for consistency * Fix typo
1 parent 008bc5b commit ef8fc5b

File tree

4 files changed

+33
-1
lines changed

4 files changed

+33
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1111
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1212

1313
[compat]
14-
AbstractFFTs = "1.0"
14+
AbstractFFTs = "1.5"
1515
FFTW_jll = "3.3.9"
1616
MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023"
1717
Preferences = "1.2"

src/dct.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,5 @@ end
171171
mul!(Array{T}(undef, p.plan.osz), p, copy(x)) # need copy to preserve input
172172

173173
*(p::DCTPlan{T,K,true}, x::StridedArray{T}) where {T,K} = mul!(x, p, x)
174+
175+
AbstractFFTs.AdjointStyle(::DCTPlan) = AbstractFFTs.UnitaryAdjointStyle()

src/fft.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,3 +1049,9 @@ function *(p::r2rFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K}
10491049
unsafe_execute!(p, x, x)
10501050
return x
10511051
end
1052+
1053+
#######################################################################
1054+
1055+
AbstractFFTs.AdjointStyle(::cFFTWPlan) = AbstractFFTs.FFTAdjointStyle()
1056+
AbstractFFTs.AdjointStyle(::rFFTWPlan{T, FORWARD}) where {T} = AbstractFFTs.RFFTAdjointStyle()
1057+
AbstractFFTs.AdjointStyle(p::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(p.osz[first(p.region)])

test/runtests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,27 @@ end
577577
end
578578
end
579579
end
580+
581+
@testset "DCT adjoints" begin
582+
# only test on FFTW because MKL is missing functionality
583+
if FFTW.get_provider() == "fftw"
584+
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5))
585+
y = randn(size(x))
586+
N = ndims(x)
587+
for dims in unique((1, 1:N, N))
588+
for P in (plan_dct(x, dims), plan_idct(x, dims))
589+
AbstractFFTs.TestUtils.test_plan_adjoint(P, x)
590+
end
591+
end
592+
end
593+
end
594+
end
595+
596+
@testset "AbstractFFTs FFT backend tests" begin
597+
# note this also tests adjoint functionality for FFT plans
598+
# only test on FFTW because MKL is missing functionality
599+
if FFTW.get_provider() == "fftw"
600+
AbstractFFTs.TestUtils.test_complex_ffts(Array)
601+
AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true)
602+
end
603+
end

0 commit comments

Comments
 (0)