1
1
module ChainRules
2
+ using Reexport
3
+ @reexport using ChainRulesCore
4
+ # Basically everything this package does is overloading these, so we make an exception
5
+ # to the normal rule of only overload via `AbstractChainRules.rrule`.
6
+ import ChainRulesCore: rrule, frule
7
+
8
+ # Deal with name clashes, by defining in this module which one we mean.
9
+ const accumulate = ChainRulesCore. accumulate
10
+ const accumulate! = ChainRulesCore. accumulate!
11
+
2
12
3
- using Cassette
4
13
using LinearAlgebra
5
14
using LinearAlgebra. BLAS
15
+ using Requires
6
16
using Statistics
7
17
using Base. Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
8
18
@@ -13,22 +23,32 @@ if VERSION < v"1.3.0-DEV.142"
13
23
import LinearAlgebra: dot
14
24
end
15
25
16
- import NaNMath, SpecialFunctions
17
-
18
- export AbstractRule, Rule, frule, rrule
19
-
20
- include (" differentials.jl" )
21
- include (" rules.jl" )
22
- include (" rules/base.jl" )
23
- include (" rules/array.jl" )
24
- include (" rules/broadcast.jl" )
25
- include (" rules/mapreduce.jl" )
26
- include (" rules/linalg/utils.jl" )
27
- include (" rules/linalg/blas.jl" )
28
- include (" rules/linalg/dense.jl" )
29
- include (" rules/linalg/structured.jl" )
30
- include (" rules/linalg/factorization.jl" )
31
- include (" rules/nanmath.jl" )
32
- include (" rules/specialfunctions.jl" )
26
+ include (" helper_functions.jl" )
27
+
28
+ include (" rulesets/Base/base.jl" )
29
+ include (" rulesets/Base/array.jl" )
30
+ include (" rulesets/Base/broadcast.jl" )
31
+ include (" rulesets/Base/mapreduce.jl" )
32
+
33
+ include (" rulesets/LinearAlgebra/utils.jl" )
34
+ include (" rulesets/LinearAlgebra/blas.jl" )
35
+ include (" rulesets/LinearAlgebra/dense.jl" )
36
+ include (" rulesets/LinearAlgebra/structured.jl" )
37
+ include (" rulesets/LinearAlgebra/factorization.jl" )
38
+
39
+ # Note: The following is only required because package authors sometimes do not
40
+ # declare their own rules using `ChainRulesCore.jl`. For arguably good reasons.
41
+ # So we define them here for them.
42
+ function __init__ ()
43
+ @require NaNMath= " 77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" begin
44
+ include (" rulesets/packages/NaNMath.jl" )
45
+ using . NaNMathGlue
46
+ end
47
+
48
+ @require SpecialFunctions= " 276daf66-3868-5448-9aa4-cd146d93841b" begin
49
+ include (" rulesets/packages/SpecialFunctions.jl" )
50
+ using . SpecialFunctionsGlue
51
+ end
52
+ end
33
53
34
54
end # module
0 commit comments