Skip to content

Commit 70c1ce8

Browse files
committed
Split core out into AbstractChainRules,
Requires-ize 3rd party packages. Reorganize all files
1 parent 91c5d92 commit 70c1ce8

33 files changed

+168
-1027
lines changed

Project.toml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.0.1"
3+
version = "0.1.0"
44

55
[deps]
6-
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
6+
AbstractChainRules = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8-
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
9-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
8+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
9+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
13-
Cassette = "^0.2"
1413
FDM = "^0.6"
1514
julia = "^1.0"
15+
AbstractChainRules = "^0.1"
1616

1717
[extras]
1818
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
19+
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
21+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2022
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2123

2224
[targets]
23-
test = ["FDM", "Random", "Test"]
25+
test = ["FDM", "Random", "Test", "SpecialFunctions", "NaNMath"]

REQUIRE

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/ChainRules.jl

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
module ChainRules
2+
using Reexport
3+
@reexport using AbstractChainRules
4+
# basically everything this package does is overloading these
5+
import AbstractChainRules: rrule, frule
6+
7+
# deal with name clashes
8+
const accumulate = AbstractChainRules.accumulate
9+
const accumulate! = AbstractChainRules.accumulate!
10+
211

3-
using Cassette
412
using LinearAlgebra
513
using LinearAlgebra.BLAS
14+
using Requires
615
using Statistics
716
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
817

@@ -13,22 +22,32 @@ if VERSION < v"1.3.0-DEV.142"
1322
import LinearAlgebra: dot
1423
end
1524

16-
import NaNMath, SpecialFunctions
17-
18-
export AbstractRule, Rule, frule, rrule
19-
20-
include("differentials.jl")
21-
include("rules.jl")
22-
include("rules/base.jl")
23-
include("rules/array.jl")
24-
include("rules/broadcast.jl")
25-
include("rules/mapreduce.jl")
26-
include("rules/linalg/utils.jl")
27-
include("rules/linalg/blas.jl")
28-
include("rules/linalg/dense.jl")
29-
include("rules/linalg/structured.jl")
30-
include("rules/linalg/factorization.jl")
31-
include("rules/nanmath.jl")
32-
include("rules/specialfunctions.jl")
25+
include("helper_functions.jl")
26+
27+
include("rulesets/Base/base.jl")
28+
include("rulesets/Base/array.jl")
29+
include("rulesets/Base/broadcast.jl")
30+
include("rulesets/Base/mapreduce.jl")
31+
32+
include("rulesets/LinearAlgebra/utils.jl")
33+
include("rulesets/LinearAlgebra/blas.jl")
34+
include("rulesets/LinearAlgebra/dense.jl")
35+
include("rulesets/LinearAlgebra/structured.jl")
36+
include("rulesets/LinearAlgebra/factorization.jl")
37+
38+
# Note: The following is only required because package authors do not use
39+
# declare their own rules using AbstractChainRules. For arguably good reasons.
40+
# so we define them here for them.
41+
function __init__()
42+
@require NaNMath="77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" begin
43+
include("rulesets/packages/NaNMath.jl")
44+
using .NaNMathGlue
45+
end
46+
47+
@require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin
48+
include("rulesets/packages/SpecialFunctions.jl")
49+
using .SpecialFunctionGlue
50+
end
51+
end
3352

3453
end # module

0 commit comments

Comments
 (0)