1
1
using Cassette
2
2
using ChainRules
3
3
using ChainRulesCore
4
- import ChainRulesCore: Wirtinger, Zero
4
+ import ChainRulesCore: Zero
5
+
6
+ # TODO : remove the copy pasted code and add that package
7
+ # copyed from SpecializeVarargs.jl, written by @MasonProtter
8
+ using MacroTools: MacroTools, splitdef, combinedef, @capture
9
+
10
+ macro specialize_vararg (n:: Int , fdef:: Expr )
11
+ @assert n > 0
12
+
13
+ macros = Symbol[]
14
+ while fdef. head == :macrocall && length (fdef. args) == 3
15
+ push! (macros, fdef. args[1 ])
16
+ fdef = fdef. args[3 ]
17
+ end
18
+
19
+ d = splitdef (fdef)
20
+ args = d[:args ][end ]
21
+ @assert d[:args ][end ] isa Expr && d[:args ][end ]. head == Symbol (" ..." ) && d[:args ][end ]. args[] isa Symbol
22
+ args_symbol = d[:args ][end ]. args[]
23
+
24
+ fdefs = Expr (:block )
25
+
26
+ for i in 1 : n- 1
27
+ di = deepcopy (d)
28
+ pop! (di[:args ])
29
+ args = Tuple (gensym (" arg$j " ) for j in 1 : i)
30
+ Ts = Tuple (gensym (" T$j " ) for j in 1 : i)
31
+
32
+ args_with_Ts = ((arg, T) -> :($ arg :: $T )). (args, Ts)
33
+
34
+ di[:whereparams ] = (di[:whereparams ]. .. , Ts... )
35
+
36
+ push! (di[:args ], args_with_Ts... )
37
+ pushfirst! (di[:body ]. args, :($ args_symbol = $ (Expr (:tuple , args... ))))
38
+ cfdef = combinedef (di)
39
+ mcfdef = isempty (macros) ? cfdef : foldr ((m,f) -> Expr (:macrocall , m, nothing , f), macros, init= cfdef)
40
+ push! (fdefs. args, mcfdef)
41
+ end
42
+
43
+ di = deepcopy (d)
44
+ pop! (di[:args ])
45
+ args = tuple ((gensym () for j in 1 : n). .. , :($ (gensym (" args" )). .. ))
46
+ Ts = Tuple (gensym (" T$j " ) for j in 1 : n)
47
+
48
+ args_with_Ts = (((arg, T) -> :($ arg :: $T )). (args[1 : end - 1 ], Ts). .. , args[end ])
49
+
50
+ di[:whereparams ] = (di[:whereparams ]. .. , Ts... )
51
+
52
+ push! (di[:args ], args_with_Ts... )
53
+ pushfirst! (di[:body ]. args, :($ args_symbol = $ (Expr (:tuple , args... ))))
54
+
55
+ cfdef = combinedef (di)
56
+ mcfdef = isempty (macros) ? cfdef : foldr ((m,f) -> Expr (:macrocall , m, nothing , f), macros, init= cfdef)
57
+ push! (fdefs. args, mcfdef)
58
+
59
+ esc (fdefs)
60
+ end
5
61
6
62
using Cassette: overdub, Context, nametype, similarcontext
7
63
30
86
@inline _partials (:: Any , x) = Zero ()
31
87
@inline _partials (:: Tag{T} , d:: Dual{Tag{T}} ) where T = d. partials
32
88
33
- Wirtinger (primal, conjugate) = Wirtinger .(primal, conjugate)
34
-
35
89
@inline _values (S, xs) = map (x-> _value (S, x), xs)
36
90
@inline _partialss (S, xs) = map (x-> _partials (S, x), xs)
37
91
@@ -48,64 +102,54 @@ Wirtinger(primal, conjugate) = Wirtinger.(primal, conjugate)
48
102
end
49
103
50
104
# actually interesting:
51
-
52
105
@inline isinteresting (ctx:: TaggedCtx , f, a) = anydual (a)
53
106
@inline isinteresting (ctx:: TaggedCtx , f, a, b) = anydual (a, b)
54
107
@inline isinteresting (ctx:: TaggedCtx , f, a, b, c) = anydual (a, b, c)
55
108
@inline isinteresting (ctx:: TaggedCtx , f, a, b, c, d) = anydual (a, b, c, d)
56
- @inline isinteresting (ctx:: TaggedCtx , f, args... ) = false
57
- @inline isinteresting (ctx:: TaggedCtx , f:: typeof (Base. show), args... ) = false
109
+ @inline isinteresting (ctx:: TaggedCtx , f, args... ) = anydual (args... )
110
+ @inline isinteresting (ctx:: TaggedCtx , f:: Core.Builtin , args... ) = false
111
+ @inline isinteresting (ctx:: TaggedCtx , f:: Union {typeof (ForwardDiff2. find_dual),
112
+ typeof (ForwardDiff2. anydual)}, args... ) = false
58
113
59
- @inline function _frule_overdub2 (ctx:: TaggedCtx{T} , f, args... ) where T
114
+ @specialize_vararg 4 @ inline function _frule_overdub2 (ctx:: TaggedCtx{T} , f:: F , args... ) where {T,F}
60
115
# Here we can assume that one or more `args` is a Dual with tag
61
116
# of type T.
62
117
63
118
tag = Tag {T} ()
64
119
# unwrap only duals with the tag T.
65
120
vs = _values (tag, args)
66
121
122
+ # extract the partials only for the current tag
123
+ # so we can pass them to the pushforward
124
+ ps = _partialss (tag, args)
125
+
126
+ # default `dself` to `Zero()`
127
+ dself = Zero ()
128
+
67
129
# call frule to see if there is a rule for this call:
68
130
if ctx. metadata isa Tag
69
131
ctx1 = similarcontext (ctx, metadata= oldertag (ctx. metadata))
70
132
71
133
# we call frule with an older context because the Dual numbers may
72
134
# themselves contain Dual numbers that were created in an older context
73
- frule_result = overdub (ctx1, frule, f, vs... )
135
+ frule_result = overdub (ctx1, frule, f, vs... , dself, ps ... )
74
136
else
75
- frule_result = frule (f, vs... )
137
+ frule_result = frule (f, vs... , dself, ps ... )
76
138
end
77
139
78
140
if frule_result === nothing
79
141
# this means there is no frule
80
142
# We can't just do f(args...) here because `f` might be
81
143
# a closure which closes over a Dual number, hence we call
82
144
# recurse. Recurse overdubs the calls inside `f` and not `f` itself
83
-
84
145
return Cassette. overdub (ctx, f, args... )
85
146
else
86
147
# this means there exists an frule for this specific call.
87
148
# frule_result is then a tuple (val, pushforward) where val
88
149
# is the primal result. (Note: this may be Dual numbers but only
89
150
# with an older tag)
90
- val, pushforward = frule_result
91
-
92
- # extract the partials only for the current tag
93
- # so we can pass them to the pushforward
94
- ps = _partialss (tag, args)
95
-
96
- # Call the pushforward to get new partials
97
- # we call it with the older context because the partials
98
- # might themselves be Duals from older contexts
99
- if ctx. metadata isa Tag
100
- ctx1 = similarcontext (ctx, metadata= oldertag (ctx. metadata))
101
- ∂s = overdub (ctx1, pushforward, Zero (), ps... )
102
- else
103
- ∂s = pushforward (Zero (), ps... )
104
- end
151
+ val, ∂s = frule_result
105
152
106
- # Attach the new partials to the primal result
107
- # multi-output `f` such as result in the new partials being
108
- # a tuple, we handle both cases:
109
153
return if ∂s isa Tuple
110
154
map (val, ∂s) do v, ∂
111
155
Dual {Tag{T}} (v, ∂)
116
160
end
117
161
end
118
162
119
- @inline function alternative (ctx:: TaggedCtx{T} , f, args... ) where {T}
163
+ @specialize_vararg 4 @ inline function alternative (ctx:: TaggedCtx{T} , f:: F , args... ) where {T,F }
120
164
# This method only executes if `args` contains at least 1 Dual
121
165
# the question is what is its tag
122
166
161
205
162
206
163
207
# #### Inference Hacks
164
- # this makes `log` work by making throw_complex_domainerror inferable, but not really sure why
165
- @inline isinteresting (ctx:: TaggedCtx , f:: typeof (Core. throw), xs) = true
166
- # add `DualContext` here to avoid ambiguity
167
- @noinline alternative (ctx:: Union{DualContext,TaggedCtx} , f:: typeof (Core. throw), arg) = throw (arg)
168
-
169
- @inline isinteresting (ctx:: TaggedCtx , f:: typeof (Base. print_to_string), args... ) = true
170
- @noinline alternative (ctx:: Union{DualContext,TaggedCtx} , f:: typeof (Base. print_to_string), args... ) = f (args... )
208
+ @inline isinteresting (ctx:: TaggedCtx , f:: Union{typeof(Base.print_to_string),typeof(hash)} , args... ) = false
209
+ @inline Cassette. overdub (ctx:: TaggedCtx , f:: Union{typeof(Base.print_to_string),typeof(hash)} , args... ) = f (args... )
210
+ @inline Cassette. overdub (ctx:: TaggedCtx , f:: Core.Builtin , args... ) = f (args... )
0 commit comments