Skip to content

Commit 049695a

Browse files
authored
Remove all uses of Cassette and define the arithmetic directly (#35)
* =Remove all uses of Cassette and difine the arithmetic directly * =Remove use of add from accumulate and define single-arg + * Update test/rule_types.jl * Update src/differential_arithmetic.jl * Update test/rules.jl * clean up setup code * use infix * make ambiguity test not count as 20,000 tests * remove whitespace * fix and confirm fixed typos in Wirtinger arthmetic
1 parent b004320 commit 049695a

10 files changed

+167
-153
lines changed

Project.toml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.1.1"
4-
5-
[deps]
6-
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
3+
version = "0.2.0"
74

85
[compat]
9-
Cassette = "^0.2"
106
julia = "^1.0"
117

128
[extras]

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
module ChainRulesCore
2-
using Cassette
32
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
43

54
export AbstractRule, Rule, frule, rrule
65
export @scalar_rule, @thunk
76
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
87

98
include("differentials.jl")
9+
include("differential_arithmetic.jl")
1010
include("rule_types.jl")
1111
include("rules.jl")
1212
include("rule_definition_tools.jl")

src/differential_arithmetic.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#==
2+
All differentials need to define + and *.
3+
That happens here.
4+
5+
We just use @eval to define all the combinations for AbstractDifferential
6+
subtypes, as we know the full set that might be encountered.
7+
Thus we can avoid any ambiguities.
8+
9+
Notice:
10+
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :Any)
11+
Thus each of the @eval loops creating definitions of + and *
12+
defines the combination this type with all types of lower precidence.
13+
This means each eval loops is 1 item smaller than the previous.
14+
==#
15+
16+
17+
function Base.:*(a::Wirtinger, b::Wirtinger)
18+
error("""
19+
Cannot multiply two Wirtinger objects; this error likely means a
20+
`WirtingerRule` was inappropriately defined somewhere. Multiplication
21+
of two Wirtinger objects is not defined because chain rule application
22+
often expands into a non-commutative operation in the Wirtinger
23+
calculus. To put it another way: simply given two Wirtinger objects
24+
and no other information, we can't know "locally" which components to
25+
conjugate in order to implement the chain rule. We could pick a
26+
convention; for example, we could define `a::Wirtinger * b::Wirtinger`
27+
such that we assume the chain rule application is of the form `f_a ∘ f_b`
28+
instead of `f_b ∘ f_a`. However, picking such a convention is likely to
29+
lead to silently incorrect derivatives due to commutativity assumptions
30+
in downstream generic code that deals with the reals. Thus, ChainRulesCore
31+
makes this operation an error instead.
32+
""")
33+
end
34+
35+
function Base.:+(a::Wirtinger, b::Wirtinger)
36+
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
37+
end
38+
39+
for T in (:Casted, :Zero, :DNE, :One, :Thunk, :Any)
40+
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
41+
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
42+
43+
@eval Base.:*(a::Wirtinger, b::$T) = Wirtinger(a.primal * b, a.conjugate * b)
44+
@eval Base.:*(a::$T, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate)
45+
end
46+
47+
48+
Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
49+
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
50+
for T in (:Zero, :DNE, :One, :Thunk, :Any)
51+
@eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b))
52+
@eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value))
53+
54+
@eval Base.:*(a::Casted, b::$T) = Casted(broadcasted(*, a.value, b))
55+
@eval Base.:*(a::$T, b::Casted) = Casted(broadcasted(*, a, b.value))
56+
end
57+
58+
59+
Base.:+(::Zero, b::Zero) = Zero()
60+
Base.:*(::Zero, ::Zero) = Zero()
61+
for T in (:DNE, :One, :Thunk, :Any)
62+
@eval Base.:+(::Zero, b::$T) = b
63+
@eval Base.:+(a::$T, ::Zero) = a
64+
65+
@eval Base.:*(::Zero, ::$T) = Zero()
66+
@eval Base.:*(::$T, ::Zero) = Zero()
67+
end
68+
69+
70+
Base.:+(::DNE, ::DNE) = DNE()
71+
Base.:*(::DNE, ::DNE) = DNE()
72+
for T in (:One, :Thunk, :Any)
73+
@eval Base.:+(::DNE, b::$T) = b
74+
@eval Base.:+(a::$T, ::DNE) = a
75+
76+
@eval Base.:*(::DNE, ::$T) = DNE()
77+
@eval Base.:*(::$T, ::DNE) = DNE()
78+
end
79+
80+
81+
Base.:+(a::One, b::One) = extern(a) + extern(b)
82+
Base.:*(::One, ::One) = One()
83+
for T in (:Thunk, :Any)
84+
@eval Base.:+(a::One, b::$T) = extern(a) + b
85+
@eval Base.:+(a::$T, b::One) = a + extern(b)
86+
87+
@eval Base.:*(::One, b::$T) = b
88+
@eval Base.:*(a::$T, ::One) = a
89+
end
90+
91+
92+
Base.:+(a::Thunk, b::Thunk) = extern(a) + extern(b)
93+
Base.:*(a::Thunk, b::Thunk) = extern(a) * extern(b)
94+
for T in (:Any,) #This loop is redundant but for consistency...
95+
@eval Base.:+(a::Thunk, b::$T) = extern(a) + b
96+
@eval Base.:+(a::$T, b::Thunk) = a + extern(b)
97+
98+
@eval Base.:*(a::Thunk, b::$T) = extern(a) * b
99+
@eval Base.:*(a::$T, b::Thunk) = a * extern(b)
100+
end

