diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index b75d8eff5..d21d96e21 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -7,7 +7,7 @@ using Compat: hasfield, hasproperty export frule, rrule # core function # rule configurations -export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode +export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode, Reuseable, NotReuseable export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented diff --git a/src/config.jl b/src/config.jl index 04757e838..f81244799 100644 --- a/src/config.jl +++ b/src/config.jl @@ -64,6 +64,24 @@ that do not support performing forwards mode AD should be `RuleConfig{>:NoForwar """ struct NoForwardsMode <: ForwardsModeCapability end +abstract type PullbackCapability end + +""" +NotReuseable + +This trait indicate that a pullback acquired by `RuleConfig{>:NotReuseable}` can only be called once. +So optimizations like reusing array buffers can be done in the pullback. +""" +struct NotReuseable <: PullbackCapability end + +""" +Reuseable + +This is the complement to [`NotReuseable`](@ref). If it is set then the pullback must return correct +result when being called multiple times. This is useful for computing jacobian. +""" +struct Reuseable <: PullbackCapability end + """ frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...) diff --git a/test/config.jl b/test/config.jl index 58d943252..db09384cc 100644 --- a/test/config.jl +++ b/test/config.jl @@ -40,6 +40,9 @@ function ChainRulesCore.rrule_via_ad(config::MockBothConfig, f, args...; kws...) return f(args...; kws...), pullback_via_ad end +struct ReuseableConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode,Reuseable}} end +struct NotReuseableConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode,NotReuseable}} end + ############################## #define some functions for testing @@ -155,6 +158,39 @@ end @test rconfig.reverse_calls == [(identity, (32.1,))] end + @testset "pullback capability" begin + f(x) = x .* fill(2, size(x)) + function ChainRulesCore.rrule(::RuleConfig{>:NotReuseable}, ::typeof(f), x) + tmp = similar(x) + fill!(tmp, 2) + y = x .* tmp + function pullback(Ȳ) + tmp .*= Ȳ + ∂ = tmp + return (NoTangent(), ∂) + end + return y, pullback + end + + function ChainRulesCore.rrule(::RuleConfig{>:Reuseable}, ::typeof(f), x) + tmp = similar(x) + fill!(tmp, 2) + y = x .* tmp + function pullback(Ȳ) + ∂ = tmp .* Ȳ + return (NoTangent(), ∂) + end + return y, pullback + end + + reuseable_pullback = rrule(ReuseableConfig(), f, randn(3))[2] + @test reuseable_pullback([1.0, 2.0, 3.0])[2] == [2.0, 4.0, 6.0] + @test reuseable_pullback([1.0, 2.0, 3.0])[2] == [2.0, 4.0, 6.0] + notreuseable_pullback = rrule(NotReuseableConfig(), f, randn(3))[2] + @test notreuseable_pullback([1.0, 2.0, 3.0])[2] == [2.0, 4.0, 6.0] + @test notreuseable_pullback([1.0, 2.0, 3.0])[2] != [2.0, 4.0, 6.0] + end + @testset "RuleConfig broadcasts like a scaler" begin @test (MostBoringConfig() .=> (1, 2, 3)) isa NTuple{3,Pair{MostBoringConfig,Int}} end