From 1f2a3a33de7feaf1ae9f74a489c99c62442ffd5d Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 26 Jul 2022 17:11:07 -0400 Subject: [PATCH 01/10] add chunked mode --- src/ChainRulesCore.jl | 1 + src/config.jl | 17 +++++++++++++++++ src/rules.jl | 25 ++++++++++++++++++++++++- src/tangent_types/thunks.jl | 5 +++-- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index f9eaf59f6..8d805fc46 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -32,6 +32,7 @@ include("projection.jl") include("config.jl") include("rules.jl") +include("chunked_rules.jl") include("rule_definition_tools.jl") include("ignore_derivatives.jl") diff --git a/src/config.jl b/src/config.jl index 04757e838..aa7471840 100644 --- a/src/config.jl +++ b/src/config.jl @@ -89,3 +89,20 @@ See also: [`frule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on [rule configurations and calling back into AD](@ref config) """ function rrule_via_ad end + + +abstract type ChunkedRuleCapability end +""" +HasChunkedMode + +This trait indicates that a `RuleConfig{>:HasChunkedMode}` can perform chunked AD. +""" +struct HasChunkedMode <: ChunkedRuleCapability end + +""" +NoChunkedMode + +This is the complement to [`HasChunkedMode`](@ref). To avoid ambiguities [`RuleConfig`]s +that do not support chunked AD should be `RuleConfig{>:NoChunkedMode}`. +""" +struct NoChunkedMode <: ChunkedRuleCapability end diff --git a/src/rules.jl b/src/rules.jl index d99e54a01..18b320ee9 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -78,6 +78,29 @@ function (::typeof(frule_kwfunc))(kws::Any, ::typeof(frule), ::RuleConfig, args. return frule_kwfunc(kws, frule, args...) end +struct ProductTangent{P} + partials::P +end + +function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx...), f, args...) + frule((Δf, Δx...), f, args...) +end + +function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx::ProductTangent), f, args...) + frule((Δf, first(Δx)), args...)[1], ProductTangent(map(Δrow->frule((Δf, Δrow)[2], f, args...), Δx)) +end + +function rrule(::RuleConfig{>:HasChunkedMode}, f, args...) + y, back = rrule(args...) + return y, ApplyBack(back) +end + +struct ApplyBack{F}; back::F; end + +(a::ApplyBack)(dy) = a.back(dy) +(a::ApplyBack)(dy::ProductTangent) = ProductTangent(map(a.back, dy.partials)) # or some Tangent recursion? + + """ rrule([::RuleConfig,] f, x...) @@ -149,7 +172,7 @@ const NO_RRULE_DOC = """ This is an piece of infastructure supporting opting out of [`rrule`](@ref). It follows the signature for `rrule` exactly. A collection of type-tuples is stored in its method-table. -If something has this defined, it means that it must having a must also have a `rrule`, +If something has this defined, it means that it must having a must also have a `rrule`, defined that returns `nothing`. !!! warning "Do not overload no_rrule directly" diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 8baa006e8..958d082d4 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -135,8 +135,9 @@ Define a [`Thunk`](@ref) wrapping the `expr`, to lazily defer its evaluation. macro thunk(body) # Basically `:(Thunk(() -> $(esc(body))))` but use the location where it is defined. # so we get useful stack traces if it errors. - func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body)) - return :(Thunk($(esc(func)))) + #func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body)) + #return :(Thunk($(esc(func)))) + return esc(body) end """ From 6c606c0b384414855032f933c3d347fd6c3693df Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Tue, 26 Jul 2022 17:15:28 -0400 Subject: [PATCH 02/10] Update src/config.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/config.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/config.jl b/src/config.jl index aa7471840..0bc3bc662 100644 --- a/src/config.jl +++ b/src/config.jl @@ -90,7 +90,6 @@ See also: [`frule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on """ function rrule_via_ad end - abstract type ChunkedRuleCapability end """ HasChunkedMode From 8a09e385bc6ee624b249b467cf593a472861d156 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Tue, 26 Jul 2022 17:15:32 -0400 Subject: [PATCH 03/10] Update src/rules.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rules.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 18b320ee9..4ff01b0c3 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -87,7 +87,8 @@ function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx...), f, args...) end function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx::ProductTangent), f, args...) - frule((Δf, first(Δx)), args...)[1], ProductTangent(map(Δrow->frule((Δf, Δrow)[2], f, args...), Δx)) + return frule((Δf, first(Δx)), args...)[1], + ProductTangent(map(Δrow -> frule((Δf, Δrow)[2], f, args...), Δx)) end function rrule(::RuleConfig{>:HasChunkedMode}, f, args...) From 5e579e2afd455be9860881cb1e06f7d2df5f6d5a Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Tue, 26 Jul 2022 17:15:54 -0400 Subject: [PATCH 04/10] Update src/rules.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 4ff01b0c3..9b8b09331 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -96,7 +96,9 @@ function rrule(::RuleConfig{>:HasChunkedMode}, f, args...) return y, ApplyBack(back) end -struct ApplyBack{F}; back::F; end +struct ApplyBack{F} + back::F +end (a::ApplyBack)(dy) = a.back(dy) (a::ApplyBack)(dy::ProductTangent) = ProductTangent(map(a.back, dy.partials)) # or some Tangent recursion? From 2739a6686506571b6d6e42b1ceb39b0522ade0e7 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Tue, 26 Jul 2022 17:15:59 -0400 Subject: [PATCH 05/10] Update src/rules.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 9b8b09331..4043b1ddc 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -103,7 +103,6 @@ end (a::ApplyBack)(dy) = a.back(dy) (a::ApplyBack)(dy::ProductTangent) = ProductTangent(map(a.back, dy.partials)) # or some Tangent recursion? - """ rrule([::RuleConfig,] f, x...) From beda204b599c2332d569cc2554ac3eb9175b96fc Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Tue, 26 Jul 2022 17:16:04 -0400 Subject: [PATCH 06/10] Update src/rules.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 4043b1ddc..a5681bf3f 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -83,7 +83,7 @@ struct ProductTangent{P} end function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx...), f, args...) - frule((Δf, Δx...), f, args...) + return frule((Δf, Δx...), f, args...) end function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx::ProductTangent), f, args...) From 3846220ffbd4647c7c34ba294be1ebec6483d5d9 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 27 Jul 2022 12:52:16 -0400 Subject: [PATCH 07/10] fixes --- src/ChainRulesCore.jl | 1 - src/rules.jl | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 8d805fc46..f9eaf59f6 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -32,7 +32,6 @@ include("projection.jl") include("config.jl") include("rules.jl") -include("chunked_rules.jl") include("rule_definition_tools.jl") include("ignore_derivatives.jl") diff --git a/src/rules.jl b/src/rules.jl index a5681bf3f..520e54445 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -86,9 +86,8 @@ function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx...), f, args...) return frule((Δf, Δx...), f, args...) end -function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx::ProductTangent), f, args...) - return frule((Δf, first(Δx)), args...)[1], - ProductTangent(map(Δrow -> frule((Δf, Δrow)[2], f, args...), Δx)) +function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx)::Tuple{Any,ProductTangent}, f, args...) + return frule((Δf, first(Δx)), args...)[1], ProductTangent(map(Δrow->frule((Δf, Δrow), f, args...)[2], Δx)) end function rrule(::RuleConfig{>:HasChunkedMode}, f, args...) From 9b888a6584f0abe660ef48a71a19571f904270fa Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 27 Jul 2022 12:52:55 -0400 Subject: [PATCH 08/10] formatting --- src/rules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 520e54445..e50d1552e 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -87,7 +87,9 @@ function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx...), f, args...) end function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx)::Tuple{Any,ProductTangent}, f, args...) - return frule((Δf, first(Δx)), args...)[1], ProductTangent(map(Δrow->frule((Δf, Δrow), f, args...)[2], Δx)) + fx = frule((Δf, first(Δx)), args...)[1] + dfx = ProductTangent(map(Δrow->frule((Δf, Δrow), f, args...)[2], Δx)) + return (fx, dfx) end function rrule(::RuleConfig{>:HasChunkedMode}, f, args...) From 24f1a57aa33a42301ecb9be7e96f11a0686fbc29 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 27 Jul 2022 12:55:28 -0400 Subject: [PATCH 09/10] Update src/rules.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index e50d1552e..4b435385d 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -88,7 +88,7 @@ end function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx)::Tuple{Any,ProductTangent}, f, args...) fx = frule((Δf, first(Δx)), args...)[1] - dfx = ProductTangent(map(Δrow->frule((Δf, Δrow), f, args...)[2], Δx)) + dfx = ProductTangent(map(Δrow -> frule((Δf, Δrow), f, args...)[2], Δx)) return (fx, dfx) end From 6a81e52c0c82df19054c12daac508df356ac7ab9 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 28 Jul 2022 11:31:08 -0400 Subject: [PATCH 10/10] formatting --- src/ChainRulesCore.jl | 2 +- src/rules.jl | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index f9eaf59f6..bc58a4ef8 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -7,7 +7,7 @@ using Compat: hasfield, hasproperty export frule, rrule # core function # rule configurations -export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode +export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode, HasChunkedMode, NoChunkedMode export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented diff --git a/src/rules.jl b/src/rules.jl index 4b435385d..b003a2f97 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -86,13 +86,14 @@ function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx...), f, args...) return frule((Δf, Δx...), f, args...) end -function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx)::Tuple{Any,ProductTangent}, f, args...) +function frule(::RuleConfig{>:HasChunkedMode}, + (Δf, Δx)::Tuple{Any,ProductTangent}, f, args...) fx = frule((Δf, first(Δx)), args...)[1] dfx = ProductTangent(map(Δrow -> frule((Δf, Δrow), f, args...)[2], Δx)) return (fx, dfx) end -function rrule(::RuleConfig{>:HasChunkedMode}, f, args...) +function rrule(::RuleConfig{>:HasChunkedMode}, args...) y, back = rrule(args...) return y, ApplyBack(back) end