Skip to content

Commit 615a572

Browse files
author
Gaurav Arya
committed
Dispatch on ProjectionStyle trait for adjoints
1 parent 840bdf8 commit 615a572

File tree

2 files changed

+43
-38
lines changed

2 files changed

+43
-38
lines changed

src/definitions.jl

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -345,16 +345,6 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N}
345345
return ntuple(i -> i == d1 ? d : sz[i], Val(N))
346346
end
347347

348-
function output_size(p::Plan)
349-
if projection_style(p) == :none
350-
return size(p)
351-
elseif projection_style(p) == :real
352-
return rfft_output_size(size(p), region(p))
353-
elseif projection_style(p) == :real_inv
354-
return brfft_output_size(size(p), irfft_dim(p), region(p))
355-
end
356-
end
357-
358348
plan_irfft(x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} =
359349
ScaledPlan(plan_brfft(x, d, region; kws...),
360350
normalization(T, brfft_output_size(x, d, region), region))
@@ -590,11 +580,19 @@ plan_brfft
590580

591581
##############################################################################
592582

593-
# Projection style (:none, :real, or :real_inv) to handle real FFTs
594-
function projection_style end
595-
# Length of halved dimension, needed only for irfft
583+
struct NoProjectionStyle end
584+
struct RealProjectionStyle end
585+
struct RealInverseProjectionStyle end
586+
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}
587+
596588
function irfft_dim end
597589

590+
ProjectionStyle(p::Plan) = error("No projection style defined for plan")
591+
output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
592+
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
593+
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p))
594+
_output_size(p::Plan, ::RealInverseProjectionStyle) = brfft_output_size(size(p), irfft_dim(p), region(p))
595+
598596
mutable struct AdjointPlan{T,P} <: Plan{T}
599597
p::P
600598
pinv::Plan
@@ -611,31 +609,38 @@ Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T}(p)
611609
size(p::AdjointPlan) = output_size(p)
612610
output_size(p::AdjointPlan) = size(p)
613611

614-
function Base.:*(p::AdjointPlan{T}, x::AbstractArray) where {T}
612+
Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
613+
614+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
615+
dims = region(p.p)
616+
N = normalization(T, size(p.p), dims)
617+
return 1/N * (p.p \ x)
618+
end
619+
620+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T}
615621
dims = region(p.p)
622+
N = normalization(T, size(p.p), dims)
616623
halfdim = first(dims)
617624
d = size(p.p, halfdim)
618625
n = output_size(p.p, halfdim)
619-
if projection_style(p.p) == :none
620-
N = normalization(T, size(p.p), dims)
621-
return 1/N * (p.p \ x)
622-
elseif projection_style(p.p) == :real
623-
N = normalization(T, size(p.p), dims)
624-
scale = reshape(
625-
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
626-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
627-
)
628-
return 1/N * (p.p \ (x ./ scale))
629-
elseif projection_style(p.p) == :real_inv
630-
N = normalization(real(T), output_size(p.p), dims)
631-
scale = reshape(
632-
[(i == 1 || (i == d && 2 * (i - 1)) == n) ? 1 : 2 for i in 1:d],
633-
ntuple(i -> i == first(dims) ? d : 1, Val(ndims(x)))
634-
)
635-
return 1/N * scale .* (p.p \ x)
636-
else
637-
error("plan must define a valid projection style")
638-
end
626+
scale = reshape(
627+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
628+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
629+
)
630+
return 1/N * (p.p \ (x ./ scale))
631+
end
632+
633+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
634+
dims = region(p.p)
635+
N = normalization(real(T), output_size(p.p), dims)
636+
halfdim = first(dims)
637+
n = size(p.p, halfdim)
638+
d = output_size(p.p, halfdim)
639+
scale = reshape(
640+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
641+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
642+
)
643+
return 1/N * scale .* (p.p \ x)
639644
end
640645

641646
plan_inv(p::AdjointPlan) = AdjointPlan(plan_inv(p.p))

test/testplans.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
2121
Base.size(p::InverseTestPlan) = p.sz
2222
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N
2323

24-
AbstractFFTs.projection_style(::TestPlan) = :none
25-
AbstractFFTs.projection_style(::InverseTestPlan) = :none
24+
AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
25+
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
2626

2727
function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
2828
return TestPlan{T}(region, size(x))
@@ -111,8 +111,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{T}
111111
end
112112
end
113113

114-
AbstractFFTs.projection_style(::TestRPlan) = :real
115-
AbstractFFTs.projection_style(::InverseTestRPlan) = :real_inv
114+
AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
115+
AbstractFFTs.ProjectionStyle(::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle()
116116
AbstractFFTs.irfft_dim(p::InverseTestRPlan) = p.d
117117

118118
function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T}

0 commit comments

Comments
 (0)