@@ -9,9 +9,9 @@ support, broadcast fusion, zero-elision, etc. into nicely separated parts.
9
9
10
10
All subtypes of `AbstractDifferential` implement the following operations:
11
11
12
- `add (a, b)`: linearly combine differential `a` and differential `b`
12
+ `+ (a, b)`: linearly combine differential `a` and differential `b`
13
13
14
- `mul (a, b)`: multiply the differential `a` by the differential `b`
14
+ `* (a, b)`: multiply the differential `a` by the differential `b`
15
15
16
16
`Base.conj(x)`: complex conjugate of the differential `x`
17
17
@@ -26,6 +26,8 @@ Additionally, all subtypes of `AbstractDifferential` support `Base.iterate` and
26
26
"""
27
27
abstract type AbstractDifferential end
28
28
29
+ Base.:+ (x:: AbstractDifferential ) = x
30
+
29
31
"""
30
32
extern(x)
31
33
@@ -39,40 +41,6 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
39
41
40
42
@inline Base. conj (x:: AbstractDifferential ) = x
41
43
42
- #=
43
- This `AbstractDifferential` algebra has a monad-y "fallthrough" implementation;
44
- each step handles an element of the algebra before dispatching to the next step.
45
- This way, we don't need to implement promotion/conversion rules between subtypes
46
- of `AbstractDifferential` to resolve potential ambiguities.
47
- =#
48
-
49
- const PRECEDENCE_LIST = [:wirtinger , :casted , :zero , :dne , :one , :thunk , :fallback ]
50
-
51
- global defs = Expr (:block )
52
-
53
- let previous_add_name = :add , previous_mul_name = :mul
54
- for name in PRECEDENCE_LIST
55
- next_add_name = Symbol (string (:add_ , name))
56
- next_mul_name = Symbol (string (:mul_ , name))
57
- push! (defs. args, quote
58
- @inline $ (previous_add_name)(a, b) = $ (next_add_name)(a, b)
59
- @inline $ (previous_mul_name)(a, b) = $ (next_mul_name)(a, b)
60
- end )
61
- previous_add_name = next_add_name
62
- previous_mul_name = next_mul_name
63
- end
64
- end
65
-
66
- eval (defs)
67
-
68
- @inline add_fallback (a, b) = a + b
69
-
70
- @inline mul_fallback (a, b) = a * b
71
-
72
- @inline add (x) = x
73
-
74
- @inline mul (x) = x
75
-
76
44
# ####
77
45
# #### `Wirtinger`
78
46
# ####
@@ -120,33 +88,6 @@ Base.iterate(::Wirtinger, ::Any) = nothing
120
88
121
89
Base. conj (x:: Wirtinger ) = error (" `conj(::Wirtinger)` not yet defined" )
122
90
123
- function add_wirtinger (a:: Wirtinger , b:: Wirtinger )
124
- return Wirtinger (add (a. primal, b. primal), add (a. conjugate, b. conjugate))
125
- end
126
-
127
- add_wirtinger (a:: Wirtinger , b) = add (a, Wirtinger (b, Zero ()))
128
- add_wirtinger (a, b:: Wirtinger ) = add (Wirtinger (a, Zero ()), b)
129
-
130
- function mul_wirtinger (a:: Wirtinger , b:: Wirtinger )
131
- error ("""
132
- cannot multiply two Wirtinger objects; this error likely means a
133
- `WirtingerRule` was inappropriately defined somewhere. Multiplication
134
- of two Wirtinger objects is not defined because chain rule application
135
- often expands into a non-commutative operation in the Wirtinger
136
- calculus. To put it another way: simply given two Wirtinger objects
137
- and no other information, we can't know "locally" which components to
138
- conjugate in order to implement the chain rule. We could pick a
139
- convention; for example, we could define `a::Wirtinger * b::Wirtinger`
140
- such that we assume the chain rule application is of the form `f_a ∘ f_b`
141
- instead of `f_b ∘ f_a`. However, picking such a convention is likely to
142
- lead to silently incorrect derivatives due to commutativity assumptions
143
- in downstream generic code that deals with the reals. Thus, ChainRulesCore
144
- makes this operation an error instead.
145
- """ )
146
- end
147
-
148
- mul_wirtinger (a:: Wirtinger , b) = Wirtinger (mul (a. primal, b), mul (a. conjugate, b))
149
- mul_wirtinger (a, b:: Wirtinger ) = Wirtinger (mul (a, b. primal), mul (a, b. conjugate))
150
91
151
92
# ####
152
93
# #### `Casted`
@@ -174,14 +115,6 @@ Base.iterate(x::Casted, state) = iterate(x.value, state)
174
115
175
116
Base. conj (x:: Casted ) = cast (conj, x. value)
176
117
177
- add_casted (a:: Casted , b:: Casted ) = Casted (broadcasted (add, a. value, b. value))
178
- add_casted (a:: Casted , b) = Casted (broadcasted (add, a. value, b))
179
- add_casted (a, b:: Casted ) = Casted (broadcasted (add, a, b. value))
180
-
181
- mul_casted (a:: Casted , b:: Casted ) = Casted (broadcasted (mul, a. value, b. value))
182
- mul_casted (a:: Casted , b) = Casted (broadcasted (mul, a. value, b))
183
- mul_casted (a, b:: Casted ) = Casted (broadcasted (mul, a, b. value))
184
-
185
118
# ####
186
119
# #### `Zero`
187
120
# ####
@@ -200,13 +133,6 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
200
133
Base. iterate (x:: Zero ) = (x, nothing )
201
134
Base. iterate (:: Zero , :: Any ) = nothing
202
135
203
- add_zero (:: Zero , :: Zero ) = Zero ()
204
- add_zero (:: Zero , b) = b
205
- add_zero (a, :: Zero ) = a
206
-
207
- mul_zero (:: Zero , :: Zero ) = Zero ()
208
- mul_zero (:: Zero , :: Any ) = Zero ()
209
- mul_zero (:: Any , :: Zero ) = Zero ()
210
136
211
137
# ####
212
138
# #### `DNE`
@@ -228,14 +154,6 @@ Base.Broadcast.broadcastable(::DNE) = Ref(DNE())
228
154
Base. iterate (x:: DNE ) = (x, nothing )
229
155
Base. iterate (:: DNE , :: Any ) = nothing
230
156
231
- add_dne (:: DNE , :: DNE ) = DNE ()
232
- add_dne (:: DNE , b) = b
233
- add_dne (a, :: DNE ) = a
234
-
235
- mul_dne (:: DNE , :: DNE ) = DNE ()
236
- mul_dne (:: DNE , :: Any ) = DNE ()
237
- mul_dne (:: Any , :: DNE ) = DNE ()
238
-
239
157
# ####
240
158
# #### `One`
241
159
# ####
@@ -254,13 +172,6 @@ Base.Broadcast.broadcastable(::One) = Ref(One())
254
172
Base. iterate (x:: One ) = (x, nothing )
255
173
Base. iterate (:: One , :: Any ) = nothing
256
174
257
- add_one (a:: One , b:: One ) = add (extern (a), extern (b))
258
- add_one (a:: One , b) = add (extern (a), b)
259
- add_one (a, b:: One ) = add (a, extern (b))
260
-
261
- mul_one (:: One , :: One ) = One ()
262
- mul_one (:: One , b) = b
263
- mul_one (a, :: One ) = a
264
175
265
176
# ####
266
177
# #### `Thunk`
295
206
end
296
207
297
208
Base. conj (x:: Thunk ) = @thunk (conj (extern (x)))
298
-
299
- add_thunk (a:: Thunk , b:: Thunk ) = add (extern (a), extern (b))
300
- add_thunk (a:: Thunk , b) = add (extern (a), b)
301
- add_thunk (a, b:: Thunk ) = add (a, extern (b))
302
-
303
- mul_thunk (a:: Thunk , b:: Thunk ) = mul (extern (a), extern (b))
304
- mul_thunk (a:: Thunk , b) = mul (extern (a), b)
305
- mul_thunk (a, b:: Thunk ) = mul (a, extern (b))
0 commit comments