Skip to content

Commit 09b8b38

Browse files
Apply suggestions from code review
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent 87758c8 commit 09b8b38

File tree

3 files changed

+10
-14
lines changed

3 files changed

+10
-14
lines changed

docs/src/implementations.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ To define a new FFT implementation in your own module, you should
2020
* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
2121
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).
2222

23-
* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to
24-
0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`.
23+
* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`.
2524

26-
* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` (or `A_mul_B!`) method.
25+
* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` method.
2726
This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs.
2827

2928
* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the
@@ -33,7 +32,7 @@ To define a new FFT implementation in your own module, you should
3332

3433
* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.
3534

36-
* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can take values:
35+
* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can take values:
3736
* `AbstractFFTs.NoProjectionStyle()`,
3837
* `AbstractFFTs.RealProjectionStyle()`, for plans which halve one of the output's dimensions analogously to [`rfft`](@ref),
3938
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans which expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.

src/chainrules.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,19 @@ end
170170

171171
function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::ScaledPlan, x::AbstractArray)
172172
y = P * x
173-
Δy = P * Δx + ΔP.scale / P.scale * y
173+
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
174174
return y, Δy
175175
end
176176
function ChainRulesCore.rrule(::typeof(*), P::ScaledPlan, x::AbstractArray)
177177
y = P * x
178178
project_x = ChainRulesCore.ProjectTo(x)
179-
project_scale = ChainRulesCore.ProjectTo(P.scale)
180179
Pt = P'
181180
scale = P.scale
182-
function mul_plan_pullback(ȳ)
181+
function mul_scaledplan_pullback(ȳ)
183182
= ChainRulesCore.@thunk(project_x(Pt * ȳ))
184-
scale_tangent = ChainRulesCore.@thunk(project_scale(sum(conj(y) .* ȳ) / conj(scale)))
183+
scale_tangent = ChainRulesCore.@thunk(dot(y, ȳ) / conj(scale))
185184
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
186185
return ChainRulesCore.NoTangent(), plan_tangent, x̄
187186
end
188-
return y, mul_plan_pullback
187+
return y, mul_scaledplan_pullback
189188
end

test/runtests.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ end
206206
y = randn(size(x))
207207
for dims in unique((1, 1:N, N))
208208
P = plan_fft(x, dims)
209-
@test AbstractFFTs.output_size(P) == size(x)
209+
@test @inferred(AbstractFFTs.output_size(P)) == size(x)
210210
@test AbstractFFTs.output_size(P') == size(x)
211211
Pinv = plan_ifft(x)
212212
@test AbstractFFTs.output_size(Pinv) == size(x)
@@ -222,7 +222,7 @@ end
222222
Px_sz = size(P * x)
223223
@test AbstractFFTs.output_size(P) == Px_sz
224224
@test AbstractFFTs.output_size(P') == size(x)
225-
y = randn(Px_sz) .+ randn(Px_sz) * im
225+
y = randn(Complex{Float64}, Px_sz)
226226
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
227227
@test AbstractFFTs.output_size(Pinv) == size(Pinv * y)
228228
@test AbstractFFTs.output_size(Pinv') == size(y)
@@ -256,9 +256,7 @@ end
256256
N = ndims(x)
257257
for dims in unique((1, 1:N, N))
258258
P = plan_rfft(x, dims)
259-
y_real = randn(size(P * x))
260-
y_imag = randn(size(P * x))
261-
y = y_real .+ y_imag .* im
259+
y = randn(Complex{Float64}, size(P * x))
262260
@test (P')' * x == P * x
263261
@test size(P') == AbstractFFTs.output_size(P)
264262
@test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) dot(P' * y, x)

0 commit comments

Comments
 (0)