1
1
# These are some macros (and supporting functions) to make it easier to define rules.
2
2
3
- """
4
- propagator_name(f, propname)
5
-
6
- Determines a reasonable name for the propagator function.
7
- The name doesn't really matter too much as it is a local function to be returned
8
- by `frule` or `rrule`, but a good name make debugging easier.
9
- `f` should be some form of AST representation of the actual function,
10
- `propname` should be either `:pullback` or `:pushforward`
11
-
12
- This is able to deal with fairly complex expressions for `f`:
13
-
14
- julia> propagator_name(:bar, :pushforward)
15
- :bar_pushforward
16
-
17
- julia> propagator_name(esc(:(Base.Random.foo)), :pullback)
18
- :foo_pullback
19
- """
20
- propagator_name (f:: Expr , propname:: Symbol ) = propagator_name (f. args[end ], propname)
21
- propagator_name (fname:: Symbol , propname:: Symbol ) = Symbol (fname, :_ , propname)
22
- propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
23
-
24
-
25
- """
26
- propagation_expr(𝒟, Δs, ∂s)
27
-
28
- Returns the expression for the propagation of
29
- the input gradient `Δs` though the partials `∂s`.
30
-
31
- 𝒟 is an expression that when evaluated returns the type-of the input domain.
32
- For example if the derivative is being taken at the point `1` it returns `Int`.
33
- if it is taken at `1+1im` it returns `Complex{Int}`.
34
- At present it is ignored for non-Wirtinger derivatives.
35
- """
36
- function propagation_expr (𝒟, Δs, ∂s)
37
- wirtinger_indices = findall (∂s) do ex
38
- Meta. isexpr (ex, :call ) && ex. args[1 ] === :Wirtinger
39
- end
40
- ∂s = map (esc, ∂s)
41
- if isempty (wirtinger_indices)
42
- return standard_propagation_expr (Δs, ∂s)
43
- else
44
- return wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
45
- end
46
- end
47
-
48
- function standard_propagation_expr (Δs, ∂s)
49
- # This is basically Δs ⋅ ∂s
50
-
51
- # Notice: the thunking of `∂s[i] (potentially) saves us some computation
52
- # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
53
- # as the pullback is evaluated
54
- ∂_mul_Δs = [:(@thunk ($ (∂s[i])) * $ (Δs[i])) for i in 1 : length (∂s)]
55
- return :(+ ($ (∂_mul_Δs... )))
56
- end
57
-
58
- function wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
59
- ∂_mul_Δs_primal = Any[]
60
- ∂_mul_Δs_conjugate = Any[]
61
- ∂_wirtinger_defs = Any[]
62
- for i in 1 : length (∂s)
63
- if i in wirtinger_indices
64
- Δi = Δs[i]
65
- ∂i = Symbol (string (:∂ , i))
66
- push! (∂_wirtinger_defs, :($ ∂i = $ (∂s[i])))
67
- ∂f∂i_mul_Δ = :(wirtinger_primal ($ ∂i) * wirtinger_primal ($ Δi))
68
- ∂f∂ī_mul_Δ̄ = :(conj (wirtinger_conjugate ($ ∂i)) * wirtinger_conjugate ($ Δi))
69
- ∂f̄∂i_mul_Δ = :(wirtinger_conjugate ($ ∂i) * wirtinger_primal ($ Δi))
70
- ∂f̄∂ī_mul_Δ̄ = :(conj (wirtinger_primal ($ ∂i)) * wirtinger_conjugate ($ Δi))
71
- push! (∂_mul_Δs_primal, :($ ∂f∂i_mul_Δ + $ ∂f∂ī_mul_Δ̄))
72
- push! (∂_mul_Δs_conjugate, :($ ∂f̄∂i_mul_Δ + $ ∂f̄∂ī_mul_Δ̄))
73
- else
74
- ∂_mul_Δ = :(@thunk ($ (∂s[i])) * $ (Δs[i]))
75
- push! (∂_mul_Δs_primal, ∂_mul_Δ)
76
- push! (∂_mul_Δs_conjugate, ∂_mul_Δ)
77
- end
78
- end
79
- primal_sum = :(+ ($ (∂_mul_Δs_primal... )))
80
- conjugate_sum = :(+ ($ (∂_mul_Δs_conjugate... )))
81
- return quote # This will be a block, so will have value equal to last statement
82
- $ (∂_wirtinger_defs... )
83
- w = Wirtinger ($ primal_sum, $ conjugate_sum)
84
- refine_differential ($ 𝒟, w)
85
- end
86
- end
87
-
88
3
"""
89
4
@scalar_rule(f(x₁, x₂, ...),
90
5
@setup(statement₁, statement₂, ...),
@@ -151,9 +66,49 @@ is equivalent to:
151
66
152
67
For examples, see ChainRulesCore' `rules` directory.
153
68
154
- See also: [`frule`](@ref), [`rrule`](@ref).
69
+ See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
155
70
"""
156
71
macro scalar_rule (call, maybe_setup, partials... )
72
+ call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input (
73
+ call, maybe_setup, partials
74
+ )
75
+ f = call. args[1 ]
76
+
77
+ # An expression that when evaluated will return the type of the input domain.
78
+ # Multiple repetitions of this expression should optimize out. But if it does not then
79
+ # may need to move its definition into the body of the `rrule`/`frule`
80
+ 𝒟 = :(typeof (first (promote ($ (call. args[2 : end ]. .. )))))
81
+
82
+ frule_expr = scalar_frule_expr (𝒟, f, call, setup_stmts, inputs, partials)
83
+ rrule_expr = scalar_rrule_expr (𝒟, f, call, setup_stmts, inputs, partials)
84
+
85
+
86
+ # ###########################################################################
87
+ # Final return: building the expression to insert in the place of this macro
88
+ code = quote
89
+ if ! ($ f isa Type) && fieldcount (typeof ($ f)) > 0
90
+ throw (ArgumentError (
91
+ " @scalar_rule cannot be used on closures/functors (such as $($ f) )"
92
+ ))
93
+ end
94
+
95
+ $ (frule_expr)
96
+ $ (rrule_expr)
97
+ end
98
+ end
99
+
100
+
101
+ """
102
+ _normalize_scalarrules_macro_input(call, maybe_setup, partials)
103
+
104
+ returns (in order) the correctly escaped:
105
+ - `call` with out any type constraints
106
+ - `setup_stmts`: the content of `@setup` or `nothing` if that is not provided,
107
+ - `inputs`: with all args having the constraints removed from call, or
108
+ defaulting to `Number`
109
+ - `partials`: which are all `Expr{:tuple,...}`
110
+ """
111
+ function _normalize_scalarrules_macro_input (call, maybe_setup, partials)
157
112
# ###########################################################################
158
113
# Setup: normalizing input form etc
159
114
@@ -164,12 +119,12 @@ macro scalar_rule(call, maybe_setup, partials...)
164
119
partials = (maybe_setup, partials... )
165
120
end
166
121
@assert Meta. isexpr (call, :call )
167
- f = esc (call. args[1 ])
168
122
169
123
# Annotate all arguments in the signature as scalars
170
124
inputs = map (call. args[2 : end ]) do arg
171
125
esc (Meta. isexpr (arg, :(:: )) ? arg : Expr (:(:: ), arg, :Number ))
172
126
end
127
+
173
128
# Remove annotations and escape names for the call
174
129
for (i, arg) in enumerate (call. args)
175
130
if Meta. isexpr (arg, :(:: ))
@@ -189,78 +144,153 @@ macro scalar_rule(call, maybe_setup, partials...)
189
144
end
190
145
end
191
146
192
- # ###########################################################################
193
- # Main body: defining the results of the frule/rrule
194
-
195
- # An expression that when evaluated will return the type of the input domain.
196
- # Multiple repetitions of this expression should optimize out. But if it does not then
197
- # may need to move its definition into the body of the `rrule`/`frule`
198
- 𝒟 = :(typeof (first (promote ($ (call. args[2 : end ]. .. )))))
147
+ return call, setup_stmts, inputs, partials
148
+ end
199
149
150
+ function scalar_frule_expr (𝒟, f, call, setup_stmts, inputs, partials)
200
151
n_outputs = length (partials)
201
152
n_inputs = length (inputs)
202
153
203
- pushforward = let
204
- # Δs is the input to the propagator rule
205
- # because this is push-forward there is one per input to the function
206
- Δs = [Symbol (string (:Δ , i)) for i in 1 : n_inputs]
207
- pushforward_returns = map (1 : n_outputs) do output_i
208
- ∂s = partials[output_i]. args
209
- propagation_expr (𝒟, Δs, ∂s)
210
- end
211
- if n_outputs > 1
212
- # For forward-mode we only return a tuple if output actually a tuple.
213
- pushforward_returns = Expr (:tuple , pushforward_returns... )
214
- else
215
- pushforward_returns = pushforward_returns[1 ]
216
- end
217
- quote
218
- # _ is the input derivative w.r.t. function internals. since we do not
219
- # allow closures/functors with @scalar_rule, it is always ignored
220
- function $ (propagator_name (f, :pushforward ))(_, $ (Δs... ))
221
- $ pushforward_returns
222
- end
223
- end
154
+ # Δs is the input to the propagator rule
155
+ # because this is push-forward there is one per input to the function
156
+ Δs = [Symbol (string (:Δ , i)) for i in 1 : n_inputs]
157
+ pushforward_returns = map (1 : n_outputs) do output_i
158
+ ∂s = partials[output_i]. args
159
+ propagation_expr (𝒟, Δs, ∂s)
224
160
end
225
-
226
- pullback = let
227
- # Δs is the input to the propagator rule
228
- # because this is a pull-back there is one per output of function
229
- Δs = [Symbol (string (:Δ , i)) for i in 1 : n_outputs]
230
-
231
- # 1 partial derivative per input
232
- pullback_returns = map (1 : n_inputs) do input_i
233
- ∂s = [partial. args[input_i] for partial in partials]
234
- propagation_expr (𝒟, Δs, ∂s)
235
- end
236
-
237
- quote
238
- function $ (propagator_name (f, :pullback ))($ (Δs... ))
239
- return (NO_FIELDS, $ (pullback_returns... ))
240
- end
241
- end
161
+ if n_outputs > 1
162
+ # For forward-mode we only return a tuple if output actually a tuple.
163
+ pushforward_returns = Expr (:tuple , pushforward_returns... )
164
+ else
165
+ pushforward_returns = pushforward_returns[1 ]
242
166
end
243
167
244
- # ###########################################################################
245
- # Final return: building the expression to insert in the place of this macro
246
-
247
- code = quote
248
- if fieldcount (typeof ($ f)) > 0
249
- throw (ArgumentError (
250
- " @scalar_rule cannot be used on closures/functors (such as $f )"
251
- ))
168
+ pushforward = quote
169
+ # _ is the input derivative w.r.t. function internals. since we do not
170
+ # allow closures/functors with @scalar_rule, it is always ignored
171
+ function $ (propagator_name (f, :pushforward ))(_, $ (Δs... ))
172
+ $ pushforward_returns
252
173
end
174
+ end
253
175
176
+ return quote
254
177
function ChainRulesCore. frule (:: typeof ($ f), $ (inputs... ))
255
178
$ (esc (:Ω )) = $ call
256
179
$ (setup_stmts... )
257
180
return $ (esc (:Ω )), $ pushforward
258
181
end
182
+ end
183
+ end
184
+
185
+ function scalar_rrule_expr (𝒟, f, call, setup_stmts, inputs, partials)
186
+ n_outputs = length (partials)
187
+ n_inputs = length (inputs)
188
+
189
+ # Δs is the input to the propagator rule
190
+ # because this is a pull-back there is one per output of function
191
+ Δs = [Symbol (string (:Δ , i)) for i in 1 : n_outputs]
192
+
193
+ # 1 partial derivative per input
194
+ pullback_returns = map (1 : n_inputs) do input_i
195
+ ∂s = [partial. args[input_i] for partial in partials]
196
+ propagation_expr (𝒟, Δs, ∂s)
197
+ end
198
+
199
+ pullback = quote
200
+ function $ (propagator_name (f, :pullback ))($ (Δs... ))
201
+ return (NO_FIELDS, $ (pullback_returns... ))
202
+ end
203
+ end
259
204
205
+ return quote
260
206
function ChainRulesCore. rrule (:: typeof ($ f), $ (inputs... ))
261
207
$ (esc (:Ω )) = $ call
262
208
$ (setup_stmts... )
263
209
return $ (esc (:Ω )), $ pullback
264
210
end
265
211
end
266
212
end
213
+
214
+ """
215
+ propagation_expr(𝒟, Δs, ∂s)
216
+
217
+ Returns the expression for the propagation of
218
+ the input gradient `Δs` though the partials `∂s`.
219
+
220
+ 𝒟 is an expression that when evaluated returns the type-of the input domain.
221
+ For example if the derivative is being taken at the point `1` it returns `Int`.
222
+ if it is taken at `1+1im` it returns `Complex{Int}`.
223
+ At present it is ignored for non-Wirtinger derivatives.
224
+ """
225
+ function propagation_expr (𝒟, Δs, ∂s)
226
+ wirtinger_indices = findall (∂s) do ex
227
+ Meta. isexpr (ex, :call ) && ex. args[1 ] === :Wirtinger
228
+ end
229
+ ∂s = map (esc, ∂s)
230
+ if isempty (wirtinger_indices)
231
+ return standard_propagation_expr (Δs, ∂s)
232
+ else
233
+ return wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
234
+ end
235
+ end
236
+
237
+ function standard_propagation_expr (Δs, ∂s)
238
+ # This is basically Δs ⋅ ∂s
239
+
240
+ # Notice: the thunking of `∂s[i] (potentially) saves us some computation
241
+ # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
242
+ # as the pullback is evaluated
243
+ ∂_mul_Δs = [:(@thunk ($ (∂s[i])) * $ (Δs[i])) for i in 1 : length (∂s)]
244
+ return :(+ ($ (∂_mul_Δs... )))
245
+ end
246
+
247
+ function wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
248
+ ∂_mul_Δs_primal = Any[]
249
+ ∂_mul_Δs_conjugate = Any[]
250
+ ∂_wirtinger_defs = Any[]
251
+ for i in 1 : length (∂s)
252
+ if i in wirtinger_indices
253
+ Δi = Δs[i]
254
+ ∂i = Symbol (string (:∂ , i))
255
+ push! (∂_wirtinger_defs, :($ ∂i = $ (∂s[i])))
256
+ ∂f∂i_mul_Δ = :(wirtinger_primal ($ ∂i) * wirtinger_primal ($ Δi))
257
+ ∂f∂ī_mul_Δ̄ = :(conj (wirtinger_conjugate ($ ∂i)) * wirtinger_conjugate ($ Δi))
258
+ ∂f̄∂i_mul_Δ = :(wirtinger_conjugate ($ ∂i) * wirtinger_primal ($ Δi))
259
+ ∂f̄∂ī_mul_Δ̄ = :(conj (wirtinger_primal ($ ∂i)) * wirtinger_conjugate ($ Δi))
260
+ push! (∂_mul_Δs_primal, :($ ∂f∂i_mul_Δ + $ ∂f∂ī_mul_Δ̄))
261
+ push! (∂_mul_Δs_conjugate, :($ ∂f̄∂i_mul_Δ + $ ∂f̄∂ī_mul_Δ̄))
262
+ else
263
+ ∂_mul_Δ = :(@thunk ($ (∂s[i])) * $ (Δs[i]))
264
+ push! (∂_mul_Δs_primal, ∂_mul_Δ)
265
+ push! (∂_mul_Δs_conjugate, ∂_mul_Δ)
266
+ end
267
+ end
268
+ primal_sum = :(+ ($ (∂_mul_Δs_primal... )))
269
+ conjugate_sum = :(+ ($ (∂_mul_Δs_conjugate... )))
270
+ return quote # This will be a block, so will have value equal to last statement
271
+ $ (∂_wirtinger_defs... )
272
+ w = Wirtinger ($ primal_sum, $ conjugate_sum)
273
+ refine_differential ($ 𝒟, w)
274
+ end
275
+ end
276
+
277
+ """
278
+ propagator_name(f, propname)
279
+
280
+ Determines a reasonable name for the propagator function.
281
+ The name doesn't really matter too much as it is a local function to be returned
282
+ by `frule` or `rrule`, but a good name make debugging easier.
283
+ `f` should be some form of AST representation of the actual function,
284
+ `propname` should be either `:pullback` or `:pushforward`
285
+
286
+ This is able to deal with fairly complex expressions for `f`:
287
+
288
+ julia> propagator_name(:bar, :pushforward)
289
+ :bar_pushforward
290
+
291
+ julia> propagator_name(esc(:(Base.Random.foo)), :pullback)
292
+ :foo_pullback
293
+ """
294
+ propagator_name (f:: Expr , propname:: Symbol ) = propagator_name (f. args[end ], propname)
295
+ propagator_name (fname:: Symbol , propname:: Symbol ) = Symbol (fname, :_ , propname)
296
+ propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
0 commit comments