Skip to content

Commit cb5fb37

Browse files
shashiYingboMa"Shashi Gowda"
committed
overdub with ReflectOn on the correct methods. makes f(x::Float64) work.
Co-authored-by: "Yingbo Ma" <mayingbo5@gmail.com> Co-authored-by: "Shashi Gowda" <gowda@mit.edu>
1 parent d464c4d commit cb5fb37

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

src/dual_context.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ Cassette.@context DualContext
6565

6666
const TaggedCtx{T} = Context{nametype(DualContext),T}
6767

68+
untagtype(::Type{<:Dual{Tag{T},V}}, ::Type{<:TaggedCtx{T}}) where {T,V} = V
69+
70+
@inline @generated function _overdub(ctx::TaggedCtx{T}, f, args...) where T
71+
F = Cassette.ReflectOn{Tuple{f, (untagtype(args[i], ctx) for i in 1:nfields(args))...}}
72+
:(overdub(ctx, $F(), f, args...))
73+
end
74+
6875
function dualcontext()
6976
# Note that the `dualtag()` is not of the same type as that of the
7077
# Duals constructed in this context, because it is called in the older context
@@ -136,7 +143,7 @@ end
136143

137144
# we call frule with an older context because the Dual numbers may
138145
# themselves contain Dual numbers that were created in an older context
139-
frule_result = overdub(ctx1, frule, f, vs..., dself, ps...)
146+
frule_result = _overdub(ctx1, frule, f, vs..., dself, ps...)
140147
else
141148
frule_result = frule(f, vs..., dself, ps...)
142149
end
@@ -146,7 +153,7 @@ end
146153
# We can't just do f(args...) here because `f` might be
147154
# a closure which closes over a Dual number, hence we call
148155
# recurse. Recurse overdubs the calls inside `f` and not `f` itself
149-
return Cassette.overdub(ctx, f, args...)
156+
return _overdub(ctx, f, args...)
150157
else
151158
# this means there exists an frule for this specific call.
152159
# frule_result is then a tuple (val, pushforward) where val
@@ -172,7 +179,7 @@ end
172179

173180
idx = find_dual(tag, args...)
174181
if f === Dual
175-
return overdub(ctx, f, args...)
182+
return _overdub(ctx, f, args...)
176183
elseif idx === 0
177184
# This is the base case for the recursion in this function which
178185
# tries to do the alternative with successively older contexts
@@ -183,7 +190,7 @@ end
183190
# none of the arguments have the same tag as the context
184191
# try with the parent context
185192
ctx1 = similarcontext(ctx, metadata=oldertag(ctx.metadata))
186-
return overdub(ctx1, f, args...)
193+
return _overdub(ctx1, f, args...)
187194
else
188195
# call ChainRules.frule to execute `f` and
189196
# get a function that computes the partials
@@ -193,7 +200,7 @@ end
193200

194201
function dualrun(f, args...)
195202
ctx = dualcontext()
196-
return overdub(ctx, f, args...)
203+
return _overdub(ctx, f, args...)
197204
end
198205

199206
const BINARY_PREDICATES = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]

0 commit comments

Comments
 (0)