Skip to content

Commit e51ff80

Browse files
committed
split up scalar_rule into a bunch of functions
1 parent 1869be1 commit e51ff80

File tree

1 file changed

+170
-140
lines changed

1 file changed

+170
-140
lines changed

src/rule_definition_tools.jl

Lines changed: 170 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,5 @@
11
# These are some macros (and supporting functions) to make it easier to define rules.
22

3-
"""
4-
propagator_name(f, propname)
5-
6-
Determines a reasonable name for the propagator function.
7-
The name doesn't really matter too much as it is a local function to be returned
8-
by `frule` or `rrule`, but a good name make debugging easier.
9-
`f` should be some form of AST representation of the actual function,
10-
`propname` should be either `:pullback` or `:pushforward`
11-
12-
This is able to deal with fairly complex expressions for `f`:
13-
14-
julia> propagator_name(:bar, :pushforward)
15-
:bar_pushforward
16-
17-
julia> propagator_name(esc(:(Base.Random.foo)), :pullback)
18-
:foo_pullback
19-
"""
20-
propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname)
21-
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
22-
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)
23-
24-
25-
"""
26-
propagation_expr(𝒟, Δs, ∂s)
27-
28-
Returns the expression for the propagation of
29-
the input gradient `Δs` though the partials `∂s`.
30-
31-
𝒟 is an expression that when evaluated returns the type-of the input domain.
32-
For example if the derivative is being taken at the point `1` it returns `Int`.
33-
if it is taken at `1+1im` it returns `Complex{Int}`.
34-
At present it is ignored for non-Wirtinger derivatives.
35-
"""
36-
function propagation_expr(𝒟, Δs, ∂s)
37-
wirtinger_indices = findall(∂s) do ex
38-
Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger
39-
end
40-
∂s = map(esc, ∂s)
41-
if isempty(wirtinger_indices)
42-
return standard_propagation_expr(Δs, ∂s)
43-
else
44-
return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
45-
end
46-
end
47-
48-
function standard_propagation_expr(Δs, ∂s)
49-
# This is basically Δs ⋅ ∂s
50-
51-
# Notice: the thunking of `∂s[i] (potentially) saves us some computation
52-
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
53-
# as the pullback is evaluated
54-
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
55-
return :(+($(∂_mul_Δs...)))
56-
end
57-
58-
function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
59-
∂_mul_Δs_primal = Any[]
60-
∂_mul_Δs_conjugate = Any[]
61-
∂_wirtinger_defs = Any[]
62-
for i in 1:length(∂s)
63-
if i in wirtinger_indices
64-
Δi = Δs[i]
65-
∂i = Symbol(string(:∂, i))
66-
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
67-
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
68-
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
69-
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
70-
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
71-
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
72-
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
73-
else
74-
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
75-
push!(∂_mul_Δs_primal, ∂_mul_Δ)
76-
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
77-
end
78-
end
79-
primal_sum = :(+($(∂_mul_Δs_primal...)))
80-
conjugate_sum = :(+($(∂_mul_Δs_conjugate...)))
81-
return quote # This will be a block, so will have value equal to last statement
82-
$(∂_wirtinger_defs...)
83-
w = Wirtinger($primal_sum, $conjugate_sum)
84-
refine_differential($𝒟, w)
85-
end
86-
end
87-
883
"""
894
@scalar_rule(f(x₁, x₂, ...),
905
@setup(statement₁, statement₂, ...),
@@ -151,9 +66,49 @@ is equivalent to:
15166
15267
For examples, see ChainRulesCore' `rules` directory.
15368
154-
See also: [`frule`](@ref), [`rrule`](@ref).
69+
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
15570
"""
15671
macro scalar_rule(call, maybe_setup, partials...)
72+
call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input(
73+
call, maybe_setup, partials
74+
)
75+
f = call.args[1]
76+
77+
# An expression that when evaluated will return the type of the input domain.
78+
# Multiple repetitions of this expression should optimize out. But if it does not then
79+
# may need to move its definition into the body of the `rrule`/`frule`
80+
𝒟 = :(typeof(first(promote($(call.args[2:end]...)))))
81+
82+
frule_expr = scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
83+
rrule_expr = scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
84+
85+
86+
############################################################################
87+
# Final return: building the expression to insert in the place of this macro
88+
code = quote
89+
if !($f isa Type) && fieldcount(typeof($f)) > 0
90+
throw(ArgumentError(
91+
"@scalar_rule cannot be used on closures/functors (such as $($f))"
92+
))
93+
end
94+
95+
$(frule_expr)
96+
$(rrule_expr)
97+
end
98+
end
99+
100+
101+
"""
102+
_normalize_scalarrules_macro_input(call, maybe_setup, partials)
103+
104+
returns (in order) the correctly escaped:
105+
- `call` with out any type constraints
106+
- `setup_stmts`: the content of `@setup` or `nothing` if that is not provided,
107+
- `inputs`: with all args having the constraints removed from call, or
108+
defaulting to `Number`
109+
- `partials`: which are all `Expr{:tuple,...}`
110+
"""
111+
function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
157112
############################################################################
158113
# Setup: normalizing input form etc
159114

