Skip to content

Commit 0b232a4

Browse files
committed
Remove duplicate of rule_types from bbad rebase
WIP only one pullback with many partials (re #38)
1 parent 7789a67 commit 0b232a4

File tree

7 files changed

+93
-269
lines changed

7 files changed

+93
-269
lines changed

src/ChainRulesCore.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ export AbstractRule, Rule, frule, rrule
55
export @scalar_rule, @thunk
66
export extern, cast, store!
77
export Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
8-
export NO_FIELDS_RULE, ZERO_RULE
8+
export NO_FIELDS
99

1010
include("differentials.jl")
1111
include("differential_arithmetic.jl")
1212
include("rule_types.jl")
1313
include("rules.jl")
14-
include("rule_definition_tools.jl")
14+
#include("rule_definition_tools.jl")
1515
end # module

src/differentials.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,11 @@ end
227227
Base.conj(x::Thunk) = @thunk(conj(extern(x)))
228228

229229
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
230+
231+
"""
232+
NO_FIELDS
233+
Constant for the reverse-mode derivative with respect to a structure that has no fields.
234+
The most notable use for this is for the reverse-mode derivative with respect to the
235+
function itself, when that function is not a closure.
236+
"""
237+
const NO_FIELDS = DNE()

src/rule_definition_tools.jl

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,22 @@ methods for `frule` and `rrule`:
1414
function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...)
1515
Ω = f(x₁, x₂, ...)
1616
\$(statement₁, statement₂, ...)
17-
return Ω, (ZERO_RULE,
18-
Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
19-
Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
20-
...)
17+
return Ω, (_, Δx₁, Δx₂, ...) -> (
18+
(∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
19+
(∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
20+
...
21+
)
2122
end
2223
2324
function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
2425
Ω = f(x₁, x₂, ...)
2526
\$(statement₁, statement₂, ...)
26-
return Ω, (NO_FIELDS_RULE,
27-
Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
28-
Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
29-
...)
27+
return Ω, (ΔΩ₁, ΔΩ₂, ...) -> (
28+
NO_FIELDS,
29+
∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
30+
∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
31+
...
32+
)
3033
end
3134
3235
If no type constraints in `f(x₁, x₂, ...)` within the call to `@scalar_rule` are
@@ -36,10 +39,10 @@ Constraints may also be explicitly be provided to override the `Number` constrai
3639
e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to
3740
`Number`.
3841
39-
At present this does not support defining rules for closures/functors.
40-
This the first returned rule, representing the derivative with respect to the
41-
function itself, is always the `NO_FIELDS_RULE` (reverse-mode),
42-
or `ZERO_RULE` (forward-mode).
42+
At present this does not support defining for closures/functors.
43+
Thus in reverse-mode, the first returned partial,
44+
representing the derivative with respect to the function itself, is always `NO_FIELDS`.
45+
And in forwards-mode, the first input to the returned propergator is always ignored.
4346
4447
The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
4548
allows the primal result to be conveniently referenced (as `Ω`) within the
@@ -86,11 +89,54 @@ macro scalar_rule(call, maybe_setup, partials...)
8689
call.args[i] = esc(arg)
8790
end
8891
end
89-
if all(Meta.isexpr(partial, :tuple) for partial in partials)
92+
93+
partials = map(partials) do partial
94+
if Meta.isexpr(partial, :tuple)
95+
partial
96+
else
97+
@assert length(inputs) == 1
98+
Expr(:tuple, partial)
99+
end
100+
end
101+
@show partials
102+
103+
############################################################
104+
# Make pullback
105+
#(TODO: move to own function)
106+
# TODO: Wirtinger
107+
108+
Δs = [Symbol(string(, i)) for i in 1:length(partials)]
109+
pullback_returns = map(eachindex(inputs)) do input_i
110+
∂s = [partials.args[input_i] for partial in partials]
111+
∂s = map(esc, ∂s)
112+
113+
# Notice: the thunking of `∂s[i] (potentially) saves us some computation
114+
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
115+
# as the pullback is evaluated
116+
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
117+
:(+($(∂_mul_Δs...)))
118+
else
119+
120+
pullback = quote
121+
function $(Symbol(nameof(f), :_pullback))($(Δs...))
122+
return (ChainRulesCore.NO_FIELDS, $(pullback_returns...))
123+
end
124+
end
125+
126+
########################################
127+
quote
128+
function ChainRulesCore.rrule(::typeof($f), $(inputs...))
129+
$(esc()) = $call
130+
$(setup_stmts...)
131+
return $(esc()), $esc(pullback)
132+
end
133+
end
134+
end
135+
#==
136+
if !all(Meta.isexpr(partial, :tuple) for partial in partials)
90137
input_rep = :(first(promote($(inputs...)))) # stand-in with the right type for an input
91138
forward_rules = Any[rule_from_partials(input_rep, partial.args...) for partial in partials]
92-
reverse_rules = Any[]
93-
for i in 1:length(inputs)
139+
reverse_rules = map(1:length(inputs) do i
94140
reverse_partials = [partial.args[i] for partial in partials]
95141
push!(reverse_rules, rule_from_partials(inputs[i], reverse_partials...))
96142
end
@@ -103,7 +149,7 @@ macro scalar_rule(call, maybe_setup, partials...)
103149
# First pseudo-partial is derivative WRT function itself. Since this macro does not
104150
# support closures, it is just the empty NamedTuple
105151
forward_rules = Expr(:tuple, ZERO_RULE, forward_rules...)
106-
reverse_rules = Expr(:tuple, NO_FIELDS_RULE, reverse_rules...)
152+
reverse_rules = Expr(:tuple, NO_FIELDS, reverse_rules...)
107153
return quote
108154
if fieldcount(typeof($f)) > 0
109155
throw(ArgumentError(
@@ -123,7 +169,13 @@ macro scalar_rule(call, maybe_setup, partials...)
123169
end
124170
end
125171
end
172+
==#
173+
174+
@macroexpand(@scalar_rule(one(x), Zero()))
175+
176+
126177

178+
#==
127179
function rule_from_partials(input_arg, ∂s...)
128180
wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s)
129181
∂s = map(esc, ∂s)
@@ -161,3 +213,4 @@ function rule_from_partials(input_arg, ∂s...)
161213
end
162214
end
163215
end
216+
==#

src/rule_types.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,15 @@ See also: [`frule`](@ref), [`rrule`](@ref), [`Rule`](@ref), [`DNERule`](@ref), [
5151
abstract type AbstractRule end
5252

5353
# this ensures that consumers don't have to special-case rule destructuring
54-
Base.iterate(rule::AbstractRule) = (rule, nothing)
54+
Base.iterate(rule::AbstractRule) = (@warn "iterating rules is going away"; (rule, nothing))
5555
Base.iterate(::AbstractRule, ::Any) = nothing
5656

5757
# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple
5858
# in order to get the `i`th rule (assuming it's 1)
59-
Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError())
59+
function Base.getindex(rule::AbstractRule, i::Integer)
60+
@warn "iterating rules is going away"
61+
return i == 1 ? rule : throw(BoundsError())
62+
end
6063

6164
"""
6265
accumulate(Δ, rule::AbstractRule, args...)
@@ -78,7 +81,7 @@ accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementa
7881
7982
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
8083
"""
81-
accumulate(Δ, rule::AbstractRule, args...) = Δ + rule(args...)
84+
accumulate(Δ, rule, args...) = Δ + rule(args...)
8285

8386
"""
8487
accumulate!(Δ, rule::AbstractRule, args...)
@@ -90,11 +93,11 @@ Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
9093
9194
See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
9295
"""
93-
function accumulate!(Δ, rule::AbstractRule, args...)
96+
function accumulate!(Δ, rule, args...)
9497
return materialize!(Δ, broadcastable(cast(Δ) + rule(args...)))
9598
end
9699

97-
accumulate!::Number, rule::AbstractRule, args...) = accumulate(Δ, rule, args...)
100+
accumulate!::Number, rule, args...) = accumulate(Δ, rule, args...)
98101

99102
"""
100103
store!(Δ, rule::AbstractRule, args...)
@@ -110,7 +113,7 @@ to be customizable for specific rules/input types.
110113
111114
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
112115
"""
113-
store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...)))
116+
store!(Δ, rule, args...) = materialize!(Δ, broadcastable(rule(args...)))
114117

115118
#####
116119
##### `Rule`

0 commit comments

Comments
 (0)