diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 013d9e7c8..cb43633ee 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends. It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use. In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself. -At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend. +At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)). ## Implementations diff --git a/DifferentiationInterface/docs/src/faq/differentiability.md b/DifferentiationInterface/docs/src/faq/differentiability.md index 197384c4f..5845e80f5 100644 --- a/DifferentiationInterface/docs/src/faq/differentiability.md +++ b/DifferentiationInterface/docs/src/faq/differentiability.md @@ -111,4 +111,5 @@ There are, however, translation utilities: ### Backend switch Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend. -Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object. \ No newline at end of file + +Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index 23f9c9b0c..292372b81 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) (; f, backend) = dw y = f(x) - prep_same = DI.prepare_pullback_same_point_nokwarg(Val(true), f, backend, x, (y,)) + prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,)) function pullbackfunc(dy) tx = DI.pullback(f, prep_same, backend, x, (dy,)) return (NoTangent(), only(tx)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 6253ea229..321378e23 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -3,6 +3,7 @@ module DifferentiationInterfaceMooncakeExt using ADTypes: ADTypes, AutoMooncake import DifferentiationInterface as DI using Mooncake: + Mooncake, CoDual, Config, prepare_gradient_cache, @@ -11,6 +12,16 @@ using Mooncake: value_and_gradient!!, value_and_pullback!!, zero_tangent, + rdata_type, + fdata, + rdata, + tangent_type, + NoTangent, + @is_primitive, + zero_fcodual, + MinimalCtx, + NoRData, + primal, _copy_output, _copy_to_output!! @@ -25,5 +36,6 @@ mycopy(x) = deepcopy(x) include("onearg.jl") include("twoarg.jl") +include("differentiate_with.jl") end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl new file mode 100644 index 000000000..3b4fb91c3 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -0,0 +1,84 @@ +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Any} + +struct MooncakeDifferentiateWithError <: Exception + F::Type + X::Type + Y::Type + function MooncakeDifferentiateWithError(::F, ::X, ::Y) where {F,X,Y} + return new(F, X, Y) + end +end + +function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) + return print( + io, + "MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.", + ) +end + +function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) + primal_func = primal(dw) + primal_x = primal(x) + (; f, backend) = primal_func + y = zero_fcodual(f(primal_x)) + + # output is a vector, so we need to use the vector pullback + function pullback_array!!(dy::NoRData) + tx = DI.pullback(f, backend, primal_x, (y.dx,)) + @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) + return NoRData(), rdata(only(tx)) + end + + # output is a scalar, so we can use the scalar pullback + function pullback_scalar!!(dy::Number) + tx = DI.pullback(f, backend, primal_x, (dy,)) + @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) + return NoRData(), rdata(only(tx)) + end + + pullback = if primal(y) isa Number + pullback_scalar!! + elseif primal(y) isa AbstractArray + pullback_array!! + else + throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) + end + + return y, pullback +end + +function Mooncake.rrule!!( + dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}} +) + primal_func = primal(dw) + primal_x = primal(x) + fdata_arg = x.dx + (; f, backend) = primal_func + y = zero_fcodual(f(primal_x)) + + # output is a vector, so we need to use the vector pullback + function pullback_array!!(dy::NoRData) + tx = DI.pullback(f, backend, primal_x, (y.dx,)) + @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) + fdata_arg .+= only(tx) + return NoRData(), dy + end + + # output is a scalar, so we can use the scalar pullback + function pullback_scalar!!(dy::Number) + tx = DI.pullback(f, backend, primal_x, (dy,)) + @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) + fdata_arg .+= only(tx) + return NoRData(), NoRData() + end + + pullback = if primal(y) isa Number + pullback_scalar!! + elseif primal(y) isa AbstractArray + pullback_array!! + else + throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) + end + + return y, pullback +end diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index 83576c078..256d46f75 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -13,9 +13,13 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be !!! warning `DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments. - It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. + It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper). +!!! warning + When using `DifferentiateWith(f, AutoSomething())`, the function `f` must not close over any active data. + As of now, we cannot differentiate with respect to parameters stored inside `f`. + # Fields - `f`: the function in question, with signature `f(x)` diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index c8ea57c0b..d2bf57f88 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -1,20 +1,41 @@ using Pkg -Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"]) +Pkg.add(["ChainRulesTestUtils", "FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"]) +using ChainRulesTestUtils: ChainRulesTestUtils using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using Zygote: Zygote +using Mooncake: Mooncake +using StableRNGs using Test LOGGING = get(ENV, "CI", "false") == "false" +struct ADBreaker{F} + f::F +end + +function (adb::ADBreaker)(x::Number) + copyto!(Float64[0], x) # break ForwardDiff and Zygote + return adb.f(x) +end + +function (adb::ADBreaker)(x::AbstractArray) + copyto!(similar(x, Float64), x) # break ForwardDiff and Zygote + return adb.f(x) +end + function differentiatewith_scenarios() - bad_scens = # these closurified scenarios have mutation and type constraints - filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen - DIT.function_place(scen) == :out - end + outofplace_scens = filter(DIT.default_scenarios()) do scen + DIT.function_place(scen) == :out + end + # with bad_scens, everything would break + bad_scens = map(outofplace_scens) do scen + DIT.change_function(scen, ADBreaker(scen.f)) + end + # with good_scens, everything is fixed good_scens = map(bad_scens) do scen DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff())) end @@ -22,8 +43,64 @@ function differentiatewith_scenarios() end test_differentiation( - [AutoForwardDiff(), AutoZygote()], + [AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)], differentiatewith_scenarios(); excluded=SECOND_ORDER, logging=LOGGING, + testset_name="DI tests", ) + +@testset "ChainRules tests" begin + @testset for scen in filter(differentiatewith_scenarios()) do scen + DIT.operator(scen) == :pullback + end + ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol=1e-4) + end +end; + +@testset "Mooncake tests" begin + @testset for scen in filter(differentiatewith_scenarios()) do scen + DIT.operator(scen) == :pullback + end + Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true) + end +end; + +@testset "Mooncake errors" begin + MooncakeDifferentiateWithError = + Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError + + e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) + @test sprint(showerror, e) == + "MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported." + + f_num2tup(x::Number) = (x,) + f_vec2tup(x::Vector) = (first(x),) + f_tup2num(x::Tuple{<:Number}) = only(x) + f_tup2vec(x::Tuple{<:Number}) = [only(x)] + + @test_throws MooncakeDifferentiateWithError pullback( + DifferentiateWith(f_num2tup, AutoFiniteDiff()), + AutoMooncake(; config=nothing), + 1.0, + ((2.0,),), + ) + @test_throws MooncakeDifferentiateWithError pullback( + DifferentiateWith(f_vec2tup, AutoFiniteDiff()), + AutoMooncake(; config=nothing), + [1.0], + ((2.0,),), + ) + @test_throws MethodError pullback( + DifferentiateWith(f_tup2num, AutoFiniteDiff()), + AutoMooncake(; config=nothing), + (1.0,), + (2.0,), + ) + @test_throws MethodError pullback( + DifferentiateWith(f_tup2vec, AutoFiniteDiff()), + AutoMooncake(; config=nothing), + (1.0,), + ([2.0],), + ) +end