Skip to content

Commit 33e365e

Browse files
gaurav-aryaGaurav Arya
authored andcommitted
Implement FFT chain rules using adjoint plans
1 parent ec2aef1 commit 33e365e

File tree

1 file changed

+52
-116
lines changed

1 file changed

+52
-116
lines changed

src/chainrules.jl

Lines changed: 52 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,43 @@
1-
# ffts
2-
function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims)
3-
y = fft(x, dims)
4-
Δy = fft(Δx, dims)
5-
return y, Δy
6-
end
7-
function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims)
8-
y = fft(x, dims)
9-
project_x = ChainRulesCore.ProjectTo(x)
10-
function fft_pullback(ȳ)
11-
= project_x(bfft(ChainRulesCore.unthunk(ȳ), dims))
12-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
13-
end
14-
return y, fft_pullback
15-
end
16-
17-
function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims)
18-
y = rfft(x, dims)
19-
Δy = rfft(Δx, dims)
20-
return y, Δy
21-
end
22-
function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
23-
y = rfft(x, dims)
24-
25-
# compute scaling factors
26-
halfdim = first(dims)
27-
d = size(x, halfdim)
28-
n = size(y, halfdim)
29-
scale = reshape(
30-
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
31-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
32-
)
33-
34-
project_x = ChainRulesCore.ProjectTo(x)
35-
function rfft_pullback(ȳ)
36-
= project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims))
37-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
38-
end
39-
return y, rfft_pullback
40-
end
41-
42-
function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims)
43-
y = ifft(x, dims)
44-
Δy = ifft(Δx, dims)
45-
return y, Δy
46-
end
47-
function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims)
48-
y = ifft(x, dims)
49-
invN = normalization(y, dims)
50-
project_x = ChainRulesCore.ProjectTo(x)
51-
function ifft_pullback(ȳ)
52-
= project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims))
53-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
54-
end
55-
return y, ifft_pullback
56-
end
57-
58-
function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims)
59-
y = irfft(x, d, dims)
60-
Δy = irfft(Δx, d, dims)
61-
return y, Δy
62-
end
63-
function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
64-
y = irfft(x, d, dims)
65-
66-
# compute scaling factors
67-
halfdim = first(dims)
68-
n = size(x, halfdim)
69-
invN = normalization(y, dims)
70-
twoinvN = 2 * invN
71-
scale = reshape(
72-
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
73-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
74-
)
75-
76-
project_x = ChainRulesCore.ProjectTo(x)
77-
function irfft_pullback(ȳ)
78-
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
79-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
1+
for f in (:fft, :bfft, :ifft, :rfft)
2+
pf = Symbol("plan_", f)
3+
@eval begin
4+
function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, dims)
5+
y = $f(x, dims)
6+
Δy = $f(Δx, dims)
7+
return y, Δy
8+
end
9+
function ChainRulesCore.rrule(::typeof($f), x::T, dims) where {T<:AbstractArray}
10+
y = $f(x, dims)
11+
project_x = ChainRulesCore.ProjectTo(x)
12+
ax = axes(x)
13+
function fft_pullback(ȳ)
14+
= project_x($pf(similar(T, ax), dims)' * ChainRulesCore.unthunk(ȳ))
15+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
16+
end
17+
return y, fft_pullback
18+
end
8019
end
81-
return y, irfft_pullback
8220
end
8321

84-
function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims)
85-
y = bfft(x, dims)
86-
Δy = bfft(Δx, dims)
87-
return y, Δy
88-
end
89-
function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims)
90-
y = bfft(x, dims)
91-
project_x = ChainRulesCore.ProjectTo(x)
92-
function bfft_pullback(ȳ)
93-
= project_x(fft(ChainRulesCore.unthunk(ȳ), dims))
94-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
22+
for f in (:brfft, :irfft)
23+
pf = Symbol("plan_", f)
24+
@eval begin
25+
function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, d::Int, dims)
26+
y = $f(x, d::Int, dims)
27+
Δy = $f(Δx, d::Int, dims)
28+
return y, Δy
29+
end
30+
function ChainRulesCore.rrule(::typeof($f), x::T, d::Int, dims) where {T<:AbstractArray}
31+
y = $f(x, d, dims)
32+
project_x = ChainRulesCore.ProjectTo(x)
33+
ax = axes(x)
34+
function fft_pullback(ȳ)
35+
= project_x($pf(similar(T, ax), d, dims)' * ChainRulesCore.unthunk(ȳ))
36+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
37+
end
38+
return y, fft_pullback
39+
end
9540
end
96-
return y, bfft_pullback
97-
end
98-
99-
function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims)
100-
y = brfft(x, d, dims)
101-
Δy = brfft(Δx, d, dims)
102-
return y, Δy
103-
end
104-
function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
105-
y = brfft(x, d, dims)
106-
107-
# compute scaling factors
108-
halfdim = first(dims)
109-
n = size(x, halfdim)
110-
scale = reshape(
111-
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
112-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
113-
)
114-
115-
project_x = ChainRulesCore.ProjectTo(x)
116-
function brfft_pullback(ȳ)
117-
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
118-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
119-
end
120-
return y, brfft_pullback
12141
end
12242

12343
# shift functions
@@ -150,3 +70,19 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
15070
end
15171
return y, ifftshift_pullback
15272
end
73+
74+
# plans
75+
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::Plan, x::AbstractArray)
76+
y = P * x
77+
Δy = P * Δx
78+
return y, Δy
79+
end
80+
function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
81+
y = P * x
82+
project_x = ChainRulesCore.ProjectTo(x)
83+
function fft_pullback(ȳ)
84+
= project_x(P' * ȳ)
85+
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
86+
end
87+
return y, fft_pullback
88+
end

0 commit comments

Comments
 (0)