|
1 |
| -for f in (:fft, :bfft, :ifft, :rfft) |
2 |
| - pf = Symbol("plan_", f) |
3 |
| - @eval begin |
4 |
| - function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, dims) |
5 |
| - y = $f(x, dims) |
6 |
| - Δy = $f(Δx, dims) |
7 |
| - return y, Δy |
8 |
| - end |
9 |
| - function ChainRulesCore.rrule(::typeof($f), x::T, dims) where {T<:AbstractArray} |
10 |
| - y = $f(x, dims) |
11 |
| - project_x = ChainRulesCore.ProjectTo(x) |
12 |
| - ax = axes(x) |
13 |
| - function fft_pullback(ȳ) |
14 |
| - x̄ = project_x($pf(similar(T, ax), dims)' * ChainRulesCore.unthunk(ȳ)) |
15 |
| - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
16 |
| - end |
17 |
| - return y, fft_pullback |
18 |
| - end |
| 1 | +# ffts |
| 2 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims) |
| 3 | + y = fft(x, dims) |
| 4 | + Δy = fft(Δx, dims) |
| 5 | + return y, Δy |
| 6 | +end |
| 7 | +function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims) |
| 8 | + y = fft(x, dims) |
| 9 | + project_x = ChainRulesCore.ProjectTo(x) |
| 10 | + function fft_pullback(ȳ) |
| 11 | + x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims)) |
| 12 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 13 | + end |
| 14 | + return y, fft_pullback |
| 15 | +end |
| 16 | + |
| 17 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims) |
| 18 | + y = rfft(x, dims) |
| 19 | + Δy = rfft(Δx, dims) |
| 20 | + return y, Δy |
| 21 | +end |
| 22 | +function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) |
| 23 | + y = rfft(x, dims) |
| 24 | + |
| 25 | + # compute scaling factors |
| 26 | + halfdim = first(dims) |
| 27 | + d = size(x, halfdim) |
| 28 | + n = size(y, halfdim) |
| 29 | + scale = reshape( |
| 30 | + [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], |
| 31 | + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
| 32 | + ) |
| 33 | + |
| 34 | + project_x = ChainRulesCore.ProjectTo(x) |
| 35 | + function rfft_pullback(ȳ) |
| 36 | + x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) |
| 37 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 38 | + end |
| 39 | + return y, rfft_pullback |
| 40 | +end |
| 41 | + |
| 42 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims) |
| 43 | + y = ifft(x, dims) |
| 44 | + Δy = ifft(Δx, dims) |
| 45 | + return y, Δy |
| 46 | +end |
| 47 | +function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims) |
| 48 | + y = ifft(x, dims) |
| 49 | + invN = normalization(y, dims) |
| 50 | + project_x = ChainRulesCore.ProjectTo(x) |
| 51 | + function ifft_pullback(ȳ) |
| 52 | + x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims)) |
| 53 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
19 | 54 | end
|
| 55 | + return y, ifft_pullback |
| 56 | +end |
| 57 | + |
| 58 | +function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims) |
| 59 | + y = irfft(x, d, dims) |
| 60 | + Δy = irfft(Δx, d, dims) |
| 61 | + return y, Δy |
20 | 62 | end
|
| 63 | +function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) |
| 64 | + y = irfft(x, d, dims) |
| 65 | + |
| 66 | + # compute scaling factors |
| 67 | + halfdim = first(dims) |
| 68 | + n = size(x, halfdim) |
| 69 | + invN = normalization(y, dims) |
| 70 | + twoinvN = 2 * invN |
| 71 | + scale = reshape( |
| 72 | + [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], |
| 73 | + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
| 74 | + ) |
21 | 75 |
|
22 |
| -for f in (:brfft, :irfft) |
23 |
| - pf = Symbol("plan_", f) |
24 |
| - @eval begin |
25 |
| - function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, d::Int, dims) |
26 |
| - y = $f(x, d::Int, dims) |
27 |
| - Δy = $f(Δx, d::Int, dims) |
28 |
| - return y, Δy |
29 |
| - end |
30 |
| - function ChainRulesCore.rrule(::typeof($f), x::T, d::Int, dims) where {T<:AbstractArray} |
31 |
| - y = $f(x, d, dims) |
32 |
| - project_x = ChainRulesCore.ProjectTo(x) |
33 |
| - ax = axes(x) |
34 |
| - function fft_pullback(ȳ) |
35 |
| - x̄ = project_x($pf(similar(T, ax), d, dims)' * ChainRulesCore.unthunk(ȳ)) |
36 |
| - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
37 |
| - end |
38 |
| - return y, fft_pullback |
39 |
| - end |
| 76 | + project_x = ChainRulesCore.ProjectTo(x) |
| 77 | + function irfft_pullback(ȳ) |
| 78 | + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) |
| 79 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
| 80 | + end |
| 81 | + return y, irfft_pullback |
| 82 | +end |
| 83 | + |
| 84 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims) |
| 85 | + y = bfft(x, dims) |
| 86 | + Δy = bfft(Δx, dims) |
| 87 | + return y, Δy |
| 88 | +end |
| 89 | +function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims) |
| 90 | + y = bfft(x, dims) |
| 91 | + project_x = ChainRulesCore.ProjectTo(x) |
| 92 | + function bfft_pullback(ȳ) |
| 93 | + x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims)) |
| 94 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 95 | + end |
| 96 | + return y, bfft_pullback |
| 97 | +end |
| 98 | + |
| 99 | +function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims) |
| 100 | + y = brfft(x, d, dims) |
| 101 | + Δy = brfft(Δx, d, dims) |
| 102 | + return y, Δy |
| 103 | +end |
| 104 | +function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) |
| 105 | + y = brfft(x, d, dims) |
| 106 | + |
| 107 | + # compute scaling factors |
| 108 | + halfdim = first(dims) |
| 109 | + n = size(x, halfdim) |
| 110 | + scale = reshape( |
| 111 | + [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], |
| 112 | + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
| 113 | + ) |
| 114 | + |
| 115 | + project_x = ChainRulesCore.ProjectTo(x) |
| 116 | + function brfft_pullback(ȳ) |
| 117 | + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) |
| 118 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
40 | 119 | end
|
| 120 | + return y, brfft_pullback |
41 | 121 | end
|
42 | 122 |
|
43 | 123 | # shift functions
|
|
0 commit comments