Skip to content

Commit de2bb62

Browse files
WIP work on new @scalar_rule
make real scalar rules work. correct @scalarrule forward rule return Wirtinger scalar working work WirtingerRule test as a test of @scalar_rule Fix spelling Co-Authored-By: simeonschaub <simeondavidschaub99@gmail.com> Oxford Comma Co-Authored-By: simeonschaub <simeondavidschaub99@gmail.com> spelling Co-Authored-By: Nick Robinson <npr251@gmail.com> docstring for propagator_name spelling Co-Authored-By: Nick Robinson <npr251@gmail.com>
1 parent 0b232a4 commit de2bb62

File tree

7 files changed

+213
-104
lines changed

7 files changed

+213
-104
lines changed

src/ChainRulesCore.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ChainRulesCore
22
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
33

44
export AbstractRule, Rule, frule, rrule
5+
export wirtinger_conjugate, wirtinger_primal, differential
56
export @scalar_rule, @thunk
67
export extern, cast, store!
78
export Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
@@ -11,5 +12,5 @@ include("differentials.jl")
1112
include("differential_arithmetic.jl")
1213
include("rule_types.jl")
1314
include("rules.jl")
14-
#include("rule_definition_tools.jl")
15+
include("rule_definition_tools.jl")
1516
end # module

src/differentials.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,22 @@ Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
230230

231231
"""
232232
NO_FIELDS
233+
233234
Constant for the reverse-mode derivative with respect to a structure that has no fields.
234235
The most notable use for this is for the reverse-mode derivative with respect to the
235236
function itself, when that function is not a closure.
236237
"""
237238
const NO_FIELDS = DNE()
239+
240+
####
241+
"""
242+
differential(𝒟::Type, der)
243+
244+
For some differential (e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
245+
convert it to another differential that is more suited for the domain given by
246+
the type 𝒟.
247+
"""
248+
function differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger)
249+
return wirtinger_primal(w) + wirtinger_conjugate(w)
250+
end
251+
differential(::Any, der) = der # most of the time leave it alone.

src/rule_definition_tools.jl

Lines changed: 137 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,89 @@
11
# These are some macros (and supporting functions) to make it easier to define rules.
22

3+
"""
4+
propagator_name(f, propname)
5+
Determines a reasonable name for the propagator function.
6+
The name doesn't really matter too much as it is a local function to be returned
7+
by `frule` or `rrule`, but a good name make debugging easier.
8+
`f` should be some form of AST representation of the actual function,
9+
`propname` should be either `:pullback` or `:pushforward`
10+
11+
This is able to deal with fairly complex expressions for `f`:
12+
13+
julia> propagator_name(:bar, :pushforward)
14+
:bar_pushforward
15+
16+
julia> propagator_name(esc(:(Base.Random.foo)), :pullback)
17+
:foo_pullback
18+
"""
19+
propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname)
20+
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
21+
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)
22+
23+
24+
"""
25+
propagation_expr(𝒟, Δs, ∂s)
26+
27+
Returns the expression for the propagation of
28+
the input gradient `Δs` though the partials `∂s`.
29+
30+
𝒟 is an expression that when evaluated returns the type-of the input domain.
31+
For example if the derivative is being taken at the point `1` it returns `Int`.
32+
if it is taken at `1+1im` it returns `Complex{Int}`.
33+
At present it is ignored for non-Wirtinger derivatives.
34+
"""
35+
function propagation_expr(𝒟, Δs, ∂s)
36+
wirtinger_indices = findall(∂s) do ex
37+
Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger
38+
end
39+
∂s = map(esc, ∂s)
40+
if isempty(wirtinger_indices)
41+
return standard_propagation_expr(Δs, ∂s)
42+
else
43+
return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
44+
end
45+
end
46+
47+
function standard_propagation_expr(Δs, ∂s)
48+
# This is basically Δs ⋅ ∂s
49+
50+
# Notice: the thunking of `∂s[i] (potentially) saves us some computation
51+
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
52+
# as the pullback is evaluated
53+
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
54+
return :(+($(∂_mul_Δs...)))
55+
end
56+
57+
function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
58+
∂_mul_Δs_primal = Any[]
59+
∂_mul_Δs_conjugate = Any[]
60+
∂_wirtinger_defs = Any[]
61+
for i in 1:length(∂s)
62+
if i in wirtinger_indices
63+
Δi = Δs[i]
64+
∂i = Symbol(string(:∂, i))
65+
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
66+
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
67+
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
68+
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
69+
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
70+
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
71+
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
72+
else
73+
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
74+
push!(∂_mul_Δs_primal, ∂_mul_Δ)
75+
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
76+
end
77+
end
78+
primal_sum = :(+($(∂_mul_Δs_primal...)))
79+
conjugate_sum = :(+($(∂_mul_Δs_conjugate...)))
80+
return quote # This will be a block, so will have value equal to last statement
81+
$(∂_wirtinger_defs...)
82+
w = Wirtinger($primal_sum, $conjugate_sum)
83+
differential($𝒟, w)
84+
end
85+
end
86+
387
"""
488
@scalar_rule(f(x₁, x₂, ...),
589
@setup(statement₁, statement₂, ...),
@@ -42,7 +126,7 @@ e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x
42126
At present this does not support defining for closures/functors.
43127
Thus in reverse-mode, the first returned partial,
44128
representing the derivative with respect to the function itself, is always `NO_FIELDS`.
45-
And in forwards-mode, the first input to the returned propergator is always ignored.
129+
And in forward-mode, the first input to the returned propagator is always ignored.
46130
47131
The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
48132
allows the primal result to be conveniently referenced (as `Ω`) within the
@@ -69,6 +153,9 @@ For examples, see ChainRulesCore' `rules` directory.
69153
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
70154
"""
71155
macro scalar_rule(call, maybe_setup, partials...)
156+
############################################################################
157+
# Setup: normalizing input form etc
158+
72159
if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup")
73160
setup_stmts = map(esc, maybe_setup.args[3:end])
74161
else
@@ -77,6 +164,7 @@ macro scalar_rule(call, maybe_setup, partials...)
77164
end
78165
@assert Meta.isexpr(call, :call)
79166
f = esc(call.args[1])
167+
80168
# Annotate all arguments in the signature as scalars
81169
inputs = map(call.args[2:end]) do arg
82170
esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number))
@@ -90,6 +178,7 @@ macro scalar_rule(call, maybe_setup, partials...)
90178
end
91179
end
92180

