Skip to content

Commit 2bae074

Browse files
authored
Add ChainRules definitions (#58)
* Add ChainRules definitions * Add tests * Move test plans to separate file * Fix type inference of `fftshift` and `ifftshift` on Julia 1.0 * Disable type inference checks for `fftshift` and `ifftshift` in old Julia versions * Bump version
1 parent d007201 commit 2bae074

File tree

6 files changed

+536
-75
lines changed

6 files changed

+536
-75
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
name = "AbstractFFTs"
22
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3-
version = "1.0.1"
3+
version = "1.1.0"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78

89
[compat]
10+
ChainRulesCore = "1"
911
julia = "^1.0"
1012

1113
[extras]
14+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
15+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1216
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1317
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1418

1519
[targets]
16-
test = ["Test", "Unitful"]
20+
test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"]

src/AbstractFFTs.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
module AbstractFFTs
22

3+
import ChainRulesCore
4+
35
export fft, ifft, bfft, fft!, ifft!, bfft!,
46
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
57
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
68
fftshift, ifftshift, Frequencies, fftfreq, rfftfreq
79

810
include("definitions.jl")
11+
include("chainrules.jl")
912

1013
end # module

src/chainrules.jl

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# ffts
2+
function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims)
3+
y = fft(x, dims)
4+
Δy = fft(Δx, dims)
5+
return y, Δy
6+
end
7+
function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims)
8+
y = fft(x, dims)
9+
project_x = ChainRulesCore.ProjectTo(x)
10+
function fft_pullback(ȳ)
11+
= project_x(bfft(ChainRulesCore.unthunk(ȳ), dims))
12+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
13+
end
14+
return y, fft_pullback
15+
end
16+
17+
function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims)
18+
y = rfft(x, dims)
19+
Δy = rfft(Δx, dims)
20+
return y, Δy
21+
end
22+
function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
23+
y = rfft(x, dims)
24+
25+
# compute scaling factors
26+
halfdim = first(dims)
27+
d = size(x, halfdim)
28+
n = size(y, halfdim)
29+
scale = reshape(
30+
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
31+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
32+
)
33+
34+
project_x = ChainRulesCore.ProjectTo(x)
35+
function rfft_pullback(ȳ)
36+
= project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims))
37+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
38+
end
39+
return y, rfft_pullback
40+
end
41+
42+
function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims)
43+
y = ifft(x, dims)
44+
Δy = ifft(Δx, dims)
45+
return y, Δy
46+
end
47+
function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims)
48+
y = ifft(x, dims)
49+
invN = normalization(y, dims)
50+
project_x = ChainRulesCore.ProjectTo(x)
51+
function ifft_pullback(ȳ)
52+
= project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims))
53+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
54+
end
55+
return y, ifft_pullback
56+
end
57+
58+
function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims)
59+
y = irfft(x, d, dims)
60+
Δy = irfft(Δx, d, dims)
61+
return y, Δy
62+
end
63+
function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
64+
y = irfft(x, d, dims)
65+
66+
# compute scaling factors
67+
halfdim = first(dims)
68+
n = size(x, halfdim)
69+
invN = normalization(y, dims)
70+
twoinvN = 2 * invN
71+
scale = reshape(
72+
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
73+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
74+
)
75+
76+
project_x = ChainRulesCore.ProjectTo(x)
77+
function irfft_pullback(ȳ)
78+
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
79+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
80+
end
81+
return y, irfft_pullback
82+
end
83+
84+
function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims)
85+
y = bfft(x, dims)
86+
Δy = bfft(Δx, dims)
87+
return y, Δy
88+
end
89+
function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims)
90+
y = bfft(x, dims)
91+
project_x = ChainRulesCore.ProjectTo(x)
92+
function bfft_pullback(ȳ)
93+
= project_x(fft(ChainRulesCore.unthunk(ȳ), dims))
94+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
95+
end
96+
return y, bfft_pullback
97+
end
98+
99+
function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims)
100+
y = brfft(x, d, dims)
101+
Δy = brfft(Δx, d, dims)
102+
return y, Δy
103+
end
104+
function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
105+
y = brfft(x, d, dims)
106+
107+
# compute scaling factors
108+
halfdim = first(dims)
109+
n = size(x, halfdim)
110+
scale = reshape(
111+
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
112+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
113+
)
114+
115+
project_x = ChainRulesCore.ProjectTo(x)
116+
function brfft_pullback(ȳ)
117+
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
118+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
119+
end
120+
return y, brfft_pullback
121+
end
122+
123+
# shift functions
124+
function ChainRulesCore.frule((_, Δx, _), ::typeof(fftshift), x::AbstractArray, dims)
125+
y = fftshift(x, dims)
126+
Δy = fftshift(Δx, dims)
127+
return y, Δy
128+
end
129+
function ChainRulesCore.rrule(::typeof(fftshift), x::AbstractArray, dims)
130+
y = fftshift(x, dims)
131+
project_x = ChainRulesCore.ProjectTo(x)
132+
function fftshift_pullback(ȳ)
133+
= project_x(ifftshift(ChainRulesCore.unthunk(ȳ), dims))
134+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
135+
end
136+
return y, fftshift_pullback
137+
end
138+
139+
function ChainRulesCore.frule((_, Δx, _), ::typeof(ifftshift), x::AbstractArray, dims)
140+
y = ifftshift(x, dims)
141+
Δy = ifftshift(Δx, dims)
142+
return y, Δy
143+
end
144+
function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
145+
y = ifftshift(x, dims)
146+
project_x = ChainRulesCore.ProjectTo(x)
147+
function ifftshift_pullback(ȳ)
148+
= project_x(fftshift(ChainRulesCore.unthunk(ȳ), dims))
149+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
150+
end
151+
return y, ifftshift_pullback
152+
end

src/definitions.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))
256256
*(p::Plan, I::UniformScaling) = ScaledPlan(p, I.λ)
257257

258258
# Normalization for ifft, given unscaled bfft, is 1/prod(dimensions)
259-
normalization(::Type{T}, sz, region) where T = one(T) / Int(prod([sz...][[region...]]))::Int
259+
normalization(::Type{T}, sz, region) where T = one(T) / Int(prod(sz[r] for r in region))::Int
260260
normalization(X, region) = normalization(real(eltype(X)), size(X), region)
261261

262262
plan_ifft(x::AbstractArray, region; kws...) =
@@ -360,7 +360,7 @@ If `dim` is not given then the signal is shifted along each dimension.
360360
fftshift
361361

362362
function fftshift(x, dim = 1:ndims(x))
363-
s = ntuple(d -> d in dim ? div(size(x,d),2) : 0, ndims(x))
363+
s = ntuple(d -> d in dim ? div(size(x,d),2) : 0, Val(ndims(x)))
364364
circshift(x, s)
365365
end
366366

@@ -380,7 +380,7 @@ If `dim` is not given then the signal is shifted along each dimension.
380380
ifftshift
381381

382382
function ifftshift(x, dim = 1:ndims(x))
383-
s = ntuple(d -> d in dim ? -div(size(x,d),2) : 0, ndims(x))
383+
s = ntuple(d -> d in dim ? -div(size(x,d),2) : 0, Val(ndims(x)))
384384
circshift(x, s)
385385
end
386386

0 commit comments

Comments
 (0)