@@ -164,12 +119,12 @@ macro scalar_rule(call, maybe_setup, partials...)
164119
partials = (maybe_setup, partials...)
165120
end
166121
@assert Meta.isexpr(call, :call)
167-
f = esc(call.args[1])
168122

169123
# Annotate all arguments in the signature as scalars
170124
inputs = map(call.args[2:end]) do arg
171125
esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number))
172126
end
127+
173128
# Remove annotations and escape names for the call
174129
for (i, arg) in enumerate(call.args)
175130
if Meta.isexpr(arg, :(::))
@@ -189,78 +144,153 @@ macro scalar_rule(call, maybe_setup, partials...)
189144
end
190145
end
191146

192-
############################################################################
193-
# Main body: defining the results of the frule/rrule
194-
195-
# An expression that when evaluated will return the type of the input domain.
196-
# Multiple repetitions of this expression should optimize out. But if it does not then
197-
# may need to move its definition into the body of the `rrule`/`frule`
198-
𝒟 = :(typeof(first(promote($(call.args[2:end]...)))))
147+
return call, setup_stmts, inputs, partials
148+
end
199149

150+
function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
200151
n_outputs = length(partials)
201152
n_inputs = length(inputs)
202153

203-
pushforward = let
204-
# Δs is the input to the propagator rule
205-
# because this is push-forward there is one per input to the function
206-
Δs = [Symbol(string(, i)) for i in 1:n_inputs]
207-
pushforward_returns = map(1:n_outputs) do output_i
208-
∂s = partials[output_i].args
209-
propagation_expr(𝒟, Δs, ∂s)
210-
end
211-
if n_outputs > 1
212-
# For forward-mode we only return a tuple if output actually a tuple.
213-
pushforward_returns = Expr(:tuple, pushforward_returns...)
214-
else
215-
pushforward_returns = pushforward_returns[1]
216-
end
217-
quote
218-
# _ is the input derivative w.r.t. function internals. since we do not
219-
# allow closures/functors with @scalar_rule, it is always ignored
220-
function $(propagator_name(f, :pushforward))(_, $(Δs...))
221-
$pushforward_returns
222-
end
223-
end
154+
# Δs is the input to the propagator rule
155+
# because this is push-forward there is one per input to the function
156+
Δs = [Symbol(string(, i)) for i in 1:n_inputs]
157+
pushforward_returns = map(1:n_outputs) do output_i
158+
∂s = partials[output_i].args
159+
propagation_expr(𝒟, Δs, ∂s)
224160
end
225-
226-
pullback = let
227-
# Δs is the input to the propagator rule
228-
# because this is a pull-back there is one per output of function
229-
Δs = [Symbol(string(, i)) for i in 1:n_outputs]
230-
231-
# 1 partial derivative per input
232-
pullback_returns = map(1:n_inputs) do input_i
233-
∂s = [partial.args[input_i] for partial in partials]
234-
propagation_expr(𝒟, Δs, ∂s)
235-
end
236-
237-
quote
238-
function $(propagator_name(f, :pullback))($(Δs...))
239-
return (NO_FIELDS, $(pullback_returns...))
240-
end
241-
end
161+
if n_outputs > 1
162+
# For forward-mode we only return a tuple if output actually a tuple.
163+
pushforward_returns = Expr(:tuple, pushforward_returns...)
164+
else
165+
pushforward_returns = pushforward_returns[1]
242166
end
243167

244-
############################################################################
245-
# Final return: building the expression to insert in the place of this macro
246-
247-
code = quote
248-
if fieldcount(typeof($f)) > 0
249-
throw(ArgumentError(
250-
"@scalar_rule cannot be used on closures/functors (such as $f)"
251-
))
168+
pushforward = quote
169+
# _ is the input derivative w.r.t. function internals. since we do not
170+
# allow closures/functors with @scalar_rule, it is always ignored
171+
function $(propagator_name(f, :pushforward))(_, $(Δs...))
172+
$pushforward_returns
252173
end
174+
end
253175

