Skip to content

Commit 7f88718

Browse files
committed
Implement AdjointPlans
1 parent b848c54 commit 7f88718

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

src/definitions.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T
1212

1313
# size(p) should return the size of the input array for p
1414
size(p::Plan, d) = size(p)[d]
15+
output_size(p::Plan, d) = output_size(p)[d]
1516
ndims(p::Plan) = length(size(p))
1617
length(p::Plan) = prod(size(p))::Int
1718

@@ -254,6 +255,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
254255
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)
255256

256257
size(p::ScaledPlan) = size(p.p)
258+
output_size(p::ScaledPlan) = output_size(p.p)
257259

258260
fftdims(p::ScaledPlan) = fftdims(p.p)
259261

@@ -575,3 +577,67 @@ Pre-plan an optimized real-input unnormalized transform, similar to
575577
the same as for [`brfft`](@ref).
576578
"""
577579
plan_brfft
580+
581+
##############################################################################
582+
583+
struct NoProjectionStyle end
584+
struct RealProjectionStyle end
585+
struct RealInverseProjectionStyle end
586+
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}
587+
588+
function irfft_dim end
589+
590+
output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
591+
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
592+
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p))
593+
_output_size(p::Plan, ::RealInverseProjectionStyle) = brfft_output_size(size(p), irfft_dim(p), region(p))
594+
595+
mutable struct AdjointPlan{T,P} <: Plan{T}
596+
p::P
597+
pinv::Plan
598+
AdjointPlan{T,P}(p) where {T,P} = new(p)
599+
end
600+
601+
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
602+
Base.adjoint(p::AdjointPlan{T}) where {T} = p.p
603+
# always have AdjointPlan inside ScaledPlan.
604+
Base.adjoint(p::ScaledPlan{T}) where {T} = ScaledPlan{T}(p.p', p.scale)
605+
606+
size(p::AdjointPlan) = output_size(p.p)
607+
output_size(p::AdjointPlan) = size(p.p)
608+
609+
Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
610+
611+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
612+
dims = region(p.p)
613+
N = normalization(T, size(p.p), dims)
614+
return 1/N * (p.p \ x)
615+
end
616+
617+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T}
618+
dims = region(p.p)
619+
N = normalization(T, size(p.p), dims)
620+
halfdim = first(dims)
621+
d = size(p.p, halfdim)
622+
n = output_size(p.p, halfdim)
623+
scale = reshape(
624+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
625+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
626+
)
627+
return 1/N * (p.p \ (x ./ scale))
628+
end
629+
630+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
631+
dims = region(p.p)
632+
N = normalization(real(T), output_size(p.p), dims)
633+
halfdim = first(dims)
634+
n = size(p.p, halfdim)
635+
d = output_size(p.p, halfdim)
636+
scale = reshape(
637+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
638+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
639+
)
640+
return 1/N * scale .* (p.p \ x)
641+
end
642+
643+
plan_inv(p::AdjointPlan) = AdjointPlan(plan_inv(p.p))

0 commit comments

Comments
 (0)