Skip to content

Commit 5bd43fc

Browse files
committed
define subtypes of ProjectionStyle
1 parent 430e44a commit 5bd43fc

File tree

3 files changed

+62
-11
lines changed

3 files changed

+62
-11
lines changed

src/dct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,4 @@ end
172172

173173
*(p::DCTPlan{T,K,true}, x::StridedArray{T}) where {T,K} = mul!(x, p, x)
174174

175-
AbstractFFTs.ProjectionStyle(::DCTPlan) = AbstractFFTs.NoProjectionStyle()
175+
AbstractFFTs.ProjectionStyle(::DCTPlan) = AbstractFFTs.UnitaryProjectionStyle()

src/fft.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,47 @@ function *(p::r2rFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K}
10501050
return x
10511051
end
10521052

1053+
#######################################################################
1054+
1055+
"""
1056+
UnitaryProjectionStyle()
1057+
1058+
Projection style for transforms that are unitary
1059+
"""
1060+
struct UnitaryProjectionStyle <: AbstractFFTs.ProjectionStyle end
1061+
1062+
"""
1063+
R2RProjectionStyle(kinds)
1064+
1065+
Projection style for real to real transforms
1066+
"""
1067+
struct R2RProjectionStyle{K} <: AbstractFFTs.ProjectionStyle
1068+
kinds::K
1069+
end
1070+
10531071
AbstractFFTs.ProjectionStyle(::cFFTWPlan) = AbstractFFTs.NoProjectionStyle()
10541072
AbstractFFTs.ProjectionStyle(::rFFTWPlan{T, FORWARD}) where {T} = AbstractFFTs.RealProjectionStyle()
1055-
AbstractFFTs.ProjectionStyle(p::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.RealInverseProjectionStyle(p.osz[first(p.region)])
1056-
AbstractFFTs.ProjectionStyle(::r2rFFTWPlan) = AbstractFFTs.NoProjectionStyle()
1073+
AbstractFFTs.ProjectionStyle(P::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.RealInverseProjectionStyle(p.osz[first(P.region)])
1074+
AbstractFFTs.ProjectionStyle(P::r2rFFTWPlan) = AbstractFFTs.R2RProjectionStyle(P.kinds)
1075+
1076+
AbstractFFTs._output_size(p::Plan, ::UnitaryProjectionStyle) = size(p)
1077+
AbstractFFTs._output_size(p::Plan, ::R2RProjectionStyle) = size(p)
1078+
1079+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::UnitaryProjectionStyle) where {T}
1080+
return p.p \ x
1081+
end
1082+
1083+
function _mul(p::AdjointPlan{T}, x::AbstractArray, PS::R2RProjectionStyle) where {T}
1084+
kinds = PS.kinds
1085+
dims = fftdims(p)
1086+
1087+
N = 1
1088+
# Normalization: P \ P * u = K * u
1089+
# 0, 1 , 2 : NoProjectionStyle() type: K = prod(size(N))
1090+
# 3, 6 : DCTs: K = 2 * (N - 1)
1091+
# 4, 5 : DCTs: K = 2 * N
1092+
# 7, 10: DSTs: K = 2 * (N + 1)
1093+
# 8, 9 : DSTs: K = 2 * N
1094+
return (p.p \ x) / N
1095+
end
1096+

test/runtests.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -556,14 +556,6 @@ end
556556
@test FFTW.get_num_threads() == 2 # Unchanged
557557
end
558558

559-
@testset "AbstractFFTs upstream tests" begin
560-
# A necessary hack since AbstractFFTsTestUtils is not its own registered package yet.
561-
# See https://github.com/JuliaMath/AbstractFFTs.jl/pull/78
562-
include(joinpath(pathof(AbstractFFTs), "..", "..", "test", "testbackend.jl"))
563-
using .AbstractFFTsTestUtils
564-
test_fft_backend()
565-
end
566-
567559
@testset "type-inference in r2r plans" begin
568560
# Compare with definition
569561
function testr2r(::Type{T}) where {T}
@@ -585,3 +577,22 @@ end
585577
end
586578
end
587579
end
580+
581+
@testset "adjoint" begin
582+
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5))
583+
y = randn(size(x))
584+
N = ndims(x)
585+
for dims in unique((1, 1:N, N))
586+
for P in (plan_dct(x, dims), plan_idct(x, dims))
587+
@test (P')' * x == P * x
588+
@test size(P') == AbstractFFTs.output_size(P)
589+
@test dot(y, P * x) dot(P' * y, x)
590+
@test dot(y, P \ x) dot(P' \ y, x)
591+
end # P
592+
end # dims
593+
end # x
594+
end
595+
596+
@testset "ChainRules" begin
597+
598+
end

0 commit comments

Comments
 (0)