src/differentials.jl

Lines changed: 4 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ support, broadcast fusion, zero-elision, etc. into nicely separated parts.
99
1010
All subtypes of `AbstractDifferential` implement the following operations:
1111
12-
`add(a, b)`: linearly combine differential `a` and differential `b`
12+
`+(a, b)`: linearly combine differential `a` and differential `b`
1313
14-
`mul(a, b)`: multiply the differential `a` by the differential `b`
14+
`*(a, b)`: multiply the differential `a` by the differential `b`
1515
1616
`Base.conj(x)`: complex conjugate of the differential `x`
1717
@@ -26,6 +26,8 @@ Additionally, all subtypes of `AbstractDifferential` support `Base.iterate` and
2626
"""
2727
abstract type AbstractDifferential end
2828

29+
Base.:+(x::AbstractDifferential) = x
30+
2931
"""
3032
extern(x)
3133
@@ -39,40 +41,6 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
3941

4042
@inline Base.conj(x::AbstractDifferential) = x
4143

42-
#=
43-
This `AbstractDifferential` algebra has a monad-y "fallthrough" implementation;
44-
each step handles an element of the algebra before dispatching to the next step.
45-
This way, we don't need to implement promotion/conversion rules between subtypes
46-
of `AbstractDifferential` to resolve potential ambiguities.
47-
=#
48-
49-
const PRECEDENCE_LIST = [:wirtinger, :casted, :zero, :dne, :one, :thunk, :fallback]
50-
51-
global defs = Expr(:block)
52-
53-
let previous_add_name = :add, previous_mul_name = :mul
54-
for name in PRECEDENCE_LIST
55-
next_add_name = Symbol(string(:add_, name))
56-
next_mul_name = Symbol(string(:mul_, name))
57-
push!(defs.args, quote
58-
@inline $(previous_add_name)(a, b) = $(next_add_name)(a, b)
59-
@inline $(previous_mul_name)(a, b) = $(next_mul_name)(a, b)
60-
end)
61-
previous_add_name = next_add_name
62-
previous_mul_name = next_mul_name
63-
end
64-
end
65-
66-
eval(defs)
67-
68-
@inline add_fallback(a, b) = a + b
69-
70-
@inline mul_fallback(a, b) = a * b
71-
72-
@inline add(x) = x
73-
74-
@inline mul(x) = x
75-
7644
#####
7745
##### `Wirtinger`
7846
#####
@@ -120,33 +88,6 @@ Base.iterate(::Wirtinger, ::Any) = nothing
12088

12189
Base.conj(x::Wirtinger) = error("`conj(::Wirtinger)` not yet defined")
12290

123-
function add_wirtinger(a::Wirtinger, b::Wirtinger)
124-
return Wirtinger(add(a.primal, b.primal), add(a.conjugate, b.conjugate))
125-
end
126-
127-
add_wirtinger(a::Wirtinger, b) = add(a, Wirtinger(b, Zero()))
128-
add_wirtinger(a, b::Wirtinger) = add(Wirtinger(a, Zero()), b)
129-
130-
function mul_wirtinger(a::Wirtinger, b::Wirtinger)
131-
error("""
132-
cannot multiply two Wirtinger objects; this error likely means a
133-
`WirtingerRule` was inappropriately defined somewhere. Multiplication
134-
of two Wirtinger objects is not defined because chain rule application
135-
often expands into a non-commutative operation in the Wirtinger
136-
calculus. To put it another way: simply given two Wirtinger objects
137-
and no other information, we can't know "locally" which components to
138-
conjugate in order to implement the chain rule. We could pick a
139-
convention; for example, we could define `a::Wirtinger * b::Wirtinger`
140-
such that we assume the chain rule application is of the form `f_a ∘ f_b`
141-
instead of `f_b ∘ f_a`. However, picking such a convention is likely to
142-
lead to silently incorrect derivatives due to commutativity assumptions
143-
in downstream generic code that deals with the reals. Thus, ChainRulesCore
144-
makes this operation an error instead.
145-
""")
146-
end
147-
148-
mul_wirtinger(a::Wirtinger, b) = Wirtinger(mul(a.primal, b), mul(a.conjugate, b))
149-
mul_wirtinger(a, b::Wirtinger) = Wirtinger(mul(a, b.primal), mul(a, b.conjugate))
15091

