Skip to content

Commit 6fd67a4

Browse files
gaurav-aryaGaurav Arya
authored andcommitted
Revert plan-based rules for fft, rfft, etc.
1 parent 91102a1 commit 6fd67a4

File tree

1 file changed

+116
-36
lines changed

1 file changed

+116
-36
lines changed

src/chainrules.jl

Lines changed: 116 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,123 @@
1-
for f in (:fft, :bfft, :ifft, :rfft)
2-
pf = Symbol("plan_", f)
3-
@eval begin
4-
function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, dims)
5-
y = $f(x, dims)
6-
Δy = $f(Δx, dims)
7-
return y, Δy
8-
end
9-
function ChainRulesCore.rrule(::typeof($f), x::T, dims) where {T<:AbstractArray}
10-
y = $f(x, dims)
11-
project_x = ChainRulesCore.ProjectTo(x)
12-
ax = axes(x)
13-
function fft_pullback(ȳ)
14-
= project_x($pf(similar(T, ax), dims)' * ChainRulesCore.unthunk(ȳ))
15-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
16-
end
17-
return y, fft_pullback
18-
end
1+
# ffts
2+
function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims)
3+
y = fft(x, dims)
4+
Δy = fft(Δx, dims)
5+
return y, Δy
6+
end
7+
function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims)
8+
y = fft(x, dims)
9+
project_x = ChainRulesCore.ProjectTo(x)
10+
function fft_pullback(ȳ)
11+
= project_x(bfft(ChainRulesCore.unthunk(ȳ), dims))
12+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
13+
end
14+
return y, fft_pullback
15+
end
16+
17+
function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims)
18+
y = rfft(x, dims)
19+
Δy = rfft(Δx, dims)
20+
return y, Δy
21+
end
22+
function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
23+
y = rfft(x, dims)
24+
25+
# compute scaling factors
26+
halfdim = first(dims)
27+
d = size(x, halfdim)
28+
n = size(y, halfdim)
29+
scale = reshape(
30+
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
31+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
32+
)
33+
34+
project_x = ChainRulesCore.ProjectTo(x)
35+
function rfft_pullback(ȳ)
36+
= project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims))
37+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
38+
end
39+
return y, rfft_pullback
40+
end
41+
42+
function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims)
43+
y = ifft(x, dims)
44+
Δy = ifft(Δx, dims)
45+
return y, Δy
46+
end
47+
function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims)
48+
y = ifft(x, dims)
49+
invN = normalization(y, dims)
50+
project_x = ChainRulesCore.ProjectTo(x)
51+
function ifft_pullback(ȳ)
52+
= project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims))
53+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
1954
end
55+
return y, ifft_pullback
56+
end
57+
58+
function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims)
59+
y = irfft(x, d, dims)
60+
Δy = irfft(Δx, d, dims)
61+
return y, Δy
2062
end
63+
function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
64+
y = irfft(x, d, dims)
65+
66+
# compute scaling factors
67+
halfdim = first(dims)
68+
n = size(x, halfdim)
69+
invN = normalization(y, dims)
70+
twoinvN = 2 * invN
71+
scale = reshape(
72+
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
73+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
74+
)
2175

22-
for f in (:brfft, :irfft)
23-
pf = Symbol("plan_", f)
24-
@eval begin
25-
function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, d::Int, dims)
26-
y = $f(x, d::Int, dims)
27-
Δy = $f(Δx, d::Int, dims)
28-
return y, Δy
29-
end
30-
function ChainRulesCore.rrule(::typeof($f), x::T, d::Int, dims) where {T<:AbstractArray}
31-
y = $f(x, d, dims)
32-
project_x = ChainRulesCore.ProjectTo(x)
33-
ax = axes(x)
34-
function fft_pullback(ȳ)
35-
= project_x($pf(similar(T, ax), d, dims)' * ChainRulesCore.unthunk(ȳ))
36-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
37-
end
38-
return y, fft_pullback
39-
end
76+
project_x = ChainRulesCore.ProjectTo(x)
77+
function irfft_pullback(ȳ)
78+
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
79+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
80+
end
81+
return y, irfft_pullback
82+
end
83+
84+
function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims)
85+
y = bfft(x, dims)
86+
Δy = bfft(Δx, dims)
87+
return y, Δy
88+
end
89+
function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims)
90+
y = bfft(x, dims)
91+
project_x = ChainRulesCore.ProjectTo(x)
92+
function bfft_pullback(ȳ)
93+
= project_x(fft(ChainRulesCore.unthunk(ȳ), dims))
94+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
95+
end
96+
return y, bfft_pullback
97+
end
98+
99+
function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims)
100+
y = brfft(x, d, dims)
101+
Δy = brfft(Δx, d, dims)
102+
return y, Δy
103+
end
104+
function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
105+
y = brfft(x, d, dims)
106+
107+
# compute scaling factors
108+
halfdim = first(dims)
109+
n = size(x, halfdim)
110+
scale = reshape(
111+
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
112+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
113+
)
114+
115+
project_x = ChainRulesCore.ProjectTo(x)
116+
function brfft_pullback(ȳ)
117+
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
118+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
40119
end
120+
return y, brfft_pullback
41121
end
42122

43123
# shift functions

0 commit comments

Comments
 (0)