diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index f9eaf59f6..bc58a4ef8 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -7,7 +7,7 @@ using Compat: hasfield, hasproperty export frule, rrule # core function # rule configurations -export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode +export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode, HasChunkedMode, NoChunkedMode export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented diff --git a/src/config.jl b/src/config.jl index 04757e838..0bc3bc662 100644 --- a/src/config.jl +++ b/src/config.jl @@ -89,3 +89,19 @@ See also: [`frule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on [rule configurations and calling back into AD](@ref config) """ function rrule_via_ad end + +abstract type ChunkedRuleCapability end +""" +HasChunkedMode + +This trait indicates that a `RuleConfig{>:HasChunkedMode}` can perform chunked AD. +""" +struct HasChunkedMode <: ChunkedRuleCapability end + +""" +NoChunkedMode + +This is the complement to [`HasChunkedMode`](@ref). To avoid ambiguities [`RuleConfig`]s +that do not support chunked AD should be `RuleConfig{>:NoChunkedMode}`. +""" +struct NoChunkedMode <: ChunkedRuleCapability end diff --git a/src/rules.jl b/src/rules.jl index d99e54a01..b003a2f97 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -78,6 +78,33 @@ function (::typeof(frule_kwfunc))(kws::Any, ::typeof(frule), ::RuleConfig, args. return frule_kwfunc(kws, frule, args...) end +struct ProductTangent{P} + partials::P +end + +function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx...), f, args...) + return frule((Δf, Δx...), f, args...) +end + +function frule(::RuleConfig{>:HasChunkedMode}, + (Δf, Δx)::Tuple{Any,ProductTangent}, f, args...) + fx = frule((Δf, first(Δx)), args...)[1] + dfx = ProductTangent(map(Δrow -> frule((Δf, Δrow), f, args...)[2], Δx)) + return (fx, dfx) +end + +function rrule(::RuleConfig{>:HasChunkedMode}, args...) + y, back = rrule(args...) + return y, ApplyBack(back) +end + +struct ApplyBack{F} + back::F +end + +(a::ApplyBack)(dy) = a.back(dy) +(a::ApplyBack)(dy::ProductTangent) = ProductTangent(map(a.back, dy.partials)) # or some Tangent recursion? + """ rrule([::RuleConfig,] f, x...) @@ -149,7 +176,7 @@ const NO_RRULE_DOC = """ This is an piece of infastructure supporting opting out of [`rrule`](@ref). It follows the signature for `rrule` exactly. A collection of type-tuples is stored in its method-table. -If something has this defined, it means that it must having a must also have a `rrule`, +If something has this defined, it means that it must having a must also have a `rrule`, defined that returns `nothing`. !!! warning "Do not overload no_rrule directly" diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 8baa006e8..958d082d4 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -135,8 +135,9 @@ Define a [`Thunk`](@ref) wrapping the `expr`, to lazily defer its evaluation. macro thunk(body) # Basically `:(Thunk(() -> $(esc(body))))` but use the location where it is defined. # so we get useful stack traces if it errors. - func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body)) - return :(Thunk($(esc(func)))) + #func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body)) + #return :(Thunk($(esc(func)))) + return esc(body) end """