@@ -65,6 +65,13 @@ Cassette.@context DualContext
65
65
66
66
const TaggedCtx{T} = Context{nametype (DualContext),T}
67
67
68
+ untagtype (:: Type{<:Dual{Tag{T},V}} , :: Type{<:TaggedCtx{T}} ) where {T,V} = V
69
+
70
+ @inline @generated function _overdub (ctx:: TaggedCtx{T} , f, args... ) where T
71
+ F = Cassette. ReflectOn{Tuple{f, (untagtype (args[i], ctx) for i in 1 : nfields (args)). .. }}
72
+ :(overdub (ctx, $ F (), f, args... ))
73
+ end
74
+
68
75
function dualcontext ()
69
76
# Note that the `dualtag()` is not of the same type as that of the
70
77
# Duals constructed in this context, because it is called in the older context
136
143
137
144
# we call frule with an older context because the Dual numbers may
138
145
# themselves contain Dual numbers that were created in an older context
139
- frule_result = overdub (ctx1, frule, f, vs... , dself, ps... )
146
+ frule_result = _overdub (ctx1, frule, f, vs... , dself, ps... )
140
147
else
141
148
frule_result = frule (f, vs... , dself, ps... )
142
149
end
146
153
# We can't just do f(args...) here because `f` might be
147
154
# a closure which closes over a Dual number, hence we call
148
155
# recurse. Recurse overdubs the calls inside `f` and not `f` itself
149
- return Cassette . overdub (ctx, f, args... )
156
+ return _overdub (ctx, f, args... )
150
157
else
151
158
# this means there exists an frule for this specific call.
152
159
# frule_result is then a tuple (val, pushforward) where val
172
179
173
180
idx = find_dual (tag, args... )
174
181
if f === Dual
175
- return overdub (ctx, f, args... )
182
+ return _overdub (ctx, f, args... )
176
183
elseif idx === 0
177
184
# This is the base case for the recursion in this function which
178
185
# tries to do the alternative with successively older contexts
183
190
# none of the arguments have the same tag as the context
184
191
# try with the parent context
185
192
ctx1 = similarcontext (ctx, metadata= oldertag (ctx. metadata))
186
- return overdub (ctx1, f, args... )
193
+ return _overdub (ctx1, f, args... )
187
194
else
188
195
# call ChainRules.frule to execute `f` and
189
196
# get a function that computes the partials
193
200
194
201
function dualrun (f, args... )
195
202
ctx = dualcontext ()
196
- return overdub (ctx, f, args... )
203
+ return _overdub (ctx, f, args... )
197
204
end
198
205
199
206
const BINARY_PREDICATES = Symbol[:isequal , :isless , :< , :> , :(== ), :(!= ), :(<= ), :(>= )]
0 commit comments