1
- partial (x:: TangentBundle , i) = x. partials[i]
2
- partial (x:: TaylorBundle{1} , i) = x. coeffs[i]
3
- partial (x:: UniformBundle , i) = x. partial
4
- partial (x:: CompositeBundle{N, B} , i) where {N, B} = Tangent {B} (map (x-> partial (x, i), x. tup)... )
5
- partial (x:: ZeroTangent , i) = ZeroTangent ()
1
+ partial (x:: TangentBundle , i) = partial (getfield (x, :tangent ), i)
2
+ partial (x:: ExplicitTangent , i) = getfield (getfield (x, :partials ), i)
3
+ partial (x:: TaylorTangent , i) = getfield (getfield (x, :coeffs ), i)
4
+ partial (x:: UniformTangent , i) = getfield (x, :val )
5
+ partial (x:: ProductTangent , i) = ProductTangent (map (x-> partial (x, i), getfield (x, :factors )))
6
+ partial (x:: AbstractZero , i) = x
7
+ partial (x:: CompositeBundle{N, B} , i) where {N, B} = Tangent {B} (map (x-> partial (x, i), getfield (x, :tup ))... )
6
8
primal (x:: AbstractTangentBundle ) = x. primal
7
9
primal (z:: ZeroTangent ) = ZeroTangent ()
8
10
9
- first_partial (x:: TangentBundle{1} ) = getfield (getfield (x, :partials ), 1 )
10
- first_partial (x:: TaylorBundle{1} ) = getfield (getfield (x, :coeffs ), 1 )
11
- first_partial (x:: UniformBundle ) = getfield (x, :partial )
12
- first_partial (x:: CompositeBundle ) = map (first_partial, getfield (x, :tup ))
11
+ first_partial (x) = partial (x, 1 )
13
12
14
13
# TODO : Which version do we want in ChainRules?
15
14
function my_frule (args:: ATB{1} ...)
@@ -24,22 +23,22 @@ my_frule(::ZeroBundle{1, typeof(my_frule)}, args::ATB{1}...) = nothing
24
23
(:: ∂☆{N})(:: ZeroBundle{N, typeof(my_frule)} , :: ZeroBundle{N, ZeroBundle{1, typeof(my_frule)}} , args:: ATB{N} ...) where {N} = ZeroBundle {N} (nothing )
25
24
26
25
shuffle_down (b:: UniformBundle{N, B, U} ) where {N, B, U} =
27
- UniformBundle {minus1(N), <:Any, U} (UniformBundle {1, B, U} (b. primal, b. partial ), b. partial )
26
+ UniformBundle {minus1(N), <:Any, U} (UniformBundle {1, B, U} (b. primal, b. tangent . val ), b. tangent . val )
28
27
29
- function shuffle_down (b:: TangentBundle {N, B} ) where {N, B}
28
+ function shuffle_down (b:: ExplicitTangentBundle {N, B} ) where {N, B}
30
29
# N.B: This depends on the special properties of the canonical tangent index order
31
- TangentBundle {N-1} (
32
- TangentBundle {1} (b. primal, (partial (b, 1 ),)),
30
+ ExplicitTangentBundle {N-1} (
31
+ ExplicitTangentBundle {1} (b. primal, (partial (b, 1 ),)),
33
32
ntuple (2 ^ (N- 1 )- 1 ) do i
34
- TangentBundle {1} (partial (b, 2 * i), (partial (b, 2 * i+ 1 ),))
33
+ ExplicitTangentBundle {1} (partial (b, 2 * i), (partial (b, 2 * i+ 1 ),))
35
34
end )
36
35
end
37
36
38
37
function shuffle_down (b:: TaylorBundle{N, B} ) where {N, B}
39
38
TaylorBundle {N-1} (
40
- TangentBundle {1} (b. primal, (b. coeffs[1 ],)),
39
+ ExplicitTangentBundle {1} (b. primal, (b. tangent . coeffs[1 ],)),
41
40
ntuple (N- 1 ) do i
42
- TangentBundle {1} (b. coeffs[i], (b. coeffs[i+ 1 ],))
41
+ ExplicitTangentBundle {1} (b. tangent . coeffs[i], (b. tangent . coeffs[i+ 1 ],))
43
42
end )
44
43
end
45
44
@@ -60,7 +59,7 @@ function shuffle_up(r::CompositeBundle{1})
60
59
if z₁ == z₂
61
60
return TaylorBundle {2} (z₀, (z₁, z₁₂))
62
61
else
63
- return TangentBundle {2} (z₀, (z₁, z₂, z₁₂))
62
+ return ExplicitTangentBundle {2} (z₀, (z₁, z₂, z₁₂))
64
63
end
65
64
end
66
65
@@ -86,14 +85,14 @@ function shuffle_up(r::CompositeBundle{N}) where {N}
86
85
N+ 1 ))
87
86
else
88
87
return TangentBundle {N+1} (r. tup[1 ]. primal,
89
- (r. tup[1 ]. partials... , primal (b),
88
+ (r. tup[1 ]. tangent . partials... , primal (b),
90
89
ntuple (i-> partial (b,i), 2 ^ (N+ 1 )- 1 )... ))
91
90
end
92
91
end
93
92
94
93
function shuffle_up (r:: UniformBundle{N, B, U} ) where {N, B, U}
95
94
(a, b) = primal (r)
96
- if r. partial === b
95
+ if r. tangent . val === b
97
96
u = b
98
97
elseif b == NoTangent () && U === ZeroTangent
99
98
u = b
107
106
struct ∂☆internal{N}; end
108
107
struct ∂☆shuffle{N}; end
109
108
110
- shuffle_base (r) = TangentBundle {1} (r[1 ], (r[2 ],))
109
+ shuffle_base (r) = ExplicitTangentBundle {1} (r[1 ], (r[2 ],))
111
110
112
111
function (:: ∂☆internal{1 })(args:: AbstractTangentBundle{1} ...)
113
112
r = my_frule (args... )
@@ -119,7 +118,7 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
119
118
end
120
119
121
120
function ChainRulesCore. frule_via_ad (:: DiffractorRuleConfig , partials, args... )
122
- bundles = map ((p,a) -> TangentBundle {1} (a, (p,)), partials, args)
121
+ bundles = map ((p,a) -> ExplicitTangentBundle {1} (a, (p,)), partials, args)
123
122
result = ∂☆internal {1} ()(bundles... )
124
123
primal (result), first_partial (result)
125
124
end
@@ -142,14 +141,14 @@ end
142
141
# Special case rules for performance
143
142
@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: TangentBundle{N} , s:: AbstractTangentBundle{N} ) where {N}
144
143
s = primal (s)
145
- TangentBundle {N} (getfield (primal (x), s),
146
- map (x-> lifted_getfield (x, s), x. partials))
144
+ ExplicitTangentBundle {N} (getfield (primal (x), s),
145
+ map (x-> lifted_getfield (x, s), x. tangent . partials))
147
146
end
148
147
149
148
@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: TaylorBundle{N} , s:: AbstractTangentBundle{N} ) where {N}
150
149
s = primal (s)
151
150
TaylorBundle {N} (getfield (primal (x), s),
152
- map (y-> lifted_getfield (y, s), x. coeffs))
151
+ map (y-> lifted_getfield (y, s), x. tangent . coeffs))
153
152
end
154
153
155
154
@Base . constprop :aggressive function (:: ∂☆{N})(:: ATB{N, typeof(getfield)} , x:: CompositeBundle{N} , s:: AbstractTangentBundle{N, Int} ) where {N}
@@ -162,16 +161,16 @@ end
162
161
163
162
@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: ATB{N} , s:: ATB{N} , inbounds:: ATB{N} ) where {N}
164
163
s = primal (s)
165
- TangentBundle {N} (getfield (primal (x), s, primal (inbounds)),
166
- map (x-> lifted_getfield (x, s), x. partials))
164
+ ExplicitTangentBundle {N} (getfield (primal (x), s, primal (inbounds)),
165
+ map (x-> lifted_getfield (x, s), x. tangent . partials))
167
166
end
168
167
169
168
@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: UniformBundle{N, <:Any, U} , s:: AbstractTangentBundle{N} ) where {N, U}
170
- UniformBundle {N,<:Any,U} (getfield (primal (x), primal (s)), x. partial )
169
+ UniformBundle {N,<:Any,U} (getfield (primal (x), primal (s)), x. tangent . val )
171
170
end
172
171
173
172
@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: UniformBundle{N, <:Any, U} , s:: AbstractTangentBundle{N} , inbounds:: AbstractTangentBundle{N} ) where {N, U}
174
- UniformBundle {N,<:Any,U} (getfield (primal (x), primal (s), primal (inbounds)), x. partial )
173
+ UniformBundle {N,<:Any,U} (getfield (primal (x), primal (s), primal (inbounds)), x. tangent . val )
175
174
end
176
175
177
176
function (:: ∂☆{N})(f:: ATB{N, typeof(tuple)} , args:: AbstractTangentBundle{N} ...) where {N}
0 commit comments