|
| 1 | +# ffts |
| 2 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims) |
| 3 | + y = fft(x, dims) |
| 4 | + Δy = fft(Δx, dims) |
| 5 | + return y, Δy |
| 6 | +end |
| 7 | +function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims) |
| 8 | + y = fft(x, dims) |
| 9 | + project_x = ChainRulesCore.ProjectTo(x) |
| 10 | + function fft_pullback(ȳ) |
| 11 | + x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims)) |
| 12 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 13 | + end |
| 14 | + return y, fft_pullback |
| 15 | +end |
| 16 | + |
| 17 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims) |
| 18 | + y = rfft(x, dims) |
| 19 | + Δy = rfft(Δx, dims) |
| 20 | + return y, Δy |
| 21 | +end |
| 22 | +function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) |
| 23 | + y = rfft(x, dims) |
| 24 | + |
| 25 | + # compute scaling factors |
| 26 | + halfdim = first(dims) |
| 27 | + d = size(x, halfdim) |
| 28 | + n = size(y, halfdim) |
| 29 | + scale = reshape( |
| 30 | + [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], |
| 31 | + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
| 32 | + ) |
| 33 | + |
| 34 | + project_x = ChainRulesCore.ProjectTo(x) |
| 35 | + function rfft_pullback(ȳ) |
| 36 | + x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) |
| 37 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 38 | + end |
| 39 | + return y, rfft_pullback |
| 40 | +end |
| 41 | + |
| 42 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims) |
| 43 | + y = ifft(x, dims) |
| 44 | + Δy = ifft(Δx, dims) |
| 45 | + return y, Δy |
| 46 | +end |
| 47 | +function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims) |
| 48 | + y = ifft(x, dims) |
| 49 | + invN = normalization(y, dims) |
| 50 | + project_x = ChainRulesCore.ProjectTo(x) |
| 51 | + function ifft_pullback(ȳ) |
| 52 | + x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims)) |
| 53 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 54 | + end |
| 55 | + return y, ifft_pullback |
| 56 | +end |
| 57 | + |
| 58 | +function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims) |
| 59 | + y = irfft(x, d, dims) |
| 60 | + Δy = irfft(Δx, d, dims) |
| 61 | + return y, Δy |
| 62 | +end |
| 63 | +function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) |
| 64 | + y = irfft(x, d, dims) |
| 65 | + |
| 66 | + # compute scaling factors |
| 67 | + halfdim = first(dims) |
| 68 | + n = size(x, halfdim) |
| 69 | + invN = normalization(y, dims) |
| 70 | + twoinvN = 2 * invN |
| 71 | + scale = reshape( |
| 72 | + [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], |
| 73 | + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
| 74 | + ) |
| 75 | + |
| 76 | + project_x = ChainRulesCore.ProjectTo(x) |
| 77 | + function irfft_pullback(ȳ) |
| 78 | + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) |
| 79 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
| 80 | + end |
| 81 | + return y, irfft_pullback |
| 82 | +end |
| 83 | + |
| 84 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims) |
| 85 | + y = bfft(x, dims) |
| 86 | + Δy = bfft(Δx, dims) |
| 87 | + return y, Δy |
| 88 | +end |
| 89 | +function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims) |
| 90 | + y = bfft(x, dims) |
| 91 | + project_x = ChainRulesCore.ProjectTo(x) |
| 92 | + function bfft_pullback(ȳ) |
| 93 | + x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims)) |
| 94 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 95 | + end |
| 96 | + return y, bfft_pullback |
| 97 | +end |
| 98 | + |
| 99 | +function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims) |
| 100 | + y = brfft(x, d, dims) |
| 101 | + Δy = brfft(Δx, d, dims) |
| 102 | + return y, Δy |
| 103 | +end |
| 104 | +function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) |
| 105 | + y = brfft(x, d, dims) |
| 106 | + |
| 107 | + # compute scaling factors |
| 108 | + halfdim = first(dims) |
| 109 | + n = size(x, halfdim) |
| 110 | + scale = reshape( |
| 111 | + [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], |
| 112 | + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
| 113 | + ) |
| 114 | + |
| 115 | + project_x = ChainRulesCore.ProjectTo(x) |
| 116 | + function brfft_pullback(ȳ) |
| 117 | + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) |
| 118 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
| 119 | + end |
| 120 | + return y, brfft_pullback |
| 121 | +end |
| 122 | + |
| 123 | +# shift functions |
| 124 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(fftshift), x::AbstractArray, dims) |
| 125 | + y = fftshift(x, dims) |
| 126 | + Δy = fftshift(Δx, dims) |
| 127 | + return y, Δy |
| 128 | +end |
| 129 | +function ChainRulesCore.rrule(::typeof(fftshift), x::AbstractArray, dims) |
| 130 | + y = fftshift(x, dims) |
| 131 | + project_x = ChainRulesCore.ProjectTo(x) |
| 132 | + function fftshift_pullback(ȳ) |
| 133 | + x̄ = project_x(ifftshift(ChainRulesCore.unthunk(ȳ), dims)) |
| 134 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 135 | + end |
| 136 | + return y, fftshift_pullback |
| 137 | +end |
| 138 | + |
| 139 | +function ChainRulesCore.frule((_, Δx, _), ::typeof(ifftshift), x::AbstractArray, dims) |
| 140 | + y = ifftshift(x, dims) |
| 141 | + Δy = ifftshift(Δx, dims) |
| 142 | + return y, Δy |
| 143 | +end |
| 144 | +function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) |
| 145 | + y = ifftshift(x, dims) |
| 146 | + project_x = ChainRulesCore.ProjectTo(x) |
| 147 | + function ifftshift_pullback(ȳ) |
| 148 | + x̄ = project_x(fftshift(ChainRulesCore.unthunk(ȳ), dims)) |
| 149 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 150 | + end |
| 151 | + return y, ifftshift_pullback |
| 152 | +end |
0 commit comments