Skip to content

Commit f6979ac

Browse files
committed
Make accumulate apply to Differentials
update accumulate to work on differentials
1 parent eb3c292 commit f6979ac

File tree

1 file changed

+65
-56
lines changed

1 file changed

+65
-56
lines changed

src/rule_types.jl

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -61,59 +61,6 @@ function Base.getindex(rule::AbstractRule, i::Integer)
6161
return i == 1 ? rule : throw(BoundsError())
6262
end
6363

64-
"""
65-
accumulate(Δ, rule::AbstractRule, args...)
66-
67-
Return `Δ + rule(args...)` evaluated in a manner that supports ChainRulesCore'
68-
various `AbstractDifferential` types.
69-
70-
This method intended to be customizable for specific rules/input types. For
71-
example, here is pseudocode to overload `accumulate` w.r.t. a specific forward
72-
differentiation rule for a given function `f`:
73-
74-
```
75-
df(x) = # forward differentiation primitive implementation
76-
77-
frule(::typeof(f), x) = (f(x), Rule(df))
78-
79-
accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation
80-
```
81-
82-
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
83-
"""
84-
accumulate(Δ, rule, args...) = Δ + rule(args...)
85-
86-
"""
87-
accumulate!(Δ, rule::AbstractRule, args...)
88-
89-
Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place,
90-
storing the result in `Δ`.
91-
92-
Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
93-
94-
See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
95-
"""
96-
function accumulate!(Δ, rule, args...)
97-
return materialize!(Δ, broadcastable(cast(Δ) + rule(args...)))
98-
end
99-
100-
accumulate!::Number, rule, args...) = accumulate(Δ, rule, args...)
101-
102-
"""
103-
store!(Δ, rule::AbstractRule, args...)
104-
105-
Compute `rule(args...)` and store the result in `Δ`, potentially avoiding
106-
intermediate temporary allocations that might be necessary for alternative
107-
approaches (e.g. `copyto!(Δ, extern(rule(args...)))`)
108-
109-
Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
110-
111-
Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
112-
to be customizable for specific rules/input types.
113-
114-
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
115-
"""
116-
store!(Δ, rule, args...) = materialize!(Δ, broadcastable(rule(args...)))
11764

11865
#####
11966
##### `Rule`
@@ -157,9 +104,6 @@ Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)
157104
Base.show(io::IO, rule::Rule{<:Any, Nothing}) = print(io, "Rule($(rule.f))")
158105
Base.show(io::IO, rule::Rule) = print(io, "Rule($(rule.f), $(rule.u))")
159106

160-
# Specialized accumulation
161-
# TODO: Does this need to be overdubbed in the rule context?
162-
accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)
163107

164108
#####
165109
##### `DNERule`
@@ -211,3 +155,68 @@ function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
211155
return WirtingerRule(primal, conjugate)
212156
end
213157
end
158+
159+
160+
"""
161+
accumulate(Δ, ∂)
162+
163+
Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore'
164+
various `AbstractDifferential` types.
165+
166+
#TODO: update these docs
167+
168+
This method intended to be customizable for specific rules/input types. For
169+
example, here is pseudocode to overload `accumulate` w.r.t. a specific forward
170+
differentiation rule for a given function `f`:
171+
172+
```
173+
df(x) = # forward differentiation primitive implementation
174+
175+
frule(::typeof(f), x) = (f(x), Rule(df))
176+
177+
accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation
178+
```
179+
180+
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
181+
"""
182+
accumulate(Δ, ∂) = Δ +
183+
184+
"""
185+
accumulate!(Δ, rule::AbstractRule, args...)
186+
187+
# TODO: Update these docs
188+
189+
Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place,
190+
storing the result in `Δ`.
191+
192+
Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
193+
194+
See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
195+
"""
196+
function accumulate!(Δ, ∂)
197+
return materialize!(Δ, broadcastable(cast(Δ) + ∂))
198+
end
199+
200+
accumulate!::Number, ∂) = accumulate(Δ, ∂)
201+
202+
# TODO: replace this:
203+
# accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)
204+
205+
206+
"""
207+
store!(Δ, ∂)
208+
209+
TODO: Rewrite these docs
210+
211+
Compute `rule(args...)` and store the result in `Δ`, potentially avoiding
212+
intermediate temporary allocations that might be necessary for alternative
213+
approaches (e.g. `copyto!(Δ, extern(rule(args...)))`)
214+
215+
Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
216+
217+
Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
218+
to be customizable for specific rules/input types.
219+
220+
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
221+
"""
222+
store!(Δ, ∂) = materialize!(Δ, broadcastable(∂))

0 commit comments

Comments
 (0)