Skip to content

Commit ec2aef1

Browse files
gaurav-aryaGaurav Arya
authored andcommitted
Implement AdjointPlan
1 parent c05abee commit ec2aef1

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

src/definitions.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T
1212

1313
# size(p) should return the size of the input array for p
1414
size(p::Plan, d) = size(p)[d]
15+
output_size(p::Plan, d) = output_size(p)[d]
1516
ndims(p::Plan) = length(size(p))
1617
length(p::Plan) = prod(size(p))::Int
1718

@@ -254,6 +255,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
254255
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)
255256

256257
size(p::ScaledPlan) = size(p.p)
258+
output_size(p::ScaledPlan) = size(p)
257259

258260
region(p::ScaledPlan) = region(p.p)
259261

@@ -301,9 +303,12 @@ for f in (:brfft, :irfft)
301303
end
302304

303305
for f in (:brfft, :irfft)
306+
pf = Symbol("plan_", f)
304307
@eval begin
305308
$f(x::AbstractArray{<:Real}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
309+
$pf(x::AbstractArray{<:Real}, d::Integer, region; kws...) = $pf(complexfloat(x), d, region; kws...)
306310
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
311+
$pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region; kws...) = $pf(complexfloat(x), d, region; kws...)
307312
end
308313
end
309314

@@ -343,6 +348,16 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N}
343348
return ntuple(i -> i == d1 ? d : sz[i], Val(N))
344349
end
345350

351+
function output_size(p::Plan)
352+
if projection_style(p) == :none
353+
return size(p)
354+
elseif projection_style(p) == :real
355+
return rfft_output_size(size(p), region(p))
356+
elseif projection_style(p) == :real_inv
357+
return brfft_output_size(size(p), irfft_dim(p), region(p))
358+
end
359+
end
360+
346361
plan_irfft(x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} =
347362
ScaledPlan(plan_brfft(x, d, region; kws...),
348363
normalization(T, brfft_output_size(x, d, region), region))
@@ -575,3 +590,58 @@ Pre-plan an optimized real-input unnormalized transform, similar to
575590
the same as for [`brfft`](@ref).
576591
"""
577592
plan_brfft
593+
594+
##############################################################################
595+
596+
region(p::Plan) = p.region
597+
region(p::ScaledPlan) = region(p.p)
598+
599+
# Projection style (:none, :real, or :real_inv) to handle real FFTs
600+
function projection_style end
601+
# Length of halved dimension, needed only for irfft
602+
function irfft_dim end
603+
604+
mutable struct AdjointPlan{T,P} <: Plan{T}
605+
p::P
606+
pinv::Plan
607+
AdjointPlan{T,P}(p) where {T,P} = new(p)
608+
# always have adjoint inside scaled
609+
AdjointPlan{T,P}(p::P) where {T,P<:ScaledPlan{T}} = ScaledPlan{T}(AdjointPlan{T}(p.p), p.scale)
610+
AdjointPlan{T,P}(p::AdjointPlan{T}) where {T,P} = new(p.p)
611+
end
612+
613+
AdjointPlan{T}(p::P) where {T,P} = AdjointPlan{T,P}(p)
614+
AdjointPlan(p::Plan{T}) where {T} = AdjointPlan{T}(p)
615+
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T}(p)
616+
617+
size(p::AdjointPlan) = output_size(p)
618+
output_size(p::AdjointPlan) = size(p)
619+
620+
function Base.:*(p::AdjointPlan{T}, x::AbstractArray) where {T}
621+
dims = region(p.p)
622+
halfdim = first(dims)
623+
d = size(p.p, halfdim)
624+
n = output_size(p.p, halfdim)
625+
if projection_style(p.p) == :none
626+
N = normalization(T, size(p.p), dims)
627+
return 1/N * (p.p \ x)
628+
elseif projection_style(p.p) == :real
629+
N = normalization(T, size(p.p), dims)
630+
scale = reshape(
631+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
632+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
633+
)
634+
return 1/N * (p.p \ (x ./ scale))
635+
elseif projection_style(p.p) == :real_inv
636+
N = normalization(real(T), output_size(p.p), dims)
637+
scale = reshape(
638+
[(i == 1 || (i == d && 2 * (i - 1)) == n) ? 1 : 2 for i in 1:d],
639+
ntuple(i -> i == first(dims) ? d : 1, Val(ndims(x)))
640+
)
641+
return 1/N * scale .* (p.p \ x)
642+
else
643+
error("plan must define a valid projection style")
644+
end
645+
end
646+
647+
plan_inv(p::AdjointPlan) = AdjointPlan(plan_inv(p.p))

0 commit comments

Comments
 (0)