Skip to content

Commit be4aa9b

Browse files
Eagerly invert plan on formation of AdjointPlan: correct eltype and remove output_size (#113)
* Eagerly compute inv(p) in adjoint plans, and use it to fix eltype and remove output_size * Remove output_size tests, replace with size test for all plans in TestUtils * Ensure <= 1 pass over array in adjoint application * Try to fix type stability issues * Fix another type instability * Replace division with multiplication in adjoint loop * Switch isapprox to equality Co-authored-by: Steven G. Johnson <stevenj@mit.edu> * Lift out multiply-by-2 --------- Co-authored-by: Steven G. Johnson <stevenj@mit.edu>
1 parent fae1170 commit be4aa9b

File tree

5 files changed

+36
-67
lines changed

5 files changed

+36
-67
lines changed

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ It is also relevant to implementers of FFT plans that wish to support adjoints.
3838
```@docs
3939
Base.adjoint
4040
AbstractFFTs.AdjointStyle
41-
AbstractFFTs.output_size
4241
AbstractFFTs.adjoint_mul
4342
AbstractFFTs.FFTAdjointStyle
4443
AbstractFFTs.RFFTAdjointStyle

docs/src/implementations.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ To define a new FFT implementation in your own module, you should
3535

3636
* To support adjoints in a new plan, define the trait [`AbstractFFTs.AdjointStyle`](@ref).
3737
`AbstractFFTs` implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
38-
To define a new adjoint style, define the methods [`AbstractFFTs.adjoint_mul`](@ref) and [`AbstractFFTs.output_size`](@ref).
38+
To define a new adjoint style, define the method [`AbstractFFTs.adjoint_mul`](@ref).
3939

4040
The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
4141
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.

ext/AbstractFFTsTestExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ const TEST_CASES = (
5454

5555
function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
5656
_copy = copy_input ? copy : identity
57+
@test size(P) == size(x)
5758
if !inplace_plan
5859
@test P * _copy(x) x_transformed
5960
@test P \ (P * _copy(x)) x
@@ -74,9 +75,9 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea
7475
_copy = copy_input ? copy : identity
7576
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
7677
# test basic properties
77-
@test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
78+
@test eltype(P') === eltype(y)
7879
@test (P')' === P # test adjoint of adjoint
79-
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
80+
@test size(P') == size(y) # test size of adjoint
8081
# test correctness of adjoint and its inverse via the dot test
8182
if !real_plan
8283
@test dot(y, P * _copy(x)) dot(P' * _copy(y), x)

src/definitions.jl

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
259259
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)
260260

261261
size(p::ScaledPlan) = size(p.p)
262-
output_size(p::ScaledPlan) = output_size(p.p)
263262

264263
fftdims(p::ScaledPlan) = fftdims(p.p)
265264

@@ -640,20 +639,6 @@ Adjoint style for unitary transforms, whose adjoint equals their inverse.
640639
"""
641640
struct UnitaryAdjointStyle <: AdjointStyle end
642641

643-
"""
644-
output_size(p::Plan, [dim])
645-
646-
Return the size of the output of a plan `p`, optionally at a specified dimension `dim`.
647-
648-
Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define `output_size(::Plan, ::AS)`.
649-
"""
650-
output_size(p::Plan) = output_size(p, AdjointStyle(p))
651-
output_size(p::Plan, dim) = output_size(p)[dim]
652-
output_size(p::Plan, ::FFTAdjointStyle) = size(p)
653-
output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
654-
output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
655-
output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)
656-
657642
struct AdjointPlan{T,P<:Plan} <: Plan{T}
658643
p::P
659644
AdjointPlan{T,P}(p) where {T,P} = new(p)
@@ -669,13 +654,15 @@ Return a plan that performs the adjoint operation of the original plan.
669654
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
670655
coverage of `Base.adjoint` in downstream implementations may be limited.
671656
"""
672-
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
657+
# We eagerly form the plan inverse in the adjoint(p) call, which will be cached for subsequent calls.
658+
# This is reasonable, as inv(p) would do the same, and necessary in order to compute the correct input
659+
# type for the adjoint plan and encode it in its type.
660+
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{eltype(inv(p)), typeof(p)}(p)
673661
Base.adjoint(p::AdjointPlan) = p.p
674662
# always have AdjointPlan inside ScaledPlan.
675663
Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
676664

677-
size(p::AdjointPlan) = output_size(p.p)
678-
output_size(p::AdjointPlan) = size(p.p)
665+
size(p::AdjointPlan) = size(inv(p.p))
679666
fftdims(p::AdjointPlan) = fftdims(p.p)
680667

681668
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x)
@@ -693,40 +680,57 @@ adjoint_mul(p::Plan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p))
693680
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
694681
dims = fftdims(p)
695682
N = normalization(T, size(p), dims)
696-
return (p \ x) / N
683+
pinv = inv(p)
684+
# Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x.
685+
# Even if not, ensure that we do only one pass by combining the normalization with the plan.
686+
if pinv isa ScaledPlan && pinv.scale == N
687+
return pinv.p * x
688+
else
689+
return (1/N * pinv) * x
690+
end
697691
end
698692

699693
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
700694
dims = fftdims(p)
701695
N = normalization(T, size(p), dims)
702696
halfdim = first(dims)
703697
d = size(p, halfdim)
704-
n = output_size(p, halfdim)
698+
pinv = inv(p)
699+
n = size(pinv, halfdim)
700+
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
701+
scale = pinv isa ScaledPlan ? pinv.scale / 2N : 1 / 2N
702+
twoscale = 2 * scale
703+
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
705704
y = map(x, CartesianIndices(x)) do xj, j
706705
i = j[halfdim]
707706
yj = if i == 1 || (i == n && 2 * (i - 1) == d)
708-
xj / N
707+
xj * twoscale
709708
else
710-
xj / (2 * N)
709+
xj * scale
711710
end
712711
return yj
713712
end
714-
return p \ y
713+
return unscaled_pinv * y
715714
end
716715

717716
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
718717
dims = fftdims(p)
719-
N = normalization(real(T), output_size(p), dims)
718+
N = normalization(real(T), size(inv(p)), dims)
720719
halfdim = first(dims)
721720
n = size(p, halfdim)
722-
d = output_size(p, halfdim)
723-
y = p \ x
721+
pinv = inv(p)
722+
d = size(pinv, halfdim)
723+
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
724+
scale = pinv isa ScaledPlan ? pinv.scale / N : 1 / N
725+
twoscale = 2 * scale
726+
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
727+
y = unscaled_pinv * x
724728
z = map(y, CartesianIndices(y)) do yj, j
725729
i = j[halfdim]
726730
zj = if i == 1 || (i == n && 2 * (i - 1) == d)
727-
yj / N
731+
yj * scale
728732
else
729-
2 * yj / N
733+
yj * twoscale
730734
end
731735
return zj
732736
end

test/runtests.jl

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -118,41 +118,6 @@ end
118118
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
119119
end
120120

121-
@testset "output size" begin
122-
@testset "complex fft output size" begin
123-
for x_shape in ((3,), (3, 4), (3, 4, 5))
124-
N = length(x_shape)
125-
real_x = randn(x_shape)
126-
complex_x = randn(ComplexF64, x_shape)
127-
for x in (real_x, complex_x)
128-
for dims in unique((1, 1:N, N))
129-
P = plan_fft(x, dims)
130-
@test @inferred(AbstractFFTs.output_size(P)) == size(x)
131-
@test AbstractFFTs.output_size(P') == size(x)
132-
Pinv = plan_ifft(x)
133-
@test AbstractFFTs.output_size(Pinv) == size(x)
134-
@test AbstractFFTs.output_size(Pinv') == size(x)
135-
end
136-
end
137-
end
138-
end
139-
@testset "real fft output size" begin
140-
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
141-
N = ndims(x)
142-
for dims in unique((1, 1:N, N))
143-
P = plan_rfft(x, dims)
144-
Px_sz = size(P * x)
145-
@test AbstractFFTs.output_size(P) == Px_sz
146-
@test AbstractFFTs.output_size(P') == size(x)
147-
y = randn(ComplexF64, Px_sz)
148-
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
149-
@test AbstractFFTs.output_size(Pinv) == size(Pinv * y)
150-
@test AbstractFFTs.output_size(Pinv') == size(y)
151-
end
152-
end
153-
end
154-
end
155-
156121
# Test that dims defaults to 1:ndims for fft-like functions
157122
@testset "Default dims" begin
158123
for x in (randn(3), randn(3, 4), randn(3, 4, 5))

0 commit comments

Comments
 (0)