|
1 |
| -# ffts |
2 |
| -function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims) |
3 |
| - y = fft(x, dims) |
4 |
| - Δy = fft(Δx, dims) |
5 |
| - return y, Δy |
6 |
| -end |
7 |
| -function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims) |
8 |
| - y = fft(x, dims) |
9 |
| - project_x = ChainRulesCore.ProjectTo(x) |
10 |
| - function fft_pullback(ȳ) |
11 |
| - x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims)) |
12 |
| - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
13 |
| - end |
14 |
| - return y, fft_pullback |
15 |
| -end |
16 |
| - |
17 |
| -function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims) |
18 |
| - y = rfft(x, dims) |
19 |
| - Δy = rfft(Δx, dims) |
20 |
| - return y, Δy |
21 |
| -end |
22 |
| -function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) |
23 |
| - y = rfft(x, dims) |
24 |
| - |
25 |
| - # compute scaling factors |
26 |
| - halfdim = first(dims) |
27 |
| - d = size(x, halfdim) |
28 |
| - n = size(y, halfdim) |
29 |
| - scale = reshape( |
30 |
| - [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], |
31 |
| - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
32 |
| - ) |
33 |
| - |
34 |
| - project_x = ChainRulesCore.ProjectTo(x) |
35 |
| - function rfft_pullback(ȳ) |
36 |
| - x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) |
37 |
| - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
38 |
| - end |
39 |
| - return y, rfft_pullback |
40 |
| -end |
41 |
| - |
42 |
| -function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims) |
43 |
| - y = ifft(x, dims) |
44 |
| - Δy = ifft(Δx, dims) |
45 |
| - return y, Δy |
46 |
| -end |
47 |
| -function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims) |
48 |
| - y = ifft(x, dims) |
49 |
| - invN = normalization(y, dims) |
50 |
| - project_x = ChainRulesCore.ProjectTo(x) |
51 |
| - function ifft_pullback(ȳ) |
52 |
| - x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims)) |
53 |
| - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
54 |
| - end |
55 |
| - return y, ifft_pullback |
56 |
| -end |
57 |
| - |
58 |
| -function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims) |
59 |
| - y = irfft(x, d, dims) |
60 |
| - Δy = irfft(Δx, d, dims) |
61 |
| - return y, Δy |
62 |
| -end |
63 |
| -function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) |
64 |
| - y = irfft(x, d, dims) |
65 |
| - |
66 |
| - # compute scaling factors |
67 |
| - halfdim = first(dims) |
68 |
| - n = size(x, halfdim) |
69 |
| - invN = normalization(y, dims) |
70 |
| - twoinvN = 2 * invN |
71 |
| - scale = reshape( |
72 |
| - [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], |
73 |
| - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
74 |
| - ) |
75 |
| - |
76 |
| - project_x = ChainRulesCore.ProjectTo(x) |
77 |
| - function irfft_pullback(ȳ) |
78 |
| - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) |
79 |
| - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
| 1 | +for f in (:fft, :bfft, :ifft, :rfft) |
| 2 | + pf = Symbol("plan_", f) |
| 3 | + @eval begin |
| 4 | + function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, dims) |
| 5 | + y = $f(x, dims) |
| 6 | + Δy = $f(Δx, dims) |
| 7 | + return y, Δy |
| 8 | + end |
| 9 | + function ChainRulesCore.rrule(::typeof($f), x::T, dims) where {T<:AbstractArray} |
| 10 | + y = $f(x, dims) |
| 11 | + project_x = ChainRulesCore.ProjectTo(x) |
| 12 | + ax = axes(x) |
| 13 | + function fft_pullback(ȳ) |
| 14 | + x̄ = project_x($pf(similar(T, ax), dims)' * ChainRulesCore.unthunk(ȳ)) |
| 15 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 16 | + end |
| 17 | + return y, fft_pullback |
| 18 | + end |
80 | 19 | end
|
81 |
| - return y, irfft_pullback |
82 | 20 | end
|
83 | 21 |
|
84 |
| -function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims) |
85 |
| - y = bfft(x, dims) |
86 |
| - Δy = bfft(Δx, dims) |
87 |
| - return y, Δy |
88 |
| -end |
89 |
| -function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims) |
90 |
| - y = bfft(x, dims) |
91 |
| - project_x = ChainRulesCore.ProjectTo(x) |
92 |
| - function bfft_pullback(ȳ) |
93 |
| - x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims)) |
94 |
| - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 22 | +for f in (:brfft, :irfft) |
| 23 | + pf = Symbol("plan_", f) |
| 24 | + @eval begin |
| 25 | + function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, d::Int, dims) |
| 26 | + y = $f(x, d::Int, dims) |
| 27 | + Δy = $f(Δx, d::Int, dims) |
| 28 | + return y, Δy |
| 29 | + end |
| 30 | + function ChainRulesCore.rrule(::typeof($f), x::T, d::Int, dims) where {T<:AbstractArray} |
| 31 | + y = $f(x, d, dims) |
| 32 | + project_x = ChainRulesCore.ProjectTo(x) |
| 33 | + ax = axes(x) |
| 34 | + function fft_pullback(ȳ) |
| 35 | + x̄ = project_x($pf(similar(T, ax), d, dims)' * ChainRulesCore.unthunk(ȳ)) |
| 36 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
| 37 | + end |
| 38 | + return y, fft_pullback |
| 39 | + end |
95 | 40 | end
|
96 |
| - return y, bfft_pullback |
97 |
| -end |
98 |
| - |
99 |
| -function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims) |
100 |
| - y = brfft(x, d, dims) |
101 |
| - Δy = brfft(Δx, d, dims) |
102 |
| - return y, Δy |
103 |
| -end |
104 |
| -function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) |
105 |
| - y = brfft(x, d, dims) |
106 |
| - |
107 |
| - # compute scaling factors |
108 |
| - halfdim = first(dims) |
109 |
| - n = size(x, halfdim) |
110 |
| - scale = reshape( |
111 |
| - [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], |
112 |
| - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
113 |
| - ) |
114 |
| - |
115 |
| - project_x = ChainRulesCore.ProjectTo(x) |
116 |
| - function brfft_pullback(ȳ) |
117 |
| - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) |
118 |
| - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
119 |
| - end |
120 |
| - return y, brfft_pullback |
121 | 41 | end
|
122 | 42 |
|
123 | 43 | # shift functions
|
@@ -150,3 +70,19 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
|
150 | 70 | end
|
151 | 71 | return y, ifftshift_pullback
|
152 | 72 | end
|
| 73 | + |
| 74 | +# plans |
| 75 | +function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::Plan, x::AbstractArray) |
| 76 | + y = P * x |
| 77 | + Δy = P * Δx |
| 78 | + return y, Δy |
| 79 | +end |
| 80 | +function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray) |
| 81 | + y = P * x |
| 82 | + project_x = ChainRulesCore.ProjectTo(x) |
| 83 | + function fft_pullback(ȳ) |
| 84 | + x̄ = project_x(P' * ȳ) |
| 85 | + return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄ |
| 86 | + end |
| 87 | + return y, fft_pullback |
| 88 | +end |
0 commit comments