diff --git a/docs/src/index.md b/docs/src/index.md index 254eb68..eda25fa 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -21,6 +21,7 @@ Algorithmic differentiation: ```@docs AutoForwardDiff AutoPolyesterForwardDiff +AutoMooncakeForward ``` Finite differences: diff --git a/src/ADTypes.jl b/src/ADTypes.jl index efc2b1a..2059f40 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -43,6 +43,7 @@ export AutoChainRules, AutoGTPSA, AutoModelingToolkit, AutoMooncake, + AutoMooncakeForward, AutoPolyesterForwardDiff, AutoReverseDiff, AutoSymbolics, diff --git a/src/dense.jl b/src/dense.jl index b7e152a..9ad94dd 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -294,6 +294,31 @@ end mode(::AutoMooncake) = ReverseMode() +""" + AutoMooncakeForward + +Struct used to select the [Mooncake.jl](https://github.com/compintell/Mooncake.jl) backend for automatic differentiation in forward mode. + +Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). + +!!! info + + This type was introduced when forward mode became available in Mooncake.jl. It was kept separate from [`AutoMooncake`](@ref) in order to avoid requiring a breaking release of ADTypes.jl. + +# Constructors + + AutoMooncakeForward(; config=nothing) + +# Fields + + - `config`: either `nothing` or an instance of `Mooncake.Config` -- see the docstring of `Mooncake.Config` for more information. `AutoForwardMooncake(; config=nothing)` is equivalent to `AutoForwardMooncake(; config=Mooncake.Config())`, i.e. the default configuration. +""" +Base.@kwdef struct AutoMooncakeForward{Tconfig} <: AbstractADType + config::Tconfig = nothing +end + +mode(::AutoMooncakeForward) = ForwardMode() + """ AutoPolyesterForwardDiff{chunksize,T} diff --git a/test/dense.jl b/test/dense.jl index fa43edf..307403d 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -136,6 +136,16 @@ end @test ad.config === nothing end +@testset "AutoMooncakeForward" begin + ad = AutoMooncakeForward(; config = :config) + @test ad isa AbstractADType + @test ad isa AutoMooncakeForward + @test mode(ad) isa ForwardMode + @test ad.config === :config + ad = AutoMooncakeForward() + @test ad.config === nothing +end + @testset "AutoPolyesterForwardDiff" begin ad = AutoPolyesterForwardDiff() @test ad isa AbstractADType diff --git a/test/runtests.jl b/test/runtests.jl index 7bcd078..b4a5f5b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,6 +70,8 @@ function every_ad_with_options() AutoForwardDiff(chunksize = 3, tag = :tag), AutoGTPSA(), AutoGTPSA(descriptor = Val(:descriptor)), + AutoMooncake(; config = :config), + AutoMooncakeForward(; config = :config), AutoPolyesterForwardDiff(), AutoPolyesterForwardDiff(chunksize = 3, tag = :tag), AutoReverseDiff(),