Skip to content

Commit cf63bdd

Browse files
Apply suggestions from adjoint plan code review
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent 6c81dfd commit cf63bdd

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/definitions.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -592,16 +592,16 @@ _output_size(p::Plan, ::NoProjectionStyle) = size(p)
592592
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p))
593593
_output_size(p::Plan, ::RealInverseProjectionStyle) = brfft_output_size(size(p), irfft_dim(p), region(p))
594594

595-
mutable struct AdjointPlan{T,P} <: Plan{T}
595+
mutable struct AdjointPlan{T,P<:Plan} <: Plan{T}
596596
p::P
597597
pinv::Plan
598598
AdjointPlan{T,P}(p) where {T,P} = new(p)
599599
end
600600

601601
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
602-
Base.adjoint(p::AdjointPlan{T}) where {T} = p.p
602+
Base.adjoint(p::AdjointPlan) = p.p
603603
# always have AdjointPlan inside ScaledPlan.
604-
Base.adjoint(p::ScaledPlan{T}) where {T} = ScaledPlan{T}(p.p', p.scale)
604+
Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
605605

606606
size(p::AdjointPlan) = output_size(p.p)
607607
output_size(p::AdjointPlan) = size(p.p)
@@ -611,7 +611,7 @@ Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
611611
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
612612
dims = region(p.p)
613613
N = normalization(T, size(p.p), dims)
614-
return 1/N * (p.p \ x)
614+
return (p.p \ x) / N
615615
end
616616

617617
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T}
@@ -621,10 +621,10 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where
621621
d = size(p.p, halfdim)
622622
n = output_size(p.p, halfdim)
623623
scale = reshape(
624-
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
625-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
624+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
625+
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
626626
)
627-
return 1/N * (p.p \ (x ./ scale))
627+
return p.p \ (x ./ scale)
628628
end
629629

630630
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
@@ -635,9 +635,9 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle)
635635
d = output_size(p.p, halfdim)
636636
scale = reshape(
637637
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
638-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
638+
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
639639
)
640-
return 1/N * scale .* (p.p \ x)
640+
return scale ./ N .* (p.p \ x)
641641
end
642642

643-
plan_inv(p::AdjointPlan) = AdjointPlan(plan_inv(p.p))
643+
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))

0 commit comments

Comments
 (0)