Skip to content

Commit bfd3133

Browse files
committed
Rename ProjectionStyle's -> AdjointStyles and improve docs
1 parent d53f57d commit bfd3133

File tree

4 files changed

+54
-27
lines changed

4 files changed

+54
-27
lines changed

docs/src/api.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ AbstractFFTs.plan_brfft
2121
AbstractFFTs.plan_irfft
2222
AbstractFFTs.fftdims
2323
Base.adjoint
24+
AbstractFFTs.FFTAdjointStyle
25+
AbstractFFTs.RFFTAdjointStyle
26+
AbstractFFTs.BRFFTAdjointStyle
27+
AbstractFFTs.UnitaryAdjointStyle
2428
AbstractFFTs.fftshift
2529
AbstractFFTs.fftshift!
2630
AbstractFFTs.ifftshift

docs/src/implementations.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,9 @@ To define a new FFT implementation in your own module, you should
3232

3333
* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.
3434

35-
* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
36-
* `AbstractFFTs.NoProjectionStyle()`,
37-
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
38-
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
35+
* We offer an experimental `AdjointStyle` trait to enable automatic computation of adjoint plans via [`Base.adjoint`](@ref),
36+
(which `AbstractFFTs` uses to implement reverse-mode differentiation rules). To support adjoints in a new plan, define the trait `AbstractFFTs.AdjointStyle(::MyPlan)`. This should return a subtype of `AS <: AbstractFFTs.AdjointStyle` supporting `AbstractFFTs.adjoint_mul(::AdjointPlan, ::AbstractArray, ::AS)`. `AbstractFFTs` pre-implements [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref),
37+
[`AbstractFFTs.BRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
3938

4039
The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
4140
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.

src/definitions.jl

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -583,35 +583,57 @@ plan_brfft
583583

584584
##############################################################################
585585

586-
abstract type ProjectionStyle end
586+
abstract type AdjointStyle end
587587

588588
"""
589-
NoProjectionStyle()
589+
FFTAdjointStyle()
590590
591-
Projection style for complex to complex discrete Fourier transform
591+
Projection style for complex to complex discrete Fourier transforms.
592+
593+
Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
594+
the transform's inverse with an appropriate scaling.
592595
"""
593-
struct NoProjectionStyle <: ProjectionStyle end
596+
struct FFTAdjointStyle <: AdjointStyle end
594597

595598
"""
596-
RealProjectionStyle()
599+
RFFTAdjointStyle()
597600
598-
Projection style for complex to real discrete Fourier transform
601+
Projection style for real to complex discrete Fourier transforms, for plans that
602+
halve one of the output's dimensions analogously to [`rfft`](@ref).
603+
604+
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
605+
inverse, but with additional logic to handle the fact that the output is projected
606+
to exploit its conjugate symmetry (see [`rfft`](@ref)).
599607
"""
600-
struct RealProjectionStyle <: ProjectionStyle end
608+
struct RFFTAdjointStyle <: AdjointStyle end
601609

602610
"""
603-
RealInverseProjectionStyle()
611+
BRFFTAdjointStyle(d::Dim)
604612
605-
Projection style for inverse of complex to real discrete Fourier transform
613+
Projection style for complex to real discrete Fourier transforms, for plans that
614+
expect an input with a halved dimension analogously to [`irfft`](@ref), where `d`
615+
is the original length of the dimension.
616+
617+
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
618+
inverse, but with additional logic to handle the fact that the input is projected
619+
to exploit its conjugate symmetry (see [`irfft`](@ref)).
606620
"""
607-
struct RealInverseProjectionStyle <: ProjectionStyle
621+
struct BRFFTAdjointStyle <: AdjointStyle
608622
dim::Int
609623
end
610624

611-
output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
612-
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
613-
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
614-
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
625+
"""
626+
UnitaryAdjointStyle()
627+
628+
Projection style for unitary transforms, whose adjoint equals their inverse.
629+
"""
630+
struct UnitaryAdjointStyle <: AdjointStyle end
631+
632+
output_size(p::Plan) = _output_size(p, AdjointStyle(p))
633+
_output_size(p::Plan, ::FFTAdjointStyle) = size(p)
634+
_output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
635+
_output_size(p::Plan, s::BRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
636+
_output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)
615637

616638
struct AdjointPlan{T,P<:Plan} <: Plan{T}
617639
p::P
@@ -638,15 +660,15 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
638660
size(p::AdjointPlan) = output_size(p.p)
639661
output_size(p::AdjointPlan) = size(p.p)
640662

641-
Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
663+
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p.p))
642664

643-
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
665+
function adjoint_mul(p::AdjointPlan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
644666
dims = fftdims(p.p)
645667
N = normalization(T, size(p.p), dims)
646668
return (p.p \ x) / N
647669
end
648670

649-
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
671+
function adjoint_mul(p::AdjointPlan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
650672
dims = fftdims(p.p)
651673
N = normalization(T, size(p.p), dims)
652674
halfdim = first(dims)
@@ -659,7 +681,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where
659681
return p.p \ (x ./ convert(typeof(x), scale))
660682
end
661683

662-
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
684+
function adjoint_mul(p::AdjointPlan{T}, x::AbstractArray, ::BRFFTAdjointStyle) where {T}
663685
dims = fftdims(p.p)
664686
N = normalization(real(T), output_size(p.p), dims)
665687
halfdim = first(dims)
@@ -672,6 +694,8 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle)
672694
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
673695
end
674696

697+
adjoint_mul(p::AdjointPlan, x::AbstractArray, ::UnitaryAdjointStyle) = p.p \ x
698+
675699
# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
676700
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
677701
inv(p::AdjointPlan) = adjoint(inv(p.p))

test/testplans.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
2121
Base.size(p::InverseTestPlan) = p.sz
2222
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N
2323

24-
AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
25-
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
24+
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
25+
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()
2626

2727
function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
2828
return TestPlan{T}(region, size(x))
@@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
110110
end
111111
end
112112

113-
AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
114-
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
113+
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
114+
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.BRFFTAdjointStyle(p.d)
115115

116116
function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
117117
return TestRPlan{T}(region, size(x))
@@ -241,7 +241,7 @@ end
241241

242242
Base.size(p::InplaceTestPlan) = size(p.plan)
243243
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
244-
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
244+
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)
245245

246246
function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
247247
return InplaceTestPlan(plan_fft(x, region; kwargs...))

0 commit comments

Comments
 (0)