1
1
module ChainRules
2
+ using Reexport
3
+ @reexport using AbstractChainRules
4
+ # basically everything this package does is overloading these
5
+ import AbstractChainRules: rrule, frule
6
+
7
+ # deal with name clashes
8
+ const accumulate = AbstractChainRules. accumulate
9
+ const accumulate! = AbstractChainRules. accumulate!
10
+
2
11
3
- using Cassette
4
12
using LinearAlgebra
5
13
using LinearAlgebra. BLAS
14
+ using Requires
6
15
using Statistics
7
16
using Base. Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
8
17
@@ -13,22 +22,32 @@ if VERSION < v"1.3.0-DEV.142"
13
22
import LinearAlgebra: dot
14
23
end
15
24
16
- import NaNMath, SpecialFunctions
17
-
18
- export AbstractRule, Rule, frule, rrule
19
-
20
- include (" differentials.jl" )
21
- include (" rules.jl" )
22
- include (" rules/base.jl" )
23
- include (" rules/array.jl" )
24
- include (" rules/broadcast.jl" )
25
- include (" rules/mapreduce.jl" )
26
- include (" rules/linalg/utils.jl" )
27
- include (" rules/linalg/blas.jl" )
28
- include (" rules/linalg/dense.jl" )
29
- include (" rules/linalg/structured.jl" )
30
- include (" rules/linalg/factorization.jl" )
31
- include (" rules/nanmath.jl" )
32
- include (" rules/specialfunctions.jl" )
25
+ include (" helper_functions.jl" )
26
+
27
+ include (" rulesets/Base/base.jl" )
28
+ include (" rulesets/Base/array.jl" )
29
+ include (" rulesets/Base/broadcast.jl" )
30
+ include (" rulesets/Base/mapreduce.jl" )
31
+
32
+ include (" rulesets/LinearAlgebra/utils.jl" )
33
+ include (" rulesets/LinearAlgebra/blas.jl" )
34
+ include (" rulesets/LinearAlgebra/dense.jl" )
35
+ include (" rulesets/LinearAlgebra/structured.jl" )
36
+ include (" rulesets/LinearAlgebra/factorization.jl" )
37
+
38
+ # Note: The following is only required because package authors do not use
39
+ # declare their own rules using AbstractChainRules. For arguably good reasons.
40
+ # so we define them here for them.
41
+ function __init__ ()
42
+ @require NaNMath= " 77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" begin
43
+ include (" rulesets/packages/NaNMath.jl" )
44
+ using . NaNMathGlue
45
+ end
46
+
47
+ @require SpecialFunctions= " 276daf66-3868-5448-9aa4-cd146d93841b" begin
48
+ include (" rulesets/packages/SpecialFunctions.jl" )
49
+ using . SpecialFunctionGlue
50
+ end
51
+ end
33
52
34
53
end # module
0 commit comments