@@ -61,59 +61,6 @@ function Base.getindex(rule::AbstractRule, i::Integer)
61
61
return i == 1 ? rule : throw (BoundsError ())
62
62
end
63
63
64
- """
65
- accumulate(Δ, rule::AbstractRule, args...)
66
-
67
- Return `Δ + rule(args...)` evaluated in a manner that supports ChainRulesCore'
68
- various `AbstractDifferential` types.
69
-
70
- This method intended to be customizable for specific rules/input types. For
71
- example, here is pseudocode to overload `accumulate` w.r.t. a specific forward
72
- differentiation rule for a given function `f`:
73
-
74
- ```
75
- df(x) = # forward differentiation primitive implementation
76
-
77
- frule(::typeof(f), x) = (f(x), Rule(df))
78
-
79
- accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation
80
- ```
81
-
82
- See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
83
- """
84
- accumulate (Δ, rule, args... ) = Δ + rule (args... )
85
-
86
- """
87
- accumulate!(Δ, rule::AbstractRule, args...)
88
-
89
- Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place,
90
- storing the result in `Δ`.
91
-
92
- Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
93
-
94
- See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
95
- """
96
- function accumulate! (Δ, rule, args... )
97
- return materialize! (Δ, broadcastable (cast (Δ) + rule (args... )))
98
- end
99
-
100
- accumulate! (Δ:: Number , rule, args... ) = accumulate (Δ, rule, args... )
101
-
102
- """
103
- store!(Δ, rule::AbstractRule, args...)
104
-
105
- Compute `rule(args...)` and store the result in `Δ`, potentially avoiding
106
- intermediate temporary allocations that might be necessary for alternative
107
- approaches (e.g. `copyto!(Δ, extern(rule(args...)))`)
108
-
109
- Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
110
-
111
- Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
112
- to be customizable for specific rules/input types.
113
-
114
- See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
115
- """
116
- store! (Δ, rule, args... ) = materialize! (Δ, broadcastable (rule (args... )))
117
64
118
65
# ####
119
66
# #### `Rule`
@@ -157,9 +104,6 @@ Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)
157
104
Base. show (io:: IO , rule:: Rule{<:Any, Nothing} ) = print (io, " Rule($(rule. f) )" )
158
105
Base. show (io:: IO , rule:: Rule ) = print (io, " Rule($(rule. f) , $(rule. u) )" )
159
106
160
- # Specialized accumulation
161
- # TODO : Does this need to be overdubbed in the rule context?
162
- accumulate! (Δ, rule:: Rule{F,U} , args... ) where {F,U<: Function } = rule. u (Δ, args... )
163
107
164
108
# ####
165
109
# #### `DNERule`
@@ -211,3 +155,68 @@ function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
211
155
return WirtingerRule (primal, conjugate)
212
156
end
213
157
end
158
+
159
+
160
+ """
161
+ accumulate(Δ, ∂)
162
+
163
+ Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore'
164
+ various `AbstractDifferential` types.
165
+
166
+ #TODO: update these docs
167
+
168
+ This method intended to be customizable for specific rules/input types. For
169
+ example, here is pseudocode to overload `accumulate` w.r.t. a specific forward
170
+ differentiation rule for a given function `f`:
171
+
172
+ ```
173
+ df(x) = # forward differentiation primitive implementation
174
+
175
+ frule(::typeof(f), x) = (f(x), Rule(df))
176
+
177
+ accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation
178
+ ```
179
+
180
+ See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
181
+ """
182
+ accumulate (Δ, ∂) = Δ + ∂
183
+
184
+ """
185
+ accumulate!(Δ, rule::AbstractRule, args...)
186
+
187
+ # TODO: Update these docs
188
+
189
+ Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place,
190
+ storing the result in `Δ`.
191
+
192
+ Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
193
+
194
+ See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
195
+ """
196
+ function accumulate! (Δ, ∂)
197
+ return materialize! (Δ, broadcastable (cast (Δ) + ∂))
198
+ end
199
+
200
+ accumulate! (Δ:: Number , ∂) = accumulate (Δ, ∂)
201
+
202
+ # TODO : replace this:
203
+ # accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)
204
+
205
+
206
+ """
207
+ store!(Δ, ∂)
208
+
209
+ TODO: Rewrite these docs
210
+
211
+ Compute `rule(args...)` and store the result in `Δ`, potentially avoiding
212
+ intermediate temporary allocations that might be necessary for alternative
213
+ approaches (e.g. `copyto!(Δ, extern(rule(args...)))`)
214
+
215
+ Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
216
+
217
+ Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
218
+ to be customizable for specific rules/input types.
219
+
220
+ See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
221
+ """
222
+ store! (Δ, ∂) = materialize! (Δ, broadcastable (∂))
0 commit comments