176+
return quote
254177
function ChainRulesCore.frule(::typeof($f), $(inputs...))
255178
$(esc()) = $call
256179
$(setup_stmts...)
257180
return $(esc()), $pushforward
258181
end
182+
end
183+
end
184+
185+
function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
186+
n_outputs = length(partials)
187+
n_inputs = length(inputs)
188+
189+
# Δs is the input to the propagator rule
190+
# because this is a pull-back there is one per output of function
191+
Δs = [Symbol(string(, i)) for i in 1:n_outputs]
192+
193+
# 1 partial derivative per input
194+
pullback_returns = map(1:n_inputs) do input_i
195+
∂s = [partial.args[input_i] for partial in partials]
196+
propagation_expr(𝒟, Δs, ∂s)
197+
end
198+
199+
pullback = quote
200+
function $(propagator_name(f, :pullback))($(Δs...))
201+
return (NO_FIELDS, $(pullback_returns...))
202+
end
203+
end
259204

205+
return quote
260206
function ChainRulesCore.rrule(::typeof($f), $(inputs...))
261207
$(esc()) = $call
262208
$(setup_stmts...)
263209
return $(esc()), $pullback
264210
end
265211
end
266212
end
213+
214+
"""
215+
propagation_expr(𝒟, Δs, ∂s)
216+
217+
Returns the expression for the propagation of
218+
the input gradient `Δs` though the partials `∂s`.
219+
220+
𝒟 is an expression that when evaluated returns the type-of the input domain.
221+
For example if the derivative is being taken at the point `1` it returns `Int`.
222+
if it is taken at `1+1im` it returns `Complex{Int}`.
223+
At present it is ignored for non-Wirtinger derivatives.
224+
"""
225+
function propagation_expr(𝒟, Δs, ∂s)
226+
wirtinger_indices = findall(∂s) do ex
227+
Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger
228+
end
229+
∂s = map(esc, ∂s)
230+
if isempty(wirtinger_indices)
231+
return standard_propagation_expr(Δs, ∂s)
232+
else
233+
return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
234+
end
235+
end
236+
237+
function standard_propagation_expr(Δs, ∂s)
238+
# This is basically Δs ⋅ ∂s
239+
240+
# Notice: the thunking of `∂s[i] (potentially) saves us some computation
241+
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
242+
# as the pullback is evaluated
243+
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
244+
return :(+($(∂_mul_Δs...)))
245+
end
246+
247+
function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
248+
∂_mul_Δs_primal = Any[]
249+
∂_mul_Δs_conjugate = Any[]
250+
∂_wirtinger_defs = Any[]
251+
for i in 1:length(∂s)
252+
if i in wirtinger_indices
253+
Δi = Δs[i]
254+
∂i = Symbol(string(:∂, i))
255+
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
256+
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
257+
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
258+
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
259+
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
260+
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
261+
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
262+
else
263+
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
264+
push!(∂_mul_Δs_primal, ∂_mul_Δ)
265+
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
266+
end
267+
end
268+
primal_sum = :(+($(∂_mul_Δs_primal...)))
269+
conjugate_sum = :(+($(∂_mul_Δs_conjugate...)))
270+
return quote # This will be a block, so will have value equal to last statement
271+
$(∂_wirtinger_defs...)
272+
w = Wirtinger($primal_sum, $conjugate_sum)
273+
refine_differential($𝒟, w)
274+
end
275+
end
276+
277+
"""
278+
propagator_name(f, propname)
279+
280+
Determines a reasonable name for the propagator function.
281+
The name doesn't really matter too much as it is a local function to be returned
282+
by `frule` or `rrule`, but a good name make debugging easier.
283+
`f` should be some form of AST representation of the actual function,
284+
`propname` should be either `:pullback` or `:pushforward`
285+
286+
This is able to deal with fairly complex expressions for `f`:
287+
288+
julia> propagator_name(:bar, :pushforward)
289+
:bar_pushforward
290+
291+
julia> propagator_name(esc(:(Base.Random.foo)), :pullback)
292+
:foo_pullback
293+
"""
294+
propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname)
295+
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
296+
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)

0 commit comments

Comments
 (0)