181+
# For consistency in code that follows we make all partials tuple expressions
93182
partials = map(partials) do partial
94183
if Meta.isexpr(partial, :tuple)
95184
partial
@@ -98,59 +187,58 @@ macro scalar_rule(call, maybe_setup, partials...)
98187
Expr(:tuple, partial)
99188
end
100189
end
101-
@show partials
102-
103-
############################################################
104-
# Make pullback
105-
#(TODO: move to own function)
106-
# TODO: Wirtinger
107-
108-
Δs = [Symbol(string(, i)) for i in 1:length(partials)]
109-
pullback_returns = map(eachindex(inputs)) do input_i
110-
∂s = [partials.args[input_i] for partial in partials]
111-
∂s = map(esc, ∂s)
112-
113-
# Notice: the thunking of `∂s[i] (potentially) saves us some computation
114-
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
115-
# as the pullback is evaluated
116-
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
117-
:(+($(∂_mul_Δs...)))
118-
else
119190

120-
pullback = quote
121-
function $(Symbol(nameof(f), :_pullback))($(Δs...))
122-
return (ChainRulesCore.NO_FIELDS, $(pullback_returns...))
191+
############################################################################
192+
# Main body: defining the results of the frule/rrule
193+
194+
# An expression that when evaluated will return the type of the input domain.
195+
# Multiple repetitions of this expression should optimize ot. But if it does not then
196+
# may need to move its definition into the body of the `rrule`/`frule`
197+
𝒟 = :(typeof(first(promote($(call.args[2:end]...)))))
198+
199+
n_outputs = length(partials)
200+
n_inputs = length(inputs)
201+
202+
pushforward = let
203+
# Δs is the input to the propagator rule
204+
# because this is push-forward there is one per input to the function
205+
Δs = [Symbol(string(, i)) for i in 1:n_inputs]
206+
pushforward_returns = map(1:n_outputs) do output_i
207+
∂s = partials[output_i].args
208+
propagation_expr(𝒟, Δs, ∂s)
123209
end
124-
end
125210

