Skip to content

Commit d967aa2

Browse files
committed
More tweaks to address code review
1 parent 09b8b38 commit d967aa2

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/chainrules.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,13 @@ function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::ScaledPlan, x::Abst
175175
end
176176
function ChainRulesCore.rrule(::typeof(*), P::ScaledPlan, x::AbstractArray)
177177
y = P * x
178-
project_x = ChainRulesCore.ProjectTo(x)
179178
Pt = P'
180179
scale = P.scale
180+
project_x = ChainRulesCore.ProjectTo(x)
181+
project_scale = ChainRulesCore.ProjectTo(scale)
181182
function mul_scaledplan_pullback(ȳ)
182183
= ChainRulesCore.@thunk(project_x(Pt * ȳ))
183-
scale_tangent = ChainRulesCore.@thunk(dot(y, ȳ) / conj(scale))
184+
scale_tangent = ChainRulesCore.@thunk(project_scale(dot(y, ȳ) / conj(scale)))
184185
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
185186
return ChainRulesCore.NoTangent(), plan_tangent, x̄
186187
end

test/runtests.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ end
238238
y = randn(size(x))
239239
for dims in unique((1, 1:N, N))
240240
P = plan_fft(x, dims)
241-
@test (P')' * x == P * x # test adjoint of adjoint
241+
@test (P')' === P # test adjoint of adjoint
242242
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
243243
@test dot(y, P * x) dot(P' * y, x) # test validity of adjoint
244244
@test dot(y, P \ x) dot(P' \ y, x)
@@ -259,13 +259,13 @@ end
259259
y = randn(Complex{Float64}, size(P * x))
260260
@test (P')' * x == P * x
261261
@test size(P') == AbstractFFTs.output_size(P)
262-
@test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) dot(P' * y, x)
263-
@test dot(y_real, real.(P' \ x)) + dot(y_imag, imag.(P' \ x)) dot(P \ y, x)
262+
@test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) dot(P' * y, x)
263+
@test dot(real.(y), real.(P' \ x)) + dot(imag.(y), imag.(P' \ x)) dot(P \ y, x)
264264
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
265265
@test (Pinv')' * y == Pinv * y
266266
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
267-
@test dot(x, Pinv * y) dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x))
268-
@test dot(x, Pinv' \ y) dot(y_real, real.(Pinv \ x)) + dot(y_imag, imag.(Pinv \ x))
267+
@test dot(x, Pinv * y) dot(real.(y), real.(Pinv' * x)) + dot(imag.(y), imag.(Pinv' * x))
268+
@test dot(x, Pinv' \ y) dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x))
269269
end
270270
end
271271
end

0 commit comments

Comments
 (0)