Skip to content

Commit e24c190

Browse files
authored
Merge pull request #67 from JuliaDiff/ox/reorg
Big re-organization
2 parents 91c5d92 + b6cf13f commit e24c190

33 files changed

+174
-1027
lines changed

Project.toml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.0.1"
3+
version = "0.1.0"
44

55
[deps]
6-
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8-
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
9-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
8+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
9+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
13-
Cassette = "^0.2"
13+
ChainRulesCore = "^0.1"
1414
FDM = "^0.6"
1515
julia = "^1.0"
1616

1717
[extras]
1818
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
19+
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
21+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2022
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2123

2224
[targets]
23-
test = ["FDM", "Random", "Test"]
25+
test = ["FDM", "Random", "Test", "SpecialFunctions", "NaNMath"]

REQUIRE

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/ChainRules.jl

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
module ChainRules
2+
using Reexport
3+
@reexport using ChainRulesCore
4+
# Basically everything this package does is overloading these, so we make an exception
5+
# to the normal rule of only overload via `AbstractChainRules.rrule`.
6+
import ChainRulesCore: rrule, frule
7+
8+
# Deal with name clashes, by defining in this module which one we mean.
9+
const accumulate = ChainRulesCore.accumulate
10+
const accumulate! = ChainRulesCore.accumulate!
11+
212

3-
using Cassette
413
using LinearAlgebra
514
using LinearAlgebra.BLAS
15+
using Requires
616
using Statistics
717
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
818

@@ -13,22 +23,32 @@ if VERSION < v"1.3.0-DEV.142"
1323
import LinearAlgebra: dot
1424
end
1525

16-
import NaNMath, SpecialFunctions
17-
18-
export AbstractRule, Rule, frule, rrule
19-
20-
include("differentials.jl")
21-
include("rules.jl")
22-
include("rules/base.jl")
23-
include("rules/array.jl")
24-
include("rules/broadcast.jl")
25-
include("rules/mapreduce.jl")
26-
include("rules/linalg/utils.jl")
27-
include("rules/linalg/blas.jl")
28-
include("rules/linalg/dense.jl")
29-
include("rules/linalg/structured.jl")
30-
include("rules/linalg/factorization.jl")
31-
include("rules/nanmath.jl")
32-
include("rules/specialfunctions.jl")
26+
include("helper_functions.jl")
27+
28+
include("rulesets/Base/base.jl")
29+
include("rulesets/Base/array.jl")
30+
include("rulesets/Base/broadcast.jl")
31+
include("rulesets/Base/mapreduce.jl")
32+
33+
include("rulesets/LinearAlgebra/utils.jl")
34+
include("rulesets/LinearAlgebra/blas.jl")
35+
include("rulesets/LinearAlgebra/dense.jl")
36+
include("rulesets/LinearAlgebra/structured.jl")
37+
include("rulesets/LinearAlgebra/factorization.jl")
38+
39+
# Note: The following is only required because package authors sometimes do not
40+
# declare their own rules using `ChainRulesCore.jl`. For arguably good reasons.
41+
# So we define them here for them.
42+
function __init__()
43+
@require NaNMath="77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" begin
44+
include("rulesets/packages/NaNMath.jl")
45+
using .NaNMathGlue
46+
end
47+
48+
@require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin
49+
include("rulesets/packages/SpecialFunctions.jl")
50+
using .SpecialFunctionsGlue
51+
end
52+
end
3353

3454
end # module

0 commit comments

Comments
 (0)