126-
########################################
127-
quote
128-
function ChainRulesCore.rrule(::typeof($f), $(inputs...))
129-
$(esc()) = $call
130-
$(setup_stmts...)
131-
return $(esc()), $esc(pullback)
211+
quote
212+
# _ is the input derivative w.r.t. function internals. since we do not
213+
# allow closures/functors with @scalar_rule, it is always ignored
214+
function $(propagator_name(f, :pushforward))(_, $(Δs...))
215+
return $(Expr(:tuple, pushforward_returns...))
216+
end
132217
end
133218
end
134-
end
135-
#==
136-
if !all(Meta.isexpr(partial, :tuple) for partial in partials)
137-
input_rep = :(first(promote($(inputs...)))) # stand-in with the right type for an input
138-
forward_rules = Any[rule_from_partials(input_rep, partial.args...) for partial in partials]
139-
reverse_rules = map(1:length(inputs) do i
140-
reverse_partials = [partial.args[i] for partial in partials]
141-
push!(reverse_rules, rule_from_partials(inputs[i], reverse_partials...))
219+
220+
pullback = let
221+
# Δs is the input to the propagator rule
222+
# because this is a pull-back there is one per output of function
223+
Δs = [Symbol(string(, i)) for i in 1:n_outputs]
224+
225+
# 1 partial derivative per input
226+
pullback_returns = map(1:n_inputs) do input_i
227+
∂s = [partial.args[input_i] for partial in partials]
228+
propagation_expr(𝒟, Δs, ∂s)
229+
end
230+
231+
quote
232+
function $(propagator_name(f, :pullback))($(Δs...))
233+
return (NO_FIELDS, $(pullback_returns...))
234+
end
142235
end
143-
else
144-
@assert length(inputs) == 1 && all(!Meta.isexpr(partial, :tuple) for partial in partials)
145-
forward_rules = Any[rule_from_partials(inputs[1], partial) for partial in partials]
146-
reverse_rules = Any[rule_from_partials(inputs[1], partials...)]
147236
end
148237

149-
# First pseudo-partial is derivative WRT function itself. Since this macro does not
150-
# support closures, it is just the empty NamedTuple
151-
forward_rules = Expr(:tuple, ZERO_RULE, forward_rules...)
152-
reverse_rules = Expr(:tuple, NO_FIELDS, reverse_rules...)
153-
return quote
238+
############################################################################
239+
# Final return: building the expression to insert in the place of this macro
240+
241+
code = quote
154242
if fieldcount(typeof($f)) > 0
155243
throw(ArgumentError(
156244
"@scalar_rule cannot be used on closures/functors (such as $f)"
@@ -160,57 +248,13 @@ end
160248
function ChainRulesCore.frule(::typeof($f), $(inputs...))
161249
$(esc()) = $call
162250
$(setup_stmts...)
163-
return $(esc(:Ω)), $forward_rules
251+
return $(esc()), $pushforward
164252
end
253+
165254
function ChainRulesCore.rrule(::typeof($f), $(inputs...))
166255
$(esc()) = $call
167256
$(setup_stmts...)
168-
return $(esc(:Ω)), $reverse_rules
169-
end
170-
end
171-
end
172-
==#
173-
174-
@macroexpand(@scalar_rule(one(x), Zero()))
175-
176-
177-
178-
#==
179-
function rule_from_partials(input_arg, ∂s...)
180-
wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s)
181-
∂s = map(esc, ∂s)
182-
Δs = [Symbol(string(:Δ, i)) for i in 1:length(∂s)]
183-
Δs_tuple = Expr(:tuple, Δs...)
184-
if isempty(wirtinger_indices)
185-
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
186-
return :(Rule($Δs_tuple -> +($(∂_mul_Δs...))))
187-
else
188-
∂_mul_Δs_primal = Any[]
189-
∂_mul_Δs_conjugate = Any[]
190-
∂_wirtinger_defs = Any[]
191-
for i in 1:length(∂s)
192-
if i in wirtinger_indices
193-
Δi = Δs[i]
194-
∂i = Symbol(string(:∂, i))
195-
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
196-
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
197-
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
198-
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
199-
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
200-
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
201-
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
202-
else
203-
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
204-
push!(∂_mul_Δs_primal, ∂_mul_Δ)
205-
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
206-
end
207-
end
208-
primal_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_primal...))))
209-
conjugate_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_conjugate...))))
210-
return quote
211-
$(∂_wirtinger_defs...)
212-
AbstractRule(typeof($input_arg), $primal_rule, $conjugate_rule)
257+
return $(esc()), $pullback
213258
end
214259
end
215260
end
216-
==#

test/differentials.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,20 @@
7979
]
8080
@test isempty(ambig_methods)
8181
end
82+
83+
84+
@testset "Differential" begin
85+
@test differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2)
86+
@test differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2)
87+
88+
@test differential(typeof(1.2), Wirtinger(2,2)) == 4
89+
@test differential(typeof([1.2]), Wirtinger(2,2)) == 4
90+
91+
# For most differentials, in most domains, this does nothing
92+
for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0)
93+
for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2]))
94+
@test differential(𝒟, der) === der
95+
end
96+
end
97+
end
8298
end

test/rule_types.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11

22
@testset "rule types" begin
3-
#==
3+
# The following is deprecated and should be remove next release
44
@testset "iterating and indexing rules" begin
5-
_, rule = frule(dummy_identity, 1)
5+
rule = Rule(identity)
66
i = 0
77
for r in rule
88
@test r === rule
@@ -12,8 +12,8 @@
1212
@test rule[1] == rule
1313
@test_throws BoundsError rule[2]
1414
end
15-
==#
16-
15+
16+
1717
@testset "Rule" begin
1818
@testset "show" begin
1919
@test occursin(r"^Rule\(.*foo.*\)$", repr(Rule(function foo() 1 end)))

0 commit comments

Comments
 (0)