1
1
# These are some macros (and supporting functions) to make it easier to define rules.
2
2
3
+ """
4
+ propagator_name(f, propname)
5
+ Determines a reasonable name for the propagator function.
6
+ The name doesn't really matter too much as it is a local function to be returned
7
+ by `frule` or `rrule`, but a good name make debugging easier.
8
+ `f` should be some form of AST representation of the actual function,
9
+ `propname` should be either `:pullback` or `:pushforward`
10
+
11
+ This is able to deal with fairly complex expressions for `f`:
12
+
13
+ julia> propagator_name(:bar, :pushforward)
14
+ :bar_pushforward
15
+
16
+ julia> propagator_name(esc(:(Base.Random.foo)), :pullback)
17
+ :foo_pullback
18
+ """
19
+ propagator_name (f:: Expr , propname:: Symbol ) = propagator_name (f. args[end ], propname)
20
+ propagator_name (fname:: Symbol , propname:: Symbol ) = Symbol (fname, :_ , propname)
21
+ propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
22
+
23
+
24
+ """
25
+ propagation_expr(𝒟, Δs, ∂s)
26
+
27
+ Returns the expression for the propagation of
28
+ the input gradient `Δs` though the partials `∂s`.
29
+
30
+ 𝒟 is an expression that when evaluated returns the type-of the input domain.
31
+ For example if the derivative is being taken at the point `1` it returns `Int`.
32
+ if it is taken at `1+1im` it returns `Complex{Int}`.
33
+ At present it is ignored for non-Wirtinger derivatives.
34
+ """
35
+ function propagation_expr (𝒟, Δs, ∂s)
36
+ wirtinger_indices = findall (∂s) do ex
37
+ Meta. isexpr (ex, :call ) && ex. args[1 ] === :Wirtinger
38
+ end
39
+ ∂s = map (esc, ∂s)
40
+ if isempty (wirtinger_indices)
41
+ return standard_propagation_expr (Δs, ∂s)
42
+ else
43
+ return wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
44
+ end
45
+ end
46
+
47
+ function standard_propagation_expr (Δs, ∂s)
48
+ # This is basically Δs ⋅ ∂s
49
+
50
+ # Notice: the thunking of `∂s[i] (potentially) saves us some computation
51
+ # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
52
+ # as the pullback is evaluated
53
+ ∂_mul_Δs = [:(@thunk ($ (∂s[i])) * $ (Δs[i])) for i in 1 : length (∂s)]
54
+ return :(+ ($ (∂_mul_Δs... )))
55
+ end
56
+
57
+ function wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
58
+ ∂_mul_Δs_primal = Any[]
59
+ ∂_mul_Δs_conjugate = Any[]
60
+ ∂_wirtinger_defs = Any[]
61
+ for i in 1 : length (∂s)
62
+ if i in wirtinger_indices
63
+ Δi = Δs[i]
64
+ ∂i = Symbol (string (:∂ , i))
65
+ push! (∂_wirtinger_defs, :($ ∂i = $ (∂s[i])))
66
+ ∂f∂i_mul_Δ = :(wirtinger_primal ($ ∂i) * wirtinger_primal ($ Δi))
67
+ ∂f∂ī_mul_Δ̄ = :(conj (wirtinger_conjugate ($ ∂i)) * wirtinger_conjugate ($ Δi))
68
+ ∂f̄∂i_mul_Δ = :(wirtinger_conjugate ($ ∂i) * wirtinger_primal ($ Δi))
69
+ ∂f̄∂ī_mul_Δ̄ = :(conj (wirtinger_primal ($ ∂i)) * wirtinger_conjugate ($ Δi))
70
+ push! (∂_mul_Δs_primal, :($ ∂f∂i_mul_Δ + $ ∂f∂ī_mul_Δ̄))
71
+ push! (∂_mul_Δs_conjugate, :($ ∂f̄∂i_mul_Δ + $ ∂f̄∂ī_mul_Δ̄))
72
+ else
73
+ ∂_mul_Δ = :(@thunk ($ (∂s[i])) * $ (Δs[i]))
74
+ push! (∂_mul_Δs_primal, ∂_mul_Δ)
75
+ push! (∂_mul_Δs_conjugate, ∂_mul_Δ)
76
+ end
77
+ end
78
+ primal_sum = :(+ ($ (∂_mul_Δs_primal... )))
79
+ conjugate_sum = :(+ ($ (∂_mul_Δs_conjugate... )))
80
+ return quote # This will be a block, so will have value equal to last statement
81
+ $ (∂_wirtinger_defs... )
82
+ w = Wirtinger ($ primal_sum, $ conjugate_sum)
83
+ differential ($ 𝒟, w)
84
+ end
85
+ end
86
+
3
87
"""
4
88
@scalar_rule(f(x₁, x₂, ...),
5
89
@setup(statement₁, statement₂, ...),
@@ -42,7 +126,7 @@ e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x
42
126
At present this does not support defining for closures/functors.
43
127
Thus in reverse-mode, the first returned partial,
44
128
representing the derivative with respect to the function itself, is always `NO_FIELDS`.
45
- And in forwards -mode, the first input to the returned propergator is always ignored.
129
+ And in forward -mode, the first input to the returned propagator is always ignored.
46
130
47
131
The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
48
132
allows the primal result to be conveniently referenced (as `Ω`) within the
@@ -69,6 +153,9 @@ For examples, see ChainRulesCore' `rules` directory.
69
153
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
70
154
"""
71
155
macro scalar_rule (call, maybe_setup, partials... )
156
+ # ###########################################################################
157
+ # Setup: normalizing input form etc
158
+
72
159
if Meta. isexpr (maybe_setup, :macrocall ) && maybe_setup. args[1 ] == Symbol (" @setup" )
73
160
setup_stmts = map (esc, maybe_setup. args[3 : end ])
74
161
else
@@ -77,6 +164,7 @@ macro scalar_rule(call, maybe_setup, partials...)
77
164
end
78
165
@assert Meta. isexpr (call, :call )
79
166
f = esc (call. args[1 ])
167
+
80
168
# Annotate all arguments in the signature as scalars
81
169
inputs = map (call. args[2 : end ]) do arg
82
170
esc (Meta. isexpr (arg, :(:: )) ? arg : Expr (:(:: ), arg, :Number ))
@@ -90,6 +178,7 @@ macro scalar_rule(call, maybe_setup, partials...)
90
178
end
91
179
end
92
180
181
+ # For consistency in code that follows we make all partials tuple expressions
93
182
partials = map (partials) do partial
94
183
if Meta. isexpr (partial, :tuple )
95
184
partial
@@ -98,59 +187,58 @@ macro scalar_rule(call, maybe_setup, partials...)
98
187
Expr (:tuple , partial)
99
188
end
100
189
end
101
- @show partials
102
-
103
- # ###########################################################
104
- # Make pullback
105
- # (TODO : move to own function)
106
- # TODO : Wirtinger
107
-
108
- Δs = [Symbol (string (:Δ , i)) for i in 1 : length (partials)]
109
- pullback_returns = map (eachindex (inputs)) do input_i
110
- ∂s = [partials. args[input_i] for partial in partials]
111
- ∂s = map (esc, ∂s)
112
-
113
- # Notice: the thunking of `∂s[i] (potentially) saves us some computation
114
- # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
115
- # as the pullback is evaluated
116
- ∂_mul_Δs = [:(@thunk ($ (∂s[i])) * $ (Δs[i])) for i in 1 : length (∂s)]
117
- :(+ ($ (∂_mul_Δs... )))
118
- else
119
190
120
- pullback = quote
121
- function $ (Symbol (nameof (f), :_pullback ))($ (Δs... ))
122
- return (ChainRulesCore. NO_FIELDS, $ (pullback_returns... ))
191
+ # ###########################################################################
192
+ # Main body: defining the results of the frule/rrule
193
+
194
+ # An expression that when evaluated will return the type of the input domain.
195
+ # Multiple repetitions of this expression should optimize ot. But if it does not then
196
+ # may need to move its definition into the body of the `rrule`/`frule`
197
+ 𝒟 = :(typeof (first (promote ($ (call. args[2 : end ]. .. )))))
198
+
199
+ n_outputs = length (partials)
200
+ n_inputs = length (inputs)
201
+
202
+ pushforward = let
203
+ # Δs is the input to the propagator rule
204
+ # because this is push-forward there is one per input to the function
205
+ Δs = [Symbol (string (:Δ , i)) for i in 1 : n_inputs]
206
+ pushforward_returns = map (1 : n_outputs) do output_i
207
+ ∂s = partials[output_i]. args
208
+ propagation_expr (𝒟, Δs, ∂s)
123
209
end
124
- end
125
210
126
- # #######################################
127
- quote
128
- function ChainRulesCore . rrule ( :: typeof ( $ f), $ (inputs ... ))
129
- $ ( esc ( :Ω )) = $ call
130
- $ (setup_stmts ... )
131
- return $ ( esc ( :Ω )), $ esc (pullback)
211
+ quote
212
+ # _ is the input derivative w.r.t. function internals. since we do not
213
+ # allow closures/functors with @scalar_rule, it is always ignored
214
+ function $ ( propagator_name (f, :pushforward ))(_, $ (Δs ... ))
215
+ return $ ( Expr ( :tuple , pushforward_returns ... ) )
216
+ end
132
217
end
133
218
end
134
- end
135
- #= =
136
- if !all(Meta.isexpr(partial, :tuple) for partial in partials)
137
- input_rep = :(first(promote($(inputs...)))) # stand-in with the right type for an input
138
- forward_rules = Any[rule_from_partials(input_rep, partial.args...) for partial in partials]
139
- reverse_rules = map(1:length(inputs) do i
140
- reverse_partials = [partial.args[i] for partial in partials]
141
- push!(reverse_rules, rule_from_partials(inputs[i], reverse_partials...))
219
+
220
+ pullback = let
221
+ # Δs is the input to the propagator rule
222
+ # because this is a pull-back there is one per output of function
223
+ Δs = [Symbol (string (:Δ , i)) for i in 1 : n_outputs]
224
+
225
+ # 1 partial derivative per input
226
+ pullback_returns = map (1 : n_inputs) do input_i
227
+ ∂s = [partial. args[input_i] for partial in partials]
228
+ propagation_expr (𝒟, Δs, ∂s)
229
+ end
230
+
231
+ quote
232
+ function $ (propagator_name (f, :pullback ))($ (Δs... ))
233
+ return (NO_FIELDS, $ (pullback_returns... ))
234
+ end
142
235
end
143
- else
144
- @assert length(inputs) == 1 && all(!Meta.isexpr(partial, :tuple) for partial in partials)
145
- forward_rules = Any[rule_from_partials(inputs[1], partial) for partial in partials]
146
- reverse_rules = Any[rule_from_partials(inputs[1], partials...)]
147
236
end
148
237
149
- # First pseudo-partial is derivative WRT function itself. Since this macro does not
150
- # support closures, it is just the empty NamedTuple
151
- forward_rules = Expr(:tuple, ZERO_RULE, forward_rules...)
152
- reverse_rules = Expr(:tuple, NO_FIELDS, reverse_rules...)
153
- return quote
238
+ # ###########################################################################
239
+ # Final return: building the expression to insert in the place of this macro
240
+
241
+ code = quote
154
242
if fieldcount (typeof ($ f)) > 0
155
243
throw (ArgumentError (
156
244
" @scalar_rule cannot be used on closures/functors (such as $f )"
@@ -160,57 +248,13 @@ end
160
248
function ChainRulesCore. frule (:: typeof ($ f), $ (inputs... ))
161
249
$ (esc (:Ω )) = $ call
162
250
$ (setup_stmts... )
163
- return $(esc(:Ω)), $forward_rules
251
+ return $ (esc (:Ω )), $ pushforward
164
252
end
253
+
165
254
function ChainRulesCore. rrule (:: typeof ($ f), $ (inputs... ))
166
255
$ (esc (:Ω )) = $ call
167
256
$ (setup_stmts... )
168
- return $(esc(:Ω)), $reverse_rules
169
- end
170
- end
171
- end
172
- ==#
173
-
174
- @macroexpand (@scalar_rule (one (x), Zero ()))
175
-
176
-
177
-
178
- #= =
179
- function rule_from_partials(input_arg, ∂s...)
180
- wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s)
181
- ∂s = map(esc, ∂s)
182
- Δs = [Symbol(string(:Δ, i)) for i in 1:length(∂s)]
183
- Δs_tuple = Expr(:tuple, Δs...)
184
- if isempty(wirtinger_indices)
185
- ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
186
- return :(Rule($Δs_tuple -> +($(∂_mul_Δs...))))
187
- else
188
- ∂_mul_Δs_primal = Any[]
189
- ∂_mul_Δs_conjugate = Any[]
190
- ∂_wirtinger_defs = Any[]
191
- for i in 1:length(∂s)
192
- if i in wirtinger_indices
193
- Δi = Δs[i]
194
- ∂i = Symbol(string(:∂, i))
195
- push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
196
- ∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
197
- ∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
198
- ∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
199
- ∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
200
- push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
201
- push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
202
- else
203
- ∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
204
- push!(∂_mul_Δs_primal, ∂_mul_Δ)
205
- push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
206
- end
207
- end
208
- primal_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_primal...))))
209
- conjugate_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_conjugate...))))
210
- return quote
211
- $(∂_wirtinger_defs...)
212
- AbstractRule(typeof($input_arg), $primal_rule, $conjugate_rule)
257
+ return $ (esc (:Ω )), $ pullback
213
258
end
214
259
end
215
260
end
216
- ==#
0 commit comments