15192
#####
15293
##### `Casted`
@@ -174,14 +115,6 @@ Base.iterate(x::Casted, state) = iterate(x.value, state)
174115

175116
Base.conj(x::Casted) = cast(conj, x.value)
176117

177-
add_casted(a::Casted, b::Casted) = Casted(broadcasted(add, a.value, b.value))
178-
add_casted(a::Casted, b) = Casted(broadcasted(add, a.value, b))
179-
add_casted(a, b::Casted) = Casted(broadcasted(add, a, b.value))
180-
181-
mul_casted(a::Casted, b::Casted) = Casted(broadcasted(mul, a.value, b.value))
182-
mul_casted(a::Casted, b) = Casted(broadcasted(mul, a.value, b))
183-
mul_casted(a, b::Casted) = Casted(broadcasted(mul, a, b.value))
184-
185118
#####
186119
##### `Zero`
187120
#####
@@ -200,13 +133,6 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
200133
Base.iterate(x::Zero) = (x, nothing)
201134
Base.iterate(::Zero, ::Any) = nothing
202135

203-
add_zero(::Zero, ::Zero) = Zero()
204-
add_zero(::Zero, b) = b
205-
add_zero(a, ::Zero) = a
206-
207-
mul_zero(::Zero, ::Zero) = Zero()
208-
mul_zero(::Zero, ::Any) = Zero()
209-
mul_zero(::Any, ::Zero) = Zero()
210136

211137
#####
212138
##### `DNE`
@@ -228,14 +154,6 @@ Base.Broadcast.broadcastable(::DNE) = Ref(DNE())
228154
Base.iterate(x::DNE) = (x, nothing)
229155
Base.iterate(::DNE, ::Any) = nothing
230156

231-
add_dne(::DNE, ::DNE) = DNE()
232-
add_dne(::DNE, b) = b
233-
add_dne(a, ::DNE) = a
234-
235-
mul_dne(::DNE, ::DNE) = DNE()
236-
mul_dne(::DNE, ::Any) = DNE()
237-
mul_dne(::Any, ::DNE) = DNE()
238-
239157
#####
240158
##### `One`
241159
#####
@@ -254,13 +172,6 @@ Base.Broadcast.broadcastable(::One) = Ref(One())
254172
Base.iterate(x::One) = (x, nothing)
255173
Base.iterate(::One, ::Any) = nothing
256174

257-
add_one(a::One, b::One) = add(extern(a), extern(b))
258-
add_one(a::One, b) = add(extern(a), b)
259-
add_one(a, b::One) = add(a, extern(b))
260-
261-
mul_one(::One, ::One) = One()
262-
mul_one(::One, b) = b
263-
mul_one(a, ::One) = a
264175

265176
#####
266177
##### `Thunk`
@@ -295,11 +206,3 @@ end
295206
end
296207

297208
Base.conj(x::Thunk) = @thunk(conj(extern(x)))
298-
299-
add_thunk(a::Thunk, b::Thunk) = add(extern(a), extern(b))
300-
add_thunk(a::Thunk, b) = add(extern(a), b)
301-
add_thunk(a, b::Thunk) = add(a, extern(b))
302-
303-
mul_thunk(a::Thunk, b::Thunk) = mul(extern(a), extern(b))
304-
mul_thunk(a::Thunk, b) = mul(extern(a), b)
305-
mul_thunk(a, b::Thunk) = mul(a, extern(b))

src/rule_definition_tools.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ function rule_from_partials(input_arg, ∂s...)
114114
Δs = [Symbol(string(, i)) for i in 1:length(∂s)]
115115
Δs_tuple = Expr(:tuple, Δs...)
116116
if isempty(wirtinger_indices)
117-
∂_mul_Δs = [:(mul(@thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
118-
return :(Rule($Δs_tuple -> add($(∂_mul_Δs...))))
117+
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
118+
return :(Rule($Δs_tuple -> +($(∂_mul_Δs...))))
119119
else
120120
∂_mul_Δs_primal = Any[]
121121
∂_mul_Δs_conjugate = Any[]
@@ -125,20 +125,20 @@ function rule_from_partials(input_arg, ∂s...)
125125
Δi = Δs[i]
126126
∂i = Symbol(string(:∂, i))
127127
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
128-
∂f∂i_mul_Δ = :(mul(wirtinger_primal($∂i), wirtinger_primal($Δi)))
129-
∂f∂ī_mul_Δ̄ = :(mul(conj(wirtinger_conjugate($∂i)), wirtinger_conjugate($Δi)))
130-
∂f̄∂i_mul_Δ = :(mul(wirtinger_conjugate($∂i), wirtinger_primal($Δi)))
131-
∂f̄∂ī_mul_Δ̄ = :(mul(conj(wirtinger_primal($∂i)), wirtinger_conjugate($Δi)))
132-
push!(∂_mul_Δs_primal, :(add($∂f∂i_mul_Δ, $∂f∂ī_mul_Δ̄)))
133-
push!(∂_mul_Δs_conjugate, :(add($∂f̄∂i_mul_Δ, $∂f̄∂ī_mul_Δ̄)))
128+
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
129+
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
130+
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
131+
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
132+
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
133+
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
134134
else
135-
∂_mul_Δ = :(mul(@thunk($(∂s[i])), $(Δs[i])))
135+
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
136136
push!(∂_mul_Δs_primal, ∂_mul_Δ)
137137
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
138138
end
139139
end
140-
primal_rule = :(Rule($Δs_tuple -> add($(∂_mul_Δs_primal...))))
141-
conjugate_rule = :(Rule($Δs_tuple -> add($(∂_mul_Δs_conjugate...))))
140+
primal_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_primal...))))
141+
conjugate_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_conjugate...))))
142142
return quote
143143
$(∂_wirtinger_defs...)
144144
AbstractRule(typeof($input_arg), $primal_rule, $conjugate_rule)

src/rule_types.jl

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementa
7878
7979
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
8080
"""
81-
accumulate(Δ, rule::AbstractRule, args...) = add(Δ, rule(args...))
81+
accumulate(Δ, rule::AbstractRule, args...) = Δ + rule(args...)
8282

8383
"""
8484
accumulate!(Δ, rule::AbstractRule, args...)
@@ -91,7 +91,7 @@ Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
9191
See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
9292
"""
9393
function accumulate!(Δ, rule::AbstractRule, args...)
94-
return materialize!(Δ, broadcastable(add(cast(Δ), rule(args...))))
94+
return materialize!(Δ, broadcastable(cast(Δ) + rule(args...)))
9595
end
9696

9797
accumulate!::Number, rule::AbstractRule, args...) = accumulate(Δ, rule, args...)
@@ -116,15 +116,6 @@ store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(ar
116116
##### `Rule`
117117
#####
118118

119-
Cassette.@context RuleContext
120-
121-
const RULE_CONTEXT = Cassette.disablehooks(RuleContext())
122-
123-
Cassette.overdub(::RuleContext, ::typeof(+), a, b) = add(a, b)
124-
Cassette.overdub(::RuleContext, ::typeof(*), a, b) = mul(a, b)
125-
126-
Cassette.overdub(::RuleContext, ::typeof(add), a, b) = add(a, b)
127-
Cassette.overdub(::RuleContext, ::typeof(mul), a, b) = mul(a, b)
128119

129120
"""
130121
Rule(propation_function[, updating_function])
@@ -158,7 +149,7 @@ end
158149
# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}`
159150
Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)
160151

161-
(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...)
152+
(rule::Rule)(args...) = rule.f(args...)
162153

163154
Base.show(io::IO, rule::Rule{<:Any, Nothing}) = print(io, "Rule($(rule.f))")
164155
Base.show(io::IO, rule::Rule) = print(io, "Rule($(rule.f), $(rule.u))")
@@ -212,7 +203,7 @@ otherwise return `WirtingerRule(P, C)`.
212203
"""
213204
function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
214205
if 𝒟 <: Real || eltype(𝒟) <: Real
215-
return Rule((args...) -> add(primal(args...), conjugate(args...)))
206+
return Rule((args...) -> (primal(args...) + conjugate(args...)))
216207
else
217208
return WirtingerRule(primal, conjugate)
218209
end

0 commit comments

Comments
 (0)