Skip to content

Commit b2dd69c

Browse files
authored
Make ChainRulesCore a weak dependency on Julia >= 1.9 (#85)
* Make ChainRulesCore a weak dependency on Julia >= 1.9 * Qualify `normalization` * Check on nightly if extension works correctly
1 parent 7d698db commit b2dd69c

File tree

4 files changed

+23
-9
lines changed

4 files changed

+23
-9
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
version:
1616
- '1.0'
1717
- '1'
18-
# - 'nightly'
18+
- 'nightly'
1919
os:
2020
- ubuntu-latest
2121
- macOS-latest

Project.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
name = "AbstractFFTs"
22
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3-
version = "1.2.1"
3+
version = "1.3.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

9+
[weakdeps]
10+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
11+
12+
[extensions]
13+
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
14+
915
[compat]
1016
ChainRulesCore = "1"
1117
julia = "^1.0"
1218

1319
[extras]
20+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1421
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
1522
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1623
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1724
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1825

1926
[targets]
20-
test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"]
27+
test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"]

src/chainrules.jl renamed to ext/AbstractFFTsChainRulesCoreExt.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# ffts
1+
module AbstractFFTsChainRulesCoreExt
2+
3+
using AbstractFFTs
4+
import ChainRulesCore
5+
26
function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims)
37
y = fft(x, dims)
48
Δy = fft(Δx, dims)
@@ -46,7 +50,7 @@ function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dim
4650
end
4751
function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims)
4852
y = ifft(x, dims)
49-
invN = normalization(y, dims)
53+
invN = AbstractFFTs.normalization(y, dims)
5054
project_x = ChainRulesCore.ProjectTo(x)
5155
function ifft_pullback(ȳ)
5256
= project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims))
@@ -66,7 +70,7 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
6670
# compute scaling factors
6771
halfdim = first(dims)
6872
n = size(x, halfdim)
69-
invN = normalization(y, dims)
73+
invN = AbstractFFTs.normalization(y, dims)
7074
twoinvN = 2 * invN
7175
scale = reshape(
7276
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
@@ -150,3 +154,5 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
150154
end
151155
return y, ifftshift_pullback
152156
end
157+
158+
end # module

src/AbstractFFTs.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
module AbstractFFTs
22

3-
import ChainRulesCore
4-
53
export fft, ifft, bfft, fft!, ifft!, bfft!,
64
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
75
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
86
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq
97

108
include("definitions.jl")
11-
include("chainrules.jl")
9+
10+
if !isdefined(Base, :get_extension)
11+
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
12+
end
1213

1314
end # module

0 commit comments

Comments
 (0)