From 1f735b31da44cfdd6b724fa98106019ddd7b8df7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 5 Jun 2025 19:59:54 +0200 Subject: [PATCH] feat: Add forward mode Mooncake --- Project.toml | 2 +- docs/src/index.md | 1 + src/ADTypes.jl | 1 + src/dense.jl | 25 +++++++++++++++++++++++++ test/dense.jl | 13 +++++++++++-- test/runtests.jl | 2 ++ 6 files changed, 41 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index ac4bea6..3d6d21e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ADTypes" uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = ["Vaibhav Dixit , Guillaume Dalle and contributors"] -version = "1.14.1" +version = "1.15.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/index.md b/docs/src/index.md index 254eb68..eda25fa 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -21,6 +21,7 @@ Algorithmic differentiation: ```@docs AutoForwardDiff AutoPolyesterForwardDiff +AutoMooncakeForward ``` Finite differences: diff --git a/src/ADTypes.jl b/src/ADTypes.jl index efc2b1a..2059f40 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -43,6 +43,7 @@ export AutoChainRules, AutoGTPSA, AutoModelingToolkit, AutoMooncake, + AutoMooncakeForward, AutoPolyesterForwardDiff, AutoReverseDiff, AutoSymbolics, diff --git a/src/dense.jl b/src/dense.jl index ff7b732..32786d5 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -294,6 +294,31 @@ end mode(::AutoMooncake) = ReverseMode() +""" + AutoMooncakeForward + +Struct used to select the [Mooncake.jl](https://github.com/compintell/Mooncake.jl) backend for automatic differentiation in forward mode. + +Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). + +!!! info + + This type was introduced when forward mode became available in Mooncake.jl. It was kept separate from [`AutoMooncake`](@ref) in order to avoid requiring a breaking release of ADTypes.jl. + +# Constructors + + AutoMooncakeForward(; config) + +# Fields + + - `config`: either `nothing` or an instance of `Mooncake.Config` -- see the docstring of `Mooncake.Config` for more information. `AutoForwardMooncake(; config=nothing)` is equivalent to `AutoForwardMooncake(; config=Mooncake.Config())`, i.e. the default configuration. +""" +Base.@kwdef struct AutoMooncakeForward{Tconfig} <: AbstractADType + config::Tconfig +end + +mode(::AutoMooncakeForward) = ForwardMode() + """ AutoPolyesterForwardDiff{chunksize,T} diff --git a/test/dense.jl b/test/dense.jl index a3fa24c..a5d6612 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -71,7 +71,8 @@ end @test ad.absstep === nothing @test ad.dir - ad = AutoFiniteDiff(; fdtype = Val(:central), fdjtype = Val(:forward), relstep = 1e-3, absstep = 1e-4, dir = false) + ad = AutoFiniteDiff(; fdtype = Val(:central), fdjtype = Val(:forward), + relstep = 1e-3, absstep = 1e-4, dir = false) @test ad isa AbstractADType @test ad isa AutoFiniteDiff @test mode(ad) isa ForwardMode @@ -126,13 +127,21 @@ end end @testset "AutoMooncake" begin - ad = AutoMooncake(; config=nothing) + ad = AutoMooncake(; config = nothing) @test ad isa AbstractADType @test ad isa AutoMooncake @test mode(ad) isa ReverseMode @test ad.config === nothing end +@testset "AutoMooncakeForward" begin + ad = AutoMooncakeForward(; config = nothing) + @test ad isa AbstractADType + @test ad isa AutoMooncakeForward + @test mode(ad) isa ForwardMode + @test ad.config === nothing +end + @testset "AutoPolyesterForwardDiff" begin ad = AutoPolyesterForwardDiff() @test ad isa AbstractADType diff --git a/test/runtests.jl b/test/runtests.jl index 7bcd078..b4a5f5b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,6 +70,8 @@ function every_ad_with_options() AutoForwardDiff(chunksize = 3, tag = :tag), AutoGTPSA(), AutoGTPSA(descriptor = Val(:descriptor)), + AutoMooncake(; config = :config), + AutoMooncakeForward(; config = :config), AutoPolyesterForwardDiff(), AutoPolyesterForwardDiff(chunksize = 3, tag = :tag), AutoReverseDiff(),