From 1a389a6afe10b660655e4c9f4a05481bd97caf95 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Tue, 1 Apr 2025 17:56:42 +0530 Subject: [PATCH 01/25] Handles backend switching for Mooncake using ChainRules --- .../DifferentiationInterfaceMooncakeExt.jl | 8 +++++++- .../differentiate_with.jl | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 52e742b05..867d1ec58 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -3,6 +3,7 @@ module DifferentiationInterfaceMooncakeExt using ADTypes: ADTypes, AutoMooncake import DifferentiationInterface as DI using Mooncake: + Mooncake, CoDual, Config, prepare_gradient_cache, @@ -10,7 +11,11 @@ using Mooncake: tangent_type, value_and_gradient!!, value_and_pullback!!, - zero_tangent + @from_rrule, + MinimalCtx, + NoFData + +using ChainRulesCore: ChainRulesCore, rrule DI.check_available(::AutoMooncake) = true @@ -26,5 +31,6 @@ mycopy(x) = deepcopy(x) include("onearg.jl") include("twoarg.jl") +include("differentiate_with.jl") end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl new file mode 100644 index 000000000..3de98490b --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -0,0 +1,15 @@ +function define_rule!(primal_func, primal_args) + return eval(:(@from_rrule MinimalCtx Tuple{$primal_func,$primal_args...})) +end + +function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, args::CoDual...) + primal_func = typeof(Mooncake.primal(dw)) + primal_args = typeof.(map(arg -> Mooncake.primal(arg), args)) + # use the DI.chainrule wrapper inside @from_rrule to create a custom rrule!! + + # macro evaluation in global scope with more specialized types (@fromrrule requires non generic types) + define_rule!(primal_func, primal_args) + + # Use the ChainRuleCore rrule mapping with backends, calling Mooncake rule!! that now wraps around that ChainRulesCore rrule. + return Base.invokelatest(Mooncake.rrule!!, CoDual(primal_func, dw.dx), args...) +end From 08b176a3e1702ca40b6bbe028ba4294c691ec0fa Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Wed, 2 Apr 2025 23:38:24 +0530 Subject: [PATCH 02/25] Mooncake Wrapper for substitute backends --- DifferentiationInterface/Project.toml | 2 +- .../DifferentiationInterfaceMooncakeExt.jl | 6 +++-- .../differentiate_with.jl | 27 ++++++++++++------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index c14bca747..93b4bb0e6 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -38,7 +38,7 @@ DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] DifferentiationInterfaceGTPSAExt = "GTPSA" -DifferentiationInterfaceMooncakeExt = "Mooncake" +DifferentiationInterfaceMooncakeExt = ["ChainRulesCore", "Mooncake"] DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 867d1ec58..90eb6cf9f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -11,9 +11,11 @@ using Mooncake: tangent_type, value_and_gradient!!, value_and_pullback!!, - @from_rrule, + zero_tangent, + @is_primitive, + zero_fcodual, MinimalCtx, - NoFData + NoRData using ChainRulesCore: ChainRulesCore, rrule diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 3de98490b..fd56e99c1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,15 +1,22 @@ -function define_rule!(primal_func, primal_args) - return eval(:(@from_rrule MinimalCtx Tuple{$primal_func,$primal_args...})) -end +@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:AbstractArray}} +@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:Number}} function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, args::CoDual...) - primal_func = typeof(Mooncake.primal(dw)) - primal_args = typeof.(map(arg -> Mooncake.primal(arg), args)) - # use the DI.chainrule wrapper inside @from_rrule to create a custom rrule!! + primal_func = Mooncake.primal(dw) + primal_args = map(arg -> Mooncake.primal(arg), args) + + (; f, backend) = primal_func + y = f(primal_args...) + + prep_same = DI.prepare_pullback_same_point_nokwarg( + Val(true), f, backend, primal_args..., (y,) + ) - # macro evaluation in global scope with more specialized types (@fromrrule requires non generic types) - define_rule!(primal_func, primal_args) + function pullback!!(dy) + tx = DI.pullback(f, prep_same, backend, primal_args, (dy,)) + args_rdata = map((x) -> (x, Mooncake.zero_rdata(x)), only(tx)) + return NoRData(), args_rdata... + end - # Use the ChainRuleCore rrule mapping with backends, calling Mooncake rule!! that now wraps around that ChainRulesCore rrule. - return Base.invokelatest(Mooncake.rrule!!, CoDual(primal_func, dw.dx), args...) + return zero_fcodual(y), pullback!! end From 1340d921adfae13ce2ed23d53888a9787063f758 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Thu, 10 Apr 2025 19:02:35 +0530 Subject: [PATCH 03/25] added rules --- DifferentiationInterface/Project.toml | 2 +- .../DifferentiationInterfaceMooncakeExt.jl | 6 ++-- .../differentiate_with.jl | 36 ++++++++++++------- .../test/Back/DifferentiateWith/test.jl | 5 +-- 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 93b4bb0e6..c14bca747 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -38,7 +38,7 @@ DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] DifferentiationInterfaceGTPSAExt = "GTPSA" -DifferentiationInterfaceMooncakeExt = ["ChainRulesCore", "Mooncake"] +DifferentiationInterfaceMooncakeExt = "Mooncake" DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 90eb6cf9f..c61969019 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -15,9 +15,9 @@ using Mooncake: @is_primitive, zero_fcodual, MinimalCtx, - NoRData - -using ChainRulesCore: ChainRulesCore, rrule + NoRData, + fdata, + primal DI.check_available(::AutoMooncake) = true diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index fd56e99c1..a7b86141b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,21 +1,31 @@ -@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:AbstractArray}} -@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:Number}} - -function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, args::CoDual...) - primal_func = Mooncake.primal(dw) - primal_args = map(arg -> Mooncake.primal(arg), args) +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:AbstractArray} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Number} +function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) + primal_func = primal(dw) + primal_x = primal(x) (; f, backend) = primal_func - y = f(primal_args...) + y = f(primal_x) + + function pullback!!(dy) + tx = DI.pullback(f, backend, primal_x, (dy,)) + return NoRData(), only(tx) + end - prep_same = DI.prepare_pullback_same_point_nokwarg( - Val(true), f, backend, primal_args..., (y,) - ) + return zero_fcodual(y), pullback!! +end + +function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) + primal_func = primal(dw) + primal_x = primal(x) + fdata_arg = fdata(x.dx) + (; f, backend) = primal_func + y = f(primal_x) function pullback!!(dy) - tx = DI.pullback(f, prep_same, backend, primal_args, (dy,)) - args_rdata = map((x) -> (x, Mooncake.zero_rdata(x)), only(tx)) - return NoRData(), args_rdata... + tx = DI.pullback(f, backend, primal_x, (dy,)) + fdata_arg .+= only(tx) + return NoRData(), NoRData() end return zero_fcodual(y), pullback!! diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index dbc41f548..9cf24525f 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -1,11 +1,12 @@ using Pkg -Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"]) +Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"]) using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using Zygote: Zygote +using Mooncake: Mooncake using Test LOGGING = get(ENV, "CI", "false") == "false" @@ -24,7 +25,7 @@ function differentiatewith_scenarios() end test_differentiation( - [AutoForwardDiff(), AutoZygote()], + [AutoForwardDiff(), AutoZygote(), AutoMooncake()], differentiatewith_scenarios(); excluded=SECOND_ORDER, logging=LOGGING, From 08de6df62c417c73839a3a9990c7dba8fc71b3f2 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Thu, 10 Apr 2025 19:15:58 +0530 Subject: [PATCH 04/25] config --- DifferentiationInterface/test/Back/DifferentiateWith/test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 9cf24525f..d474dd517 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -25,7 +25,7 @@ function differentiatewith_scenarios() end test_differentiation( - [AutoForwardDiff(), AutoZygote(), AutoMooncake()], + [AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)], differentiatewith_scenarios(); excluded=SECOND_ORDER, logging=LOGGING, From 84f27c933673fbdca2ad60c673a2616477ba0fdc Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Thu, 10 Apr 2025 19:34:18 +0530 Subject: [PATCH 05/25] splatting for dy --- .../DifferentiationInterfaceMooncakeExt/differentiate_with.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index a7b86141b..4751adde7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -8,7 +8,7 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number y = f(primal_x) function pullback!!(dy) - tx = DI.pullback(f, backend, primal_x, (dy,)) + tx = DI.pullback(f, backend, primal_x, (dy...,)) return NoRData(), only(tx) end @@ -23,7 +23,7 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra y = f(primal_x) function pullback!!(dy) - tx = DI.pullback(f, backend, primal_x, (dy,)) + tx = DI.pullback(f, backend, primal_x, (dy...,)) fdata_arg .+= only(tx) return NoRData(), NoRData() end From 2e952996359da52d4c00f919fa9b652c2cd357fb Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Thu, 10 Apr 2025 19:51:20 +0530 Subject: [PATCH 06/25] brackets --- .../differentiate_with.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 4751adde7..70a5e04cb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -8,8 +8,8 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number y = f(primal_x) function pullback!!(dy) - tx = DI.pullback(f, backend, primal_x, (dy...,)) - return NoRData(), only(tx) + tx = DI.pullback(f, backend, primal_x, (dy,)) + return (NoRData(), only(tx)) end return zero_fcodual(y), pullback!! @@ -23,9 +23,9 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra y = f(primal_x) function pullback!!(dy) - tx = DI.pullback(f, backend, primal_x, (dy...,)) + tx = DI.pullback(f, backend, primal_x, (dy,)) fdata_arg .+= only(tx) - return NoRData(), NoRData() + return (NoRData(), NoRData()) end return zero_fcodual(y), pullback!! From 13233e51d82666a6f231c96cf2a39a0608ff8ffc Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Fri, 11 Apr 2025 13:31:12 +0530 Subject: [PATCH 07/25] too easy --- .../differentiate_with.jl | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 70a5e04cb..9c3133a28 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -5,14 +5,21 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number primal_func = primal(dw) primal_x = primal(x) (; f, backend) = primal_func - y = f(primal_x) + y = zero_fcodual(f(primal_x)) + # output is a vector, so we need to use the vector pullback + function pullback!!(dy::NoRData) + tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) + return NoRData(), only(tx) + end + + # output is a scalar, so we can use the scalar pullback function pullback!!(dy) tx = DI.pullback(f, backend, primal_x, (dy,)) - return (NoRData(), only(tx)) + return NoRData(), only(tx) end - return zero_fcodual(y), pullback!! + return y, pullback!! end function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) @@ -20,13 +27,23 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra primal_x = primal(x) fdata_arg = fdata(x.dx) (; f, backend) = primal_func - y = f(primal_x) + y = zero_fcodual(f(primal_x)) + + # output is a vector, so we need to use the vector pullback + function pullback!!(dy::NoRData) + tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) + fdata_arg .+= only(tx) + return NoRData(), dy + end + # output is a scalar, so we can use the scalar pullback function pullback!!(dy) tx = DI.pullback(f, backend, primal_x, (dy,)) fdata_arg .+= only(tx) - return (NoRData(), NoRData()) + return NoRData(), NoRData() end - return zero_fcodual(y), pullback!! + # in case x is mutated when passed into f + x = CoDual(primal_x, x.dx) + return y, pullback!! end From 1e8df989718a295f0632e58dd2274ef14e5abff4 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sat, 12 Apr 2025 21:58:11 +0530 Subject: [PATCH 08/25] changes from reviews, Docs --- .../docs/src/explanation/backends.md | 2 +- .../docs/src/faq/differentiability.md | 2 +- .../differentiate_with.jl | 15 +++++++-------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index a3c55dd5c..bc1f6faed 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends. It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use. In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself. -At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend. +At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) or a [Mooncake.jl](https://github.com/compintell/Mooncake.jl)-compatible backend. ## Implementations diff --git a/DifferentiationInterface/docs/src/faq/differentiability.md b/DifferentiationInterface/docs/src/faq/differentiability.md index 2a89ded21..ec3eeeaab 100644 --- a/DifferentiationInterface/docs/src/faq/differentiability.md +++ b/DifferentiationInterface/docs/src/faq/differentiability.md @@ -111,4 +111,4 @@ There are, however, translation utilities: ### Backend switch Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend. -Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object. \ No newline at end of file +Right now it only targets ForwardDiff.jl, ChainRulesCore.jl and Mooncake.jl but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 9c3133a28..df07b67a1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,5 +1,4 @@ -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:AbstractArray} -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Number} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}} function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) primal_func = primal(dw) @@ -14,7 +13,7 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number end # output is a scalar, so we can use the scalar pullback - function pullback!!(dy) + function pullback!!(dy::Number) tx = DI.pullback(f, backend, primal_x, (dy,)) return NoRData(), only(tx) end @@ -28,22 +27,22 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra fdata_arg = fdata(x.dx) (; f, backend) = primal_func y = zero_fcodual(f(primal_x)) + # in case x is mutated in f calls + cp_primal_x = copy(primal_x) # output is a vector, so we need to use the vector pullback function pullback!!(dy::NoRData) - tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) + tx = DI.pullback(f, backend, cp_primal_x, (fdata(y.dx),)) fdata_arg .+= only(tx) return NoRData(), dy end # output is a scalar, so we can use the scalar pullback - function pullback!!(dy) - tx = DI.pullback(f, backend, primal_x, (dy,)) + function pullback!!(dy::Number) + tx = DI.pullback(f, backend, cp_primal_x, (dy,)) fdata_arg .+= only(tx) return NoRData(), NoRData() end - # in case x is mutated when passed into f - x = CoDual(primal_x, x.dx) return y, pullback!! end From afdddd4d4ab0e9d7f5397f6c98c9b4dc4f3d8324 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sat, 19 Apr 2025 01:12:24 +0530 Subject: [PATCH 09/25] changes from reviews - 2 --- .../docs/src/explanation/backends.md | 2 +- .../docs/src/faq/differentiability.md | 2 +- .../differentiate_with.jl | 18 ++++++++---------- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index bc1f6faed..7d67e3981 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends. It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use. In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself. -At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) or a [Mooncake.jl](https://github.com/compintell/Mooncake.jl)-compatible backend. +At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/compintell/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)). ## Implementations diff --git a/DifferentiationInterface/docs/src/faq/differentiability.md b/DifferentiationInterface/docs/src/faq/differentiability.md index ec3eeeaab..d3e51dd35 100644 --- a/DifferentiationInterface/docs/src/faq/differentiability.md +++ b/DifferentiationInterface/docs/src/faq/differentiability.md @@ -111,4 +111,4 @@ There are, however, translation utilities: ### Backend switch Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend. -Right now it only targets ForwardDiff.jl, ChainRulesCore.jl and Mooncake.jl but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file +Right now, it only targets ForwardDiff.jl, Mooncake.jl, ChainRules.jl-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index df07b67a1..18114437e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -7,18 +7,18 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number y = zero_fcodual(f(primal_x)) # output is a vector, so we need to use the vector pullback - function pullback!!(dy::NoRData) + function pullback_array!!(dy::NoRData) tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) return NoRData(), only(tx) end # output is a scalar, so we can use the scalar pullback - function pullback!!(dy::Number) + function pullback_scalar!!(dy::Number) tx = DI.pullback(f, backend, primal_x, (dy,)) return NoRData(), only(tx) end - return y, pullback!! + return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! end function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) @@ -27,22 +27,20 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra fdata_arg = fdata(x.dx) (; f, backend) = primal_func y = zero_fcodual(f(primal_x)) - # in case x is mutated in f calls - cp_primal_x = copy(primal_x) # output is a vector, so we need to use the vector pullback - function pullback!!(dy::NoRData) - tx = DI.pullback(f, backend, cp_primal_x, (fdata(y.dx),)) + function pullback_array!!(dy::NoRData) + tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) fdata_arg .+= only(tx) return NoRData(), dy end # output is a scalar, so we can use the scalar pullback - function pullback!!(dy::Number) - tx = DI.pullback(f, backend, cp_primal_x, (dy,)) + function pullback_scalar!!(dy::Number) + tx = DI.pullback(f, backend, primal_x, (dy,)) fdata_arg .+= only(tx) return NoRData(), NoRData() end - return y, pullback!! + return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! end From 7a07127fd1f72037145663f38a8c3c62fb6bef32 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Fri, 16 May 2025 18:14:52 +0530 Subject: [PATCH 10/25] changes from reviews-1 --- DifferentiationInterface/docs/src/explanation/backends.md | 4 ++-- .../docs/src/faq/differentiability.md | 8 ++++---- .../DifferentiationInterfaceMooncakeExt.jl | 1 + .../differentiate_with.jl | 2 ++ DifferentiationInterface/src/misc/differentiate_with.jl | 6 +++++- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 7d67e3981..a78b1f45f 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends. It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use. In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself. -At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/compintell/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)). +At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)). ## Implementations @@ -177,7 +177,7 @@ For all operators, preparation generates an [executable function](https://docs.s ### Mooncake -For `pullback`, preparation [builds the reverse rule](https://github.com/compintell/Mooncake.jl?tab=readme-ov-file#how-it-works) of the function. +For `pullback`, preparation [builds the reverse rule](https://github.com/chalk-lab/Mooncake.jl?tab=readme-ov-file#how-it-works) of the function. ### Tracker diff --git a/DifferentiationInterface/docs/src/faq/differentiability.md b/DifferentiationInterface/docs/src/faq/differentiability.md index d3e51dd35..684fe236f 100644 --- a/DifferentiationInterface/docs/src/faq/differentiability.md +++ b/DifferentiationInterface/docs/src/faq/differentiability.md @@ -84,9 +84,9 @@ Note that its rule writing is very different from ChainRulesCore.jl due to the p ### Mooncake -[Mooncake.jl](https://github.com/compintell/Mooncake.jl) is a recent package which also handles a large subset of all Julia programs out-of-the-box. +[Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl) is a recent package which also handles a large subset of all Julia programs out-of-the-box. -Its [rule system](https://compintell.github.io/Mooncake.jl/dev/understanding_mooncake/rule_system/) is less expressive than that of Enzyme.jl, which might make it easier to start with. +Its [rule system](https://chalk-lab.github.io/Mooncake.jl/dev/understanding_mooncake/rule_system/) is less expressive than that of Enzyme.jl, which might make it easier to start with. ## A rule mayhem? @@ -106,9 +106,9 @@ There are, however, translation utilities: - from ChainRulesCore.jl to ForwardDiff.jl with [ForwardDiffChainRules.jl](https://github.com/ThummeTo/ForwardDiffChainRules.jl) - from ChainRulesCore.jl to Enzyme.jl with [`Enzyme.@import_rrule`](https://enzymead.github.io/Enzyme.jl/stable/api/#Enzyme.@import_rrule-Tuple) -- from ChainRulesCore.jl to Mooncake.jl with [`Mooncake.@from_rrule`](https://compintell.github.io/Mooncake.jl/dev/utilities/tools_for_rules/#Using-ChainRules.jl) +- from ChainRulesCore.jl to Mooncake.jl with [`Mooncake.@from_rrule`](https://chalk-lab.github.io/Mooncake.jl/dev/utilities/tools_for_rules/#Using-ChainRules.jl) ### Backend switch Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend. -Right now, it only targets ForwardDiff.jl, Mooncake.jl, ChainRules.jl-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file +Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index c61969019..82eeb7300 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -12,6 +12,7 @@ using Mooncake: value_and_gradient!!, value_and_pullback!!, zero_tangent, + rdata_type, @is_primitive, zero_fcodual, MinimalCtx, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 18114437e..cd309d5c2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -9,12 +9,14 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) + @assert only(tx) isa rdata_type(typeof(x)) return NoRData(), only(tx) end # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) tx = DI.pullback(f, backend, primal_x, (dy,)) + @assert only(tx) isa rdata_type(typeof(x)) return NoRData(), only(tx) end diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index 83576c078..08953d9d7 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -13,9 +13,13 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be !!! warning `DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments. - It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. + It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper). +!!! warning + When using Mooncake as a substitute backend via `DifferentiateWith(f, AutoMooncake())`. The function `f` must not close over any active data. + As of now, we cannot differentiate with respect to parameters stored inside `f`. + # Fields - `f`: the function in question, with signature `f(x)` From f3e436d02a15efdbcb8a800e2cdc87561677ee57 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Fri, 16 May 2025 18:34:52 +0530 Subject: [PATCH 11/25] conflicts --- DifferentiationInterface/docs/src/explanation/backends.md | 2 +- DifferentiationInterface/docs/src/faq/differentiability.md | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index a78b1f45f..09d24d736 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -177,7 +177,7 @@ For all operators, preparation generates an [executable function](https://docs.s ### Mooncake -For `pullback`, preparation [builds the reverse rule](https://github.com/chalk-lab/Mooncake.jl?tab=readme-ov-file#how-it-works) of the function. +For `pullback`, preparation [builds the reverse rule](https://chalk-lab.github.io/Mooncake.jl/stable/understanding_mooncake/rule_system/) of the function. ### Tracker diff --git a/DifferentiationInterface/docs/src/faq/differentiability.md b/DifferentiationInterface/docs/src/faq/differentiability.md index 684fe236f..1d0912570 100644 --- a/DifferentiationInterface/docs/src/faq/differentiability.md +++ b/DifferentiationInterface/docs/src/faq/differentiability.md @@ -86,7 +86,7 @@ Note that its rule writing is very different from ChainRulesCore.jl due to the p [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl) is a recent package which also handles a large subset of all Julia programs out-of-the-box. -Its [rule system](https://chalk-lab.github.io/Mooncake.jl/dev/understanding_mooncake/rule_system/) is less expressive than that of Enzyme.jl, which might make it easier to start with. +Its [rule system](https://chalk-lab.github.io/Mooncake.jl/stable/understanding_mooncake/rule_system/) is less expressive than that of Enzyme.jl, which might make it easier to start with. ## A rule mayhem? @@ -106,9 +106,10 @@ There are, however, translation utilities: - from ChainRulesCore.jl to ForwardDiff.jl with [ForwardDiffChainRules.jl](https://github.com/ThummeTo/ForwardDiffChainRules.jl) - from ChainRulesCore.jl to Enzyme.jl with [`Enzyme.@import_rrule`](https://enzymead.github.io/Enzyme.jl/stable/api/#Enzyme.@import_rrule-Tuple) -- from ChainRulesCore.jl to Mooncake.jl with [`Mooncake.@from_rrule`](https://chalk-lab.github.io/Mooncake.jl/dev/utilities/tools_for_rules/#Using-ChainRules.jl) +- from ChainRulesCore.jl to Mooncake.jl with [`Mooncake.@from_rrule`](https://chalk-lab.github.io/Mooncake.jl/stable/utilities/tools_for_rules/#Using-ChainRules.jl) ### Backend switch Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend. -Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file + +Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file From 6a0d93790540ca8f692e5fb8d1ff1a9572724223 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Fri, 16 May 2025 18:36:45 +0530 Subject: [PATCH 12/25] conflicts-2 --- DifferentiationInterface/docs/src/faq/differentiability.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/docs/src/faq/differentiability.md b/DifferentiationInterface/docs/src/faq/differentiability.md index 1d0912570..af16c6a49 100644 --- a/DifferentiationInterface/docs/src/faq/differentiability.md +++ b/DifferentiationInterface/docs/src/faq/differentiability.md @@ -106,7 +106,7 @@ There are, however, translation utilities: - from ChainRulesCore.jl to ForwardDiff.jl with [ForwardDiffChainRules.jl](https://github.com/ThummeTo/ForwardDiffChainRules.jl) - from ChainRulesCore.jl to Enzyme.jl with [`Enzyme.@import_rrule`](https://enzymead.github.io/Enzyme.jl/stable/api/#Enzyme.@import_rrule-Tuple) -- from ChainRulesCore.jl to Mooncake.jl with [`Mooncake.@from_rrule`](https://chalk-lab.github.io/Mooncake.jl/stable/utilities/tools_for_rules/#Using-ChainRules.jl) +- from ChainRulesCore.jl to Mooncake.jl with [`Mooncake.@from_rrule`](https://chalk-lab.github.io/Mooncake.jl/stable/utilities/defining_rules/#Using-ChainRules.jl) ### Backend switch From e543958d6b7c5c95d4a1cf742d9e1e55107fb551 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal <84859349+AstitvaAggarwal@users.noreply.github.com> Date: Fri, 16 May 2025 20:35:16 +0530 Subject: [PATCH 13/25] Update differentiate_with.jl --- .../DifferentiationInterfaceMooncakeExt/differentiate_with.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index cd309d5c2..ba1e8a33f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -9,14 +9,14 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) - @assert only(tx) isa rdata_type(typeof(x)) + @assert only(tx) isa rdata_type(typeof(primal_x)) return NoRData(), only(tx) end # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert only(tx) isa rdata_type(typeof(x)) + @assert only(tx) isa rdata_type(typeof(primal_x)) return NoRData(), only(tx) end From c63c9567298d0d0a341414f265e401c80a442f76 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sun, 18 May 2025 16:54:26 +0530 Subject: [PATCH 14/25] typecheck for array rule. --- .../DifferentiationInterfaceMooncakeExt/differentiate_with.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index ba1e8a33f..41b8a9b42 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -33,6 +33,7 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) + @assert only(tx) isa rdata_type(typeof(primal_x)) fdata_arg .+= only(tx) return NoRData(), dy end @@ -40,6 +41,7 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) tx = DI.pullback(f, backend, primal_x, (dy,)) + @assert only(tx) isa rdata_type(typeof(primal_x)) fdata_arg .+= only(tx) return NoRData(), NoRData() end From 36da036225b7187963f68cb3ad2a072520e0ff01 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sun, 18 May 2025 17:27:05 +0530 Subject: [PATCH 15/25] assertion for array inputs --- .../DifferentiationInterfaceMooncakeExt/differentiate_with.jl | 4 ++-- DifferentiationInterface/test/Back/DifferentiateWith/test.jl | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 41b8a9b42..fb3a1df72 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -33,7 +33,7 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) - @assert only(tx) isa rdata_type(typeof(primal_x)) + @assert first(only(tx)) isa rdata_type(typeof(first(primal_x))) fdata_arg .+= only(tx) return NoRData(), dy end @@ -41,7 +41,7 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert only(tx) isa rdata_type(typeof(primal_x)) + @assert first(only(tx)) isa rdata_type(typeof(first(primal_x))) fdata_arg .+= only(tx) return NoRData(), NoRData() end diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index ce41b6e01..43e58b48e 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -13,7 +13,9 @@ LOGGING = get(ENV, "CI", "false") == "false" function differentiatewith_scenarios() bad_scens = # these closurified scenarios have mutation and type constraints - filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen + filter( + DIT.default_scenarios(; include_normal=false, include_closurified=true) + ) do scen DIT.function_place(scen) == :out end good_scens = map(bad_scens) do scen From c389a80b965b3603299c9e96f77d068a9cad6138 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Thu, 29 May 2025 18:25:41 +0530 Subject: [PATCH 16/25] extensive tests, diffwith for tuples --- .../differentiate_with.jl | 137 +++++++++++++++++- .../test/Back/DifferentiateWith/test.jl | 6 +- 2 files changed, 138 insertions(+), 5 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index fb3a1df72..d31f4476a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,6 +1,8 @@ -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}} -function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) +function Mooncake.rrule!!( + dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{Union{<:Number,<:Tuple}} +) primal_func = primal(dw) primal_x = primal(x) (; f, backend) = primal_func @@ -20,7 +22,22 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number return NoRData(), only(tx) end - return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! + # output is a Tuple, NTuple + function pullback_tuple!!(dy::Tuple) + tx = DI.pullback(f, backend, primal_x, (dy,)) + @assert only(tx) isa rdata_type(typeof(primal_x)) + return NoRData(), only(tx) + end + + pullback = if typeof(primal(y)) <: Number + pullback_scalar!! + elseif typeof(primal(y)) <: Array + pullback_array!! + else + pullback_tuple!! + end + + return y, pullback end function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) @@ -46,5 +63,117 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra return NoRData(), NoRData() end - return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! + # output is a Tuple, NTuple + function pullback_tuple!!(dy::Tuple) + tx = DI.pullback(f, backend, primal_x, (dy,)) + @assert first(only(tx)) isa rdata_type(typeof(first(primal_x))) + fdata_arg .+= only(tx) + return NoRData(), NoRData() + end + + pullback = if typeof(primal(y)) <: Number + pullback_scalar!! + elseif typeof(primal(y)) <: Array + pullback_array!! + else + pullback_tuple!! + end + + return y, pullback +end + +function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) + return Any[], Any[] +end + +function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) + test_cases = reduce( + vcat, + map([Float64, Float32]) do P + return Any[ + (false, :stability_and_allocs, nothing, cosh, P(0.3)), + (false, :stability_and_allocs, nothing, sinh, P(0.3)), + (false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, P(0.5)), + (false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, P(0.5)), + (false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, P(5.0)), + (false, :stability_and_allocs, nothing, Base.FastMath.sincos, P(3.0)), + ] + end, + ) + push!(test_cases, (false, :stability, nothing, copy, randn(5, 4))) + push!(test_cases, ( + # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent. + false, + :none, + nothing, + x -> +(x...), + randn(33), + )) + push!( + test_cases, + ( + false, + :none, + nothing, + ( + function (x) + rx = Ref(x) + return Base.pointerref( + Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1 + ) + end + ), + 5.0, + ), + ) + push!( + test_cases, + ( + false, + :none, + nothing, + x -> (pointerset(pointer(x), UInt8(3), 2, 1); x), + rand(UInt8, 5), + ), + ) + push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, [1.0])) + push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[1.0])) + push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[[1.0]])) + push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctlz_int, 5)) + push!( + test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctpop_int, 5) + ) + push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.cttz_int, 5)) + push!( + test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0) + ) + push!( + test_cases, + (false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0f0), + ) + push!(test_cases, (false, :stability, nothing, deepcopy, 5.0)) + push!(test_cases, (false, :stability, nothing, deepcopy, randn(5))) + push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.1)) + push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.0f1)) + push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.1)) + push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.0f1)) + push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.1)) + push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.0f1)) + + # additional_test_set = Mooncake.tangent_test_cases() + # function is_valid(f) + # try + # isa(f([1.0, 2.0]), Union{<:Number,<:AbstractArray}) + # catch + # false + # end + # end + # for test in additional_test_set + # if is_valid(test[2]) + # push!(test_cases, test) + # end + # end + + memory = Any[] + return test_cases, memory end diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 43e58b48e..a27639b58 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -7,7 +7,7 @@ using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using Zygote: Zygote using Mooncake: Mooncake -using Test +using StableRNGs, Test LOGGING = get(ENV, "CI", "false") == "false" @@ -30,3 +30,7 @@ test_differentiation( excluded=SECOND_ORDER, logging=LOGGING, ) + +@testset "new" begin + Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith)) +end \ No newline at end of file From b4fe0f83955377cfeac9434c01d41642b119eb50 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Thu, 29 May 2025 19:43:42 +0530 Subject: [PATCH 17/25] tests. --- .../differentiate_with.jl | 6 +++--- .../test/Back/DifferentiateWith/test.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index d31f4476a..5f7d68a3f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,7 +1,7 @@ @is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}} function Mooncake.rrule!!( - dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{Union{<:Number,<:Tuple}} + dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}} ) primal_func = primal(dw) primal_x = primal(x) @@ -82,11 +82,11 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra return y, pullback end -function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) +function Mooncake.generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) return Any[], Any[] end -function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) +function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) test_cases = reduce( vcat, map([Float64, Float32]) do P diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index a27639b58..e111d5279 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -33,4 +33,4 @@ test_differentiation( @testset "new" begin Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith)) -end \ No newline at end of file +end From ec4b75d7e7bab55e84f26601fae5f0187c9ced88 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sun, 1 Jun 2025 03:59:34 +0530 Subject: [PATCH 18/25] tests, inc primal handling --- DifferentiationInterface/Project.toml | 2 +- .../DifferentiationInterfaceMooncakeExt.jl | 5 +- .../differentiate_with.jl | 224 ++++++++++-------- .../test/Back/DifferentiateWith/test.jl | 2 +- 4 files changed, 130 insertions(+), 103 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 486b91a5b..f4b479efb 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -72,7 +72,7 @@ JET = "0.9" JLArrays = "0.2.0" JuliaFormatter = "1,2" LinearAlgebra = "1" -Mooncake = "0.4.88" +Mooncake = "0.4.121" Pkg = "1" PolyesterForwardDiff = "0.1.2" Random = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 82eeb7300..bd7214186 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -13,11 +13,14 @@ using Mooncake: value_and_pullback!!, zero_tangent, rdata_type, + fdata, + rdata, + tangent_type, + NoTangent, @is_primitive, zero_fcodual, MinimalCtx, NoRData, - fdata, primal DI.check_available(::AutoMooncake) = true diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 5f7d68a3f..46d59d64c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,5 +1,6 @@ @is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}} +# nested vectors, similar are not supported function Mooncake.rrule!!( dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}} ) @@ -10,31 +11,41 @@ function Mooncake.rrule!!( # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) - tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) - @assert only(tx) isa rdata_type(typeof(primal_x)) - return NoRData(), only(tx) + tx = DI.pullback(f, backend, primal_x, (y.dx,)) + @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) + return NoRData(), rdata(only(tx)) end # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert only(tx) isa rdata_type(typeof(primal_x)) - return NoRData(), only(tx) + @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) + return NoRData(), rdata(only(tx)) end # output is a Tuple, NTuple function pullback_tuple!!(dy::Tuple) tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert only(tx) isa rdata_type(typeof(primal_x)) - return NoRData(), only(tx) + @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) + return NoRData(), rdata(only(tx)) end - pullback = if typeof(primal(y)) <: Number + # inputs are non Differentiable + function pullback_nodiff!!(dy::NoRData) + @assert tangent_type(typeof(primal(x))) <: NoTangent + return NoRData(), dy + end + + pullback = if tangent_type(typeof(primal(x))) <: NoTangent + pullback_nodiff!! + elseif typeof(primal(y)) <: Number pullback_scalar!! elseif typeof(primal(y)) <: Array pullback_array!! - else + elseif typeof(primal(y)) <: Tuple pullback_tuple!! + else + error("$(typeof(primal(y))) primal type currently not supported.") end return y, pullback @@ -43,14 +54,14 @@ end function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) primal_func = primal(dw) primal_x = primal(x) - fdata_arg = fdata(x.dx) + fdata_arg = x.dx (; f, backend) = primal_func y = zero_fcodual(f(primal_x)) # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) - tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) - @assert first(only(tx)) isa rdata_type(typeof(first(primal_x))) + tx = DI.pullback(f, backend, primal_x, (y.dx,)) + @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) fdata_arg .+= only(tx) return NoRData(), dy end @@ -58,7 +69,7 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert first(only(tx)) isa rdata_type(typeof(first(primal_x))) + @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) fdata_arg .+= only(tx) return NoRData(), NoRData() end @@ -66,17 +77,27 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra # output is a Tuple, NTuple function pullback_tuple!!(dy::Tuple) tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert first(only(tx)) isa rdata_type(typeof(first(primal_x))) + @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) fdata_arg .+= only(tx) return NoRData(), NoRData() end - pullback = if typeof(primal(y)) <: Number + # inputs are non Differentiable + function pullback_nodiff!!(dy::NoRData) + @assert tangent_type(typeof(primal(x))) <: Vector{NoTangent} + return NoRData(), dy + end + + pullback = if tangent_type(typeof(primal(x))) <: Vector{NoTangent} + pullback_nodiff!! + elseif typeof(primal(y)) <: Number pullback_scalar!! - elseif typeof(primal(y)) <: Array + elseif typeof(primal(y)) <: AbstractArray pullback_array!! - else + elseif typeof(primal(y)) <: Tuple pullback_tuple!! + else + error("$(typeof(primal(y))) primal type currently not supported.") end return y, pullback @@ -89,90 +110,93 @@ end function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) test_cases = reduce( vcat, - map([Float64, Float32]) do P - return Any[ - (false, :stability_and_allocs, nothing, cosh, P(0.3)), - (false, :stability_and_allocs, nothing, sinh, P(0.3)), - (false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, P(0.5)), - (false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, P(0.5)), - (false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, P(5.0)), - (false, :stability_and_allocs, nothing, Base.FastMath.sincos, P(3.0)), - ] - end, + map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F + map([Float64, Float32]) do P + return Any[ + (false, :stability, nothing, F(cosh), P(0.3)), + (false, :stability, nothing, F(sinh), P(0.3)), + (false, :stability, nothing, F(Base.FastMath.exp10_fast), P(0.5)), + (false, :stability, nothing, F(Base.FastMath.exp2_fast), P(0.5)), + (false, :stability, nothing, F(Base.FastMath.exp_fast), P(5.0)), + (false, :none, nothing, F(copy), rand(Int32, 5)), + ] + end + end..., ) - push!(test_cases, (false, :stability, nothing, copy, randn(5, 4))) - push!(test_cases, ( - # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent. - false, - :none, - nothing, - x -> +(x...), - randn(33), - )) - push!( - test_cases, - ( - false, - :none, - nothing, - ( - function (x) - rx = Ref(x) - return Base.pointerref( - Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1 - ) - end - ), - 5.0, - ), - ) - push!( - test_cases, - ( - false, - :none, - nothing, - x -> (pointerset(pointer(x), UInt8(3), 2, 1); x), - rand(UInt8, 5), - ), - ) - push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, [1.0])) - push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[1.0])) - push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[[1.0]])) - push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctlz_int, 5)) - push!( - test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctpop_int, 5) - ) - push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.cttz_int, 5)) - push!( - test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0) - ) - push!( - test_cases, - (false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0f0), - ) - push!(test_cases, (false, :stability, nothing, deepcopy, 5.0)) - push!(test_cases, (false, :stability, nothing, deepcopy, randn(5))) - push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.1)) - push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.0f1)) - push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.1)) - push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.0f1)) - push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.1)) - push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.0f1)) - - # additional_test_set = Mooncake.tangent_test_cases() - # function is_valid(f) - # try - # isa(f([1.0, 2.0]), Union{<:Number,<:AbstractArray}) - # catch - # false - # end - # end - # for test in additional_test_set - # if is_valid(test[2]) - # push!(test_cases, test) - # end - # end + + map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F + map([Float64, Float32]) do P + push!( + test_cases, + Any[ + (false, :stability, nothing, F(Base.FastMath.sincos), P(3.0)), + (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[P(1.0)]), + ]..., + ) + end + end + + map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F + push!( + test_cases, + Any[ + (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5), + (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5), + (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5), + ]..., + ) + end + + map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F + push!( + test_cases, + Any[ + (false, :stability, nothing, copy, randn(5, 4)), + ( + # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent. + false, + :none, + nothing, + F(x -> +(x...)), + randn(33), + ), + ( + false, + :none, + nothing, + (F( + function (x) + rx = Ref(x) + return Base.pointerref( + Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1 + ) + end, + )), + 5.0, + ), + (false, :none, nothing, F(Mooncake.__vec_to_tuple), [1.0]), + # (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), DI.basis fails for this, correct it! + (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5), + (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5), + (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5), + ( + false, + :stability, + nothing, + F(Mooncake.IntrinsicsWrappers.abs_float), + 5.0f0, + ), + (false, :stability, nothing, F(deepcopy), 5.0), + (false, :stability, nothing, F(deepcopy), randn(5)), + (false, :stability_and_allocs, nothing, F(sin), 1.1), + (false, :stability_and_allocs, nothing, F(sin), 1.0f1), + (false, :stability_and_allocs, nothing, F(cos), 1.1), + (false, :stability_and_allocs, nothing, F(cos), 1.0f1), + (false, :stability_and_allocs, nothing, F(exp), 1.1), + (false, :stability_and_allocs, nothing, F(exp), 1.0f1), + ]..., + ) + end memory = Any[] return test_cases, memory diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index e111d5279..b9836f6e7 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -31,6 +31,6 @@ test_differentiation( logging=LOGGING, ) -@testset "new" begin +@testset "Mooncake tests" begin Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith)) end From 0f0b9fcb8caccd504b2fbd32aa7124f372298279 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sat, 7 Jun 2025 05:21:04 +0530 Subject: [PATCH 19/25] changes from reviews --- .../differentiate_with.jl | 147 +++++++++++++----- 1 file changed, 106 insertions(+), 41 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 46d59d64c..464767feb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,6 +1,8 @@ @is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}} -# nested vectors, similar are not supported +# nested vectors (eg. [[1.0]]), Tuples (eg. ((1.0,),)) or similar (eg. [(1.0,)]) primal types are not supported by DI yet ! +# This is because basis construction (DI.basis) does not have overloads for these types. +# For details, refer commented out test cases to see where the pullback creation fails. function Mooncake.rrule!!( dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}} ) @@ -45,7 +47,9 @@ function Mooncake.rrule!!( elseif typeof(primal(y)) <: Tuple pullback_tuple!! else - error("$(typeof(primal(y))) primal type currently not supported.") + error( + "For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.", + ) end return y, pullback @@ -97,7 +101,9 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra elseif typeof(primal(y)) <: Tuple pullback_tuple!! else - error("$(typeof(primal(y))) primal type currently not supported.") + error( + "For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.", + ) end return y, pullback @@ -113,40 +119,37 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F map([Float64, Float32]) do P return Any[ - (false, :stability, nothing, F(cosh), P(0.3)), - (false, :stability, nothing, F(sinh), P(0.3)), - (false, :stability, nothing, F(Base.FastMath.exp10_fast), P(0.5)), - (false, :stability, nothing, F(Base.FastMath.exp2_fast), P(0.5)), - (false, :stability, nothing, F(Base.FastMath.exp_fast), P(5.0)), - (false, :none, nothing, F(copy), rand(Int32, 5)), + # (false, :none, nothing, F(identity), ((1.0,),)), # (DI.basis fails for this, correct it!) + # (false, :none, nothing, F(identity), [[1.0]]), # (DI.basis fails for this, correct it!) + (false, :stability_and_allocs, nothing, F(cosh), P(0.3)), + (false, :stability_and_allocs, nothing, F(sinh), P(0.3)), + ( + false, + :stability_and_allocs, + nothing, + F(Base.FastMath.exp10_fast), + P(0.5), + ), + ( + false, + :stability_and_allocs, + nothing, + F(Base.FastMath.exp2_fast), + P(0.5), + ), + ( + false, + :stability_and_allocs, + nothing, + F(Base.FastMath.exp_fast), + P(5.0), + ), + (false, :stability, nothing, F(copy), rand(Int32, 5)), ] end end..., ) - map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F - map([Float64, Float32]) do P - push!( - test_cases, - Any[ - (false, :stability, nothing, F(Base.FastMath.sincos), P(3.0)), - (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[P(1.0)]), - ]..., - ) - end - end - - map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F - push!( - test_cases, - Any[ - (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5), - (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5), - (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5), - ]..., - ) - end - map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F push!( test_cases, @@ -155,14 +158,14 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff ( # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent. false, - :none, + :stability, nothing, F(x -> +(x...)), randn(33), ), ( false, - :none, + :stability, nothing, (F( function (x) @@ -174,19 +177,36 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff )), 5.0, ), - (false, :none, nothing, F(Mooncake.__vec_to_tuple), [1.0]), - # (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), DI.basis fails for this, correct it! - (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5), - (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5), - (false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5), + # (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), # (DI.basis fails for this, correct it!) ( false, - :stability, + :stability_and_allocs, + nothing, + F(Mooncake.IntrinsicsWrappers.ctlz_int), + 5, + ), + ( + false, + :stability_and_allocs, + nothing, + F(Mooncake.IntrinsicsWrappers.ctpop_int), + 5, + ), + ( + false, + :stability_and_allocs, + nothing, + F(Mooncake.IntrinsicsWrappers.cttz_int), + 5, + ), + ( + false, + :stability_and_allocs, nothing, F(Mooncake.IntrinsicsWrappers.abs_float), 5.0f0, ), - (false, :stability, nothing, F(deepcopy), 5.0), + (false, :stability_and_allocs, nothing, F(deepcopy), 5.0), (false, :stability, nothing, F(deepcopy), randn(5)), (false, :stability_and_allocs, nothing, F(sin), 1.1), (false, :stability_and_allocs, nothing, F(sin), 1.0f1), @@ -198,6 +218,51 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff ) end + map([(x) -> DI.DifferentiateWith(x, DI.AutoForwardDiff())]) do F + map([Float64, Float32]) do P + push!( + test_cases, + Any[ + ( + false, + :stability_and_allocs, + nothing, + F(Base.FastMath.sincos), + P(3.0), + ), + (false, :none, nothing, F(Mooncake.__vec_to_tuple), [P(1.0)]), + ]..., + ) + end + + push!( + test_cases, + Any[ + ( + false, + :stability_and_allocs, + nothing, + F(Mooncake.IntrinsicsWrappers.ctlz_int), + 5, + ), + ( + false, + :stability_and_allocs, + nothing, + F(Mooncake.IntrinsicsWrappers.ctpop_int), + 5, + ), + ( + false, + :stability_and_allocs, + nothing, + F(Mooncake.IntrinsicsWrappers.cttz_int), + 5, + ), + ]..., + ) + end + memory = Any[] return test_cases, memory end From d94f146503b4368cbd75a4abd8b642453703a832 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 13 Jun 2025 20:28:39 +0200 Subject: [PATCH 20/25] Apply suggestions from code review --- .../docs/src/faq/differentiability.md | 2 +- .../differentiate_with.jl | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/DifferentiationInterface/docs/src/faq/differentiability.md b/DifferentiationInterface/docs/src/faq/differentiability.md index af16c6a49..5845e80f5 100644 --- a/DifferentiationInterface/docs/src/faq/differentiability.md +++ b/DifferentiationInterface/docs/src/faq/differentiability.md @@ -112,4 +112,4 @@ There are, however, translation utilities: Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend. -Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file +Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 464767feb..556b63537 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -40,11 +40,11 @@ function Mooncake.rrule!!( pullback = if tangent_type(typeof(primal(x))) <: NoTangent pullback_nodiff!! - elseif typeof(primal(y)) <: Number + elseif primal(y) isa Number pullback_scalar!! - elseif typeof(primal(y)) <: Array + elseif primal(y) <: AbstractArray pullback_array!! - elseif typeof(primal(y)) <: Tuple + elseif primal(y) <: Tuple pullback_tuple!! else error( @@ -94,11 +94,11 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra pullback = if tangent_type(typeof(primal(x))) <: Vector{NoTangent} pullback_nodiff!! - elseif typeof(primal(y)) <: Number + elseif primal(y) isa Number pullback_scalar!! - elseif typeof(primal(y)) <: AbstractArray + elseif primal(y) <: AbstractArray pullback_array!! - elseif typeof(primal(y)) <: Tuple + elseif primal(y) <: Tuple pullback_tuple!! else error( From c982f46917c4b57ae1590d5b586c7444f05f1074 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 13 Jun 2025 21:08:43 +0200 Subject: [PATCH 21/25] Simplify Mooncake rule tests, add ChainRules rule tests --- .../differentiate_with.jl | 2 +- .../differentiate_with.jl | 233 ++---------------- .../test/Back/DifferentiateWith/test.jl | 90 ++++++- 3 files changed, 106 insertions(+), 219 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index 23f9c9b0c..292372b81 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) (; f, backend) = dw y = f(x) - prep_same = DI.prepare_pullback_same_point_nokwarg(Val(true), f, backend, x, (y,)) + prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,)) function pullbackfunc(dy) tx = DI.pullback(f, prep_same, backend, x, (dy,)) return (NoTangent(), only(tx)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 556b63537..52255c06e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,11 +1,23 @@ -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Any} + +struct MooncakeDifferentiateWithError <: Exception + F::Type + X::Type + Y::Type + function MooncakeDifferentiateWithError(::F, ::X, ::Y) where {F,X,Y} + return new(F, X, Y) + end +end + +function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) + return print( + io, + "MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.", + ) +end -# nested vectors (eg. [[1.0]]), Tuples (eg. ((1.0,),)) or similar (eg. [(1.0,)]) primal types are not supported by DI yet ! -# This is because basis construction (DI.basis) does not have overloads for these types. # For details, refer commented out test cases to see where the pullback creation fails. -function Mooncake.rrule!!( - dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}} -) +function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) primal_func = primal(dw) primal_x = primal(x) (; f, backend) = primal_func @@ -25,31 +37,12 @@ function Mooncake.rrule!!( return NoRData(), rdata(only(tx)) end - # output is a Tuple, NTuple - function pullback_tuple!!(dy::Tuple) - tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) - return NoRData(), rdata(only(tx)) - end - - # inputs are non Differentiable - function pullback_nodiff!!(dy::NoRData) - @assert tangent_type(typeof(primal(x))) <: NoTangent - return NoRData(), dy - end - - pullback = if tangent_type(typeof(primal(x))) <: NoTangent - pullback_nodiff!! - elseif primal(y) isa Number + pullback = if primal(y) isa Number pullback_scalar!! - elseif primal(y) <: AbstractArray + elseif primal(y) isa AbstractArray pullback_array!! - elseif primal(y) <: Tuple - pullback_tuple!! else - error( - "For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.", - ) + throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) end return y, pullback @@ -78,191 +71,13 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra return NoRData(), NoRData() end - # output is a Tuple, NTuple - function pullback_tuple!!(dy::Tuple) - tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) - fdata_arg .+= only(tx) - return NoRData(), NoRData() - end - - # inputs are non Differentiable - function pullback_nodiff!!(dy::NoRData) - @assert tangent_type(typeof(primal(x))) <: Vector{NoTangent} - return NoRData(), dy - end - - pullback = if tangent_type(typeof(primal(x))) <: Vector{NoTangent} - pullback_nodiff!! - elseif primal(y) isa Number + pullback = if primal(y) isa Number pullback_scalar!! - elseif primal(y) <: AbstractArray + elseif primal(y) isa AbstractArray pullback_array!! - elseif primal(y) <: Tuple - pullback_tuple!! else - error( - "For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.", - ) + throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) end return y, pullback end - -function Mooncake.generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) - return Any[], Any[] -end - -function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) - test_cases = reduce( - vcat, - map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F - map([Float64, Float32]) do P - return Any[ - # (false, :none, nothing, F(identity), ((1.0,),)), # (DI.basis fails for this, correct it!) - # (false, :none, nothing, F(identity), [[1.0]]), # (DI.basis fails for this, correct it!) - (false, :stability_and_allocs, nothing, F(cosh), P(0.3)), - (false, :stability_and_allocs, nothing, F(sinh), P(0.3)), - ( - false, - :stability_and_allocs, - nothing, - F(Base.FastMath.exp10_fast), - P(0.5), - ), - ( - false, - :stability_and_allocs, - nothing, - F(Base.FastMath.exp2_fast), - P(0.5), - ), - ( - false, - :stability_and_allocs, - nothing, - F(Base.FastMath.exp_fast), - P(5.0), - ), - (false, :stability, nothing, F(copy), rand(Int32, 5)), - ] - end - end..., - ) - - map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F - push!( - test_cases, - Any[ - (false, :stability, nothing, copy, randn(5, 4)), - ( - # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent. - false, - :stability, - nothing, - F(x -> +(x...)), - randn(33), - ), - ( - false, - :stability, - nothing, - (F( - function (x) - rx = Ref(x) - return Base.pointerref( - Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1 - ) - end, - )), - 5.0, - ), - # (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), # (DI.basis fails for this, correct it!) - ( - false, - :stability_and_allocs, - nothing, - F(Mooncake.IntrinsicsWrappers.ctlz_int), - 5, - ), - ( - false, - :stability_and_allocs, - nothing, - F(Mooncake.IntrinsicsWrappers.ctpop_int), - 5, - ), - ( - false, - :stability_and_allocs, - nothing, - F(Mooncake.IntrinsicsWrappers.cttz_int), - 5, - ), - ( - false, - :stability_and_allocs, - nothing, - F(Mooncake.IntrinsicsWrappers.abs_float), - 5.0f0, - ), - (false, :stability_and_allocs, nothing, F(deepcopy), 5.0), - (false, :stability, nothing, F(deepcopy), randn(5)), - (false, :stability_and_allocs, nothing, F(sin), 1.1), - (false, :stability_and_allocs, nothing, F(sin), 1.0f1), - (false, :stability_and_allocs, nothing, F(cos), 1.1), - (false, :stability_and_allocs, nothing, F(cos), 1.0f1), - (false, :stability_and_allocs, nothing, F(exp), 1.1), - (false, :stability_and_allocs, nothing, F(exp), 1.0f1), - ]..., - ) - end - - map([(x) -> DI.DifferentiateWith(x, DI.AutoForwardDiff())]) do F - map([Float64, Float32]) do P - push!( - test_cases, - Any[ - ( - false, - :stability_and_allocs, - nothing, - F(Base.FastMath.sincos), - P(3.0), - ), - (false, :none, nothing, F(Mooncake.__vec_to_tuple), [P(1.0)]), - ]..., - ) - end - - push!( - test_cases, - Any[ - ( - false, - :stability_and_allocs, - nothing, - F(Mooncake.IntrinsicsWrappers.ctlz_int), - 5, - ), - ( - false, - :stability_and_allocs, - nothing, - F(Mooncake.IntrinsicsWrappers.ctpop_int), - 5, - ), - ( - false, - :stability_and_allocs, - nothing, - F(Mooncake.IntrinsicsWrappers.cttz_int), - 5, - ), - ]..., - ) - end - - memory = Any[] - return test_cases, memory -end diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index b9836f6e7..7922310e4 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -1,23 +1,41 @@ using Pkg -Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"]) +Pkg.add(["ChainRulesTestUtils", "FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"]) +using ChainRulesTestUtils: ChainRulesTestUtils using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using Zygote: Zygote using Mooncake: Mooncake -using StableRNGs, Test +using StableRNGs +using Test LOGGING = get(ENV, "CI", "false") == "false" +struct ADBreaker{F} + f::F +end + +function (adb::ADBreaker)(x::Number) + copyto!(Float64[0], x) # break ForwardDiff and Zygote + return adb.f(x) +end + +function (adb::ADBreaker)(x::AbstractArray) + copyto!(similar(x, Float64), x) # break ForwardDiff and Zygote + return adb.f(x) +end + function differentiatewith_scenarios() - bad_scens = # these closurified scenarios have mutation and type constraints - filter( - DIT.default_scenarios(; include_normal=false, include_closurified=true) - ) do scen - DIT.function_place(scen) == :out - end + outofplace_scens = filter(DIT.default_scenarios()) do scen + DIT.function_place(scen) == :out + end + # with bad_scens, everything would break + bad_scens = map(outofplace_scens) do scen + DIT.change_function(scen, ADBreaker(scen.f)) + end + # with good_scens, everything is fixed good_scens = map(bad_scens) do scen DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff())) end @@ -29,8 +47,62 @@ test_differentiation( differentiatewith_scenarios(); excluded=SECOND_ORDER, logging=LOGGING, + testset_name="DI tests", ) +@testset "ChainRules tests" begin + @testset for scen in filter(differentiatewith_scenarios()) do scen + DIT.operator(scen) == :pullback + end + ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol=1e-4) + end +end; + @testset "Mooncake tests" begin - Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith)) + @testset for scen in filter(differentiatewith_scenarios()) do scen + DIT.operator(scen) == :pullback + end + Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true) + end +end; + +@testset "Mooncake errors" begin + MooncakeDifferentiateWithError = + Base.get_extension( + DifferentiationInterface, :DifferentiationInterfaceMooncakeExt + ).MooncakeDifferentiateWithError + + e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) + @test sprint(showerror, e) == + "MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported." + + f_num2tup(x::Number) = (x,) + f_vec2tup(x::Vector) = (first(x),) + f_tup2num(x::Tuple{<:Number}) = only(x) + f_tup2vec(x::Tuple{<:Number}) = [only(x)] + + @test_throws MooncakeDifferentiateWithError pullback( + DifferentiateWith(f_num2tup, AutoFiniteDiff()), + AutoMooncake(; config=nothing), + 1.0, + ((2.0,),), + ) + @test_throws MooncakeDifferentiateWithError pullback( + DifferentiateWith(f_vec2tup, AutoFiniteDiff()), + AutoMooncake(; config=nothing), + [1.0], + ((2.0,),), + ) + @test_throws MethodError pullback( + DifferentiateWith(f_tup2num, AutoFiniteDiff()), + AutoMooncake(; config=nothing), + (1.0,), + (2.0,), + ) + @test_throws MethodError pullback( + DifferentiateWith(f_tup2vec, AutoFiniteDiff()), + AutoMooncake(; config=nothing), + (1.0,), + ([2.0],), + ) end From 749fea5bf41aaea9e90c9e7dda19518effbbe3bd Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 13 Jun 2025 21:47:28 +0200 Subject: [PATCH 22/25] Format --- DifferentiationInterface/test/Back/DifferentiateWith/test.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 7922310e4..d2bf57f88 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -68,9 +68,7 @@ end; @testset "Mooncake errors" begin MooncakeDifferentiateWithError = - Base.get_extension( - DifferentiationInterface, :DifferentiationInterfaceMooncakeExt - ).MooncakeDifferentiateWithError + Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) @test sprint(showerror, e) == From 9e5ecfd3ba6d152c08004061430494bd0ba06234 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 14 Jun 2025 07:25:47 +0200 Subject: [PATCH 23/25] Update differentiate_with.jl --- DifferentiationInterface/src/misc/differentiate_with.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index 08953d9d7..256d46f75 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -17,7 +17,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper). !!! warning - When using Mooncake as a substitute backend via `DifferentiateWith(f, AutoMooncake())`. The function `f` must not close over any active data. + When using `DifferentiateWith(f, AutoSomething())`, the function `f` must not close over any active data. As of now, we cannot differentiate with respect to parameters stored inside `f`. # Fields From 1e85f17e35a35f2cdd02f768fd0f750a0556455d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 14 Jun 2025 14:42:55 +0200 Subject: [PATCH 24/25] Restrict to array of numbers --- .../DifferentiationInterfaceMooncakeExt/differentiate_with.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 52255c06e..bab1f2442 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -48,7 +48,9 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number return y, pullback end -function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) +function Mooncake.rrule!!( + dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}} +) primal_func = primal(dw) primal_x = primal(x) fdata_arg = x.dx From ff5c4e26eead350812283beb9798c605731aeba5 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 18 Jun 2025 17:35:53 +0200 Subject: [PATCH 25/25] Update DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl --- .../DifferentiationInterfaceMooncakeExt/differentiate_with.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index bab1f2442..3b4fb91c3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -16,7 +16,6 @@ function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) ) end -# For details, refer commented out test cases to see where the pullback creation fails. function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) primal_func = primal(dw) primal_x = primal(x)