Skip to content

Commit 197d2e6

Browse files
committed
Add TaylorDiff
1 parent d468741 commit 197d2e6

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

src/ADTypes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export AutoChainRules,
4747
AutoReverseDiff,
4848
AutoSymbolics,
4949
AutoTapir,
50+
AutoTaylorDiff,
5051
AutoTracker,
5152
AutoZygote
5253
@public AbstractMode

src/dense.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,36 @@ function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize
197197
print(io, ")")
198198
end
199199

200+
"""
201+
AutoTaylorDiff{order}
202+
203+
Struct used to select the [TaylorDiff.jl](https://github.com/JuliaDiff/TaylorDiff.jl) backend for automatic differentiation.
204+
205+
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
206+
207+
# Constructors
208+
209+
AutoTaylorDiff(; order = 1)
210+
AutoTaylorDiff{order}()
211+
212+
# Type parameters
213+
214+
- `order`: the order of the Taylor-mode automatic differentiation
215+
"""
216+
struct AutoTaylorDiff{order} <: AbstractADType end
217+
218+
function AutoTaylorDiff(; order = 1)
219+
return AutoTaylorDiff{order}()
220+
end
221+
222+
mode(::AutoTaylorDiff) = ForwardMode()
223+
224+
function Base.show(io::IO, ::AutoTaylorDiff{order}) where {order}
225+
print(io, AutoTaylorDiff, "(")
226+
print(io, "tag=", repr(order; context = io))
227+
print(io, ")")
228+
end
229+
200230
"""
201231
AutoGTPSA{D}
202232

test/dense.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,18 @@ end
182182
@test !ad.safe_mode
183183
end
184184

185+
@testset "AutoTaylorDiff" begin
186+
ad = AutoTaylorDiff{2}()
187+
@test ad isa AbstractADType
188+
@test ad isa AutoTaylorDiff{2}
189+
@test mode(ad) isa ForwardMode
190+
191+
ad = AutoTaylorDiff()
192+
@test ad isa AbstractADType
193+
@test ad isa AutoTaylorDiff{1}
194+
@test mode(ad) isa ForwardMode
195+
end
196+
185197
@testset "AutoTracker" begin
186198
ad = AutoTracker()
187199
@test ad isa AbstractADType

0 commit comments

Comments
 (0)