Skip to content

Commit 87758c8

Browse files
committed
Add rules and tests for ScaledPlan
1 parent 8ddfa97 commit 87758c8

File tree

3 files changed

+48
-14
lines changed

3 files changed

+48
-14
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ julia = "^1.0"
1212

1313
[extras]
1414
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
15+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1819

1920
[targets]
20-
test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"]
21+
test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]

src/chainrules.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,23 @@ function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
167167
end
168168
return y, mul_plan_pullback
169169
end
170+
171+
function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::ScaledPlan, x::AbstractArray)
172+
y = P * x
173+
Δy = P * Δx + ΔP.scale / P.scale * y
174+
return y, Δy
175+
end
176+
function ChainRulesCore.rrule(::typeof(*), P::ScaledPlan, x::AbstractArray)
177+
y = P * x
178+
project_x = ChainRulesCore.ProjectTo(x)
179+
project_scale = ChainRulesCore.ProjectTo(P.scale)
180+
Pt = P'
181+
scale = P.scale
182+
function mul_plan_pullback(ȳ)
183+
= ChainRulesCore.@thunk(project_x(Pt * ȳ))
184+
scale_tangent = ChainRulesCore.@thunk(project_scale(sum(conj(y) .* ȳ) / conj(scale)))
185+
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
186+
return ChainRulesCore.NoTangent(), plan_tangent, x̄
187+
end
188+
return y, mul_plan_pullback
189+
end

test/runtests.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license
22

33
using AbstractFFTs
4-
using AbstractFFTs: Plan
4+
using AbstractFFTs: Plan, ScaledPlan
55
using ChainRulesTestUtils
6-
using ChainRulesCore: NoTangent
6+
using FiniteDifferences
7+
import ChainRulesCore
78

89
using LinearAlgebra
910
using Random
@@ -293,9 +294,21 @@ end
293294
end
294295

295296
@testset "fft" begin
296-
for x in (randn(2), randn(2, 3), randn(3, 4, 5))
297-
N = ndims(x)
298-
complex_x = complex.(x)
297+
# Overloads to allow ChainRulesTestUtils to test rules w.r.t. ScaledPlan's. See https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/256
298+
InnerPlan = Union{TestPlan, InverseTestPlan, TestRPlan, InverseTestRPlan}
299+
function FiniteDifferences.to_vec(x::InnerPlan)
300+
function FFTPlan_from_vec(x_vec::Vector)
301+
return x
302+
end
303+
return Bool[], FFTPlan_from_vec
304+
end
305+
ChainRulesTestUtils.test_approx(::ChainRulesCore.AbstractZero, x::InnerPlan, msg=""; kwargs...) = true
306+
ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::InnerPlan) = ChainRulesCore.NoTangent()
307+
308+
for x_shape in ((2,), (2, 3), (3, 4, 5))
309+
N = length(x_shape)
310+
x = randn(x_shape)
311+
complex_x = x + randn(x_shape) * im
299312
for dims in unique((1, 1:N, N))
300313
# fft, ifft, bfft
301314
for f in (fft, ifft, bfft)
@@ -305,17 +318,17 @@ end
305318
test_rrule(f, complex_x, dims)
306319
end
307320
for pf in (plan_fft, plan_ifft, plan_bfft)
308-
test_frule(*, pf(x, dims) NoTangent(), x)
309-
test_rrule(*, pf(x, dims) NoTangent(), x)
310-
test_frule(*, pf(complex_x, dims) NoTangent(), complex_x)
311-
test_rrule(*, pf(complex_x, dims) NoTangent(), complex_x)
321+
test_frule(*, pf(x, dims), x)
322+
test_rrule(*, pf(x, dims), x)
323+
test_frule(*, pf(complex_x, dims), complex_x)
324+
test_rrule(*, pf(complex_x, dims), complex_x)
312325
end
313326

314327
# rfft
315328
test_frule(rfft, x, dims)
316329
test_rrule(rfft, x, dims)
317-
test_frule(*, plan_rfft(x, dims) NoTangent(), x)
318-
test_rrule(*, plan_rfft(x, dims) NoTangent(), x)
330+
test_frule(*, plan_rfft(x, dims), x)
331+
test_rrule(*, plan_rfft(x, dims), x)
319332

320333
# irfft, brfft
321334
for f in (irfft, brfft)
@@ -328,8 +341,8 @@ end
328341
end
329342
for pf in (plan_irfft, plan_brfft)
330343
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
331-
test_frule(*, pf(complex_x, d, dims) NoTangent(), complex_x)
332-
test_rrule(*, pf(complex_x, d, dims) NoTangent(), complex_x)
344+
test_frule(*, pf(complex_x, d, dims), complex_x)
345+
test_rrule(*, pf(complex_x, d, dims), complex_x)
333346
end
334347
end
335348
end

0 commit comments

Comments
 (0)