Skip to content

Explicit Taylor solvers #2620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions lib/OrdinaryDiffEqTaylorSeries/LICENSE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
The OrdinaryDiffEq.jl package is licensed under the MIT "Expat" License:

> Copyright (c) 2016-2020: ChrisRackauckas, Yingbo Ma, Julia Computing Inc, and
> other contributors:
>
> https://github.com/SciML/OrdinaryDiffEq.jl/graphs/contributors
>
> Permission is hereby granted, free of charge, to any person obtaining a copy
> of this software and associated documentation files (the "Software"), to deal
> in the Software without restriction, including without limitation the rights
> to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
> copies of the Software, and to permit persons to whom the Software is
> furnished to do so, subject to the following conditions:
>
> The above copyright notice and this permission notice shall be included in all
> copies or substantial portions of the Software.
>
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
> IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
> FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
> AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
> LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
> OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
> SOFTWARE.
51 changes: 51 additions & 0 deletions lib/OrdinaryDiffEqTaylorSeries/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name = "OrdinaryDiffEqTaylorSeries"
uuid = "9c7f1690-dd92-42a3-8318-297ee24d8d39"
authors = ["ParamThakkar123 <paramthakkar864@gmail.com>"]
version = "1.1.0"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[compat]
DiffEqBase = "6.152.2"
DiffEqDevTools = "2.44.4"
FastBroadcast = "0.3.5"
LinearAlgebra = "<0.0.1, 1"
MuladdMacro = "0.2.4"
OrdinaryDiffEqCore = "1.1"
PrecompileTools = "1.2.1"
Preferences = "1.4.3"
Random = "<0.0.1, 1"
RecursiveArrayTools = "3.27.0"
Reexport = "1.2.2"
SafeTestsets = "0.1.0"
SciMLBase = "2.72.2"
Static = "1.1.1"
Symbolics = "6.28.0"
TaylorDiff = "0.3.1"
Test = "<0.0.1, 1"
TruncatedStacktraces = "1.4.0"
julia = "1.10"

[extras]
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test", "ODEProblemLibrary"]
60 changes: 60 additions & 0 deletions lib/OrdinaryDiffEqTaylorSeries/src/OrdinaryDiffEqTaylorSeries.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module OrdinaryDiffEqTaylorSeries

import OrdinaryDiffEqCore: alg_order, alg_stability_size, explicit_rk_docstring,
OrdinaryDiffEqAdaptiveAlgorithm, OrdinaryDiffEqMutableCache,
alg_cache,
OrdinaryDiffEqConstantCache, @fold, trivial_limiter!,
constvalue, @unpack, perform_step!, calculate_residuals, @cache,
calculate_residuals!, _ode_interpolant, _ode_interpolant!,
CompiledFloats, @OnDemandTableauExtract, initialize!,
perform_step!, OrdinaryDiffEqAlgorithm,
CompositeAlgorithm, _ode_addsteps!, copyat_or_push!,
AutoAlgSwitch, get_fsalfirstlast,
full_cache, DerivativeOrderNotPossibleError
import Static: False
import MuladdMacro: @muladd
import FastBroadcast: @..
import RecursiveArrayTools: recursivefill!, recursive_unitless_bottom_eltype
import LinearAlgebra: norm
using TruncatedStacktraces
using TaylorDiff, Symbolics
using TaylorDiff: make_seed, get_coefficient, append_coefficient, flatten
import DiffEqBase: @def
import OrdinaryDiffEqCore

using Reexport
@reexport using DiffEqBase

include("algorithms.jl")
include("alg_utils.jl")
include("TaylorSeries_caches.jl")
include("TaylorSeries_perform_step.jl")

import PrecompileTools
import Preferences

PrecompileTools.@compile_workload begin
lorenz = OrdinaryDiffEqCore.lorenz
lorenz_oop = OrdinaryDiffEqCore.lorenz_oop
solver_list = [ExplicitTaylor2()]
prob_list = []

if Preferences.@load_preference("PrecompileNoSpecialize", false)
push!(prob_list,
ODEProblem{true, SciMLBase.NoSpecialize}(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0)))
push!(prob_list,
ODEProblem{true, SciMLBase.NoSpecialize}(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0),
Float64[]))
end

for prob in prob_list, solver in solver_list
solve(prob, solver)(5.0)
end

prob_list = nothing
solver_list = nothing
end

export ExplicitTaylor2, ExplicitTaylor, DAETS

end
88 changes: 88 additions & 0 deletions lib/OrdinaryDiffEqTaylorSeries/src/TaylorSeries_caches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
@cache struct ExplicitTaylor2Cache{
uType, rateType, uNoUnitsType, StageLimiter, StepLimiter,
Thread} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
k1::rateType
k2::rateType
k3::rateType
utilde::uType
tmp::uType
atmp::uNoUnitsType
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
end

function alg_cache(alg::ExplicitTaylor2, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)
utilde = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
tmp = zero(u)
ExplicitTaylor2Cache(u, uprev, k1, k2, k3, utilde, tmp, atmp,
alg.stage_limiter!, alg.step_limiter!, alg.thread)
end
struct ExplicitTaylor2ConstantCache <: OrdinaryDiffEqConstantCache end
function alg_cache(alg::ExplicitTaylor2, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
ExplicitTaylor2ConstantCache()
end
# FSAL currently not used, providing dummy implementation to satisfy the interface
get_fsalfirstlast(cache::ExplicitTaylor2Cache, u) = (cache.k1, cache.k1)

@cache struct ExplicitTaylorCache{
P, jetType, uType, taylorType, uNoUnitsType, StageLimiter, StepLimiter,
Thread} <: OrdinaryDiffEqMutableCache
order::Val{P}
jet::jetType
u::uType
uprev::uType
utaylor::taylorType
utilde::uType
tmp::uType
atmp::uNoUnitsType
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
end

function alg_cache(alg::ExplicitTaylor{P}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {P, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
_, jet_iip = build_jet(f, p, Val(P), length(u))
utaylor = TaylorDiff.make_seed(u, zero(u), Val(P))
utilde = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
tmp = zero(u)
ExplicitTaylorCache(Val(P), jet_iip, u, uprev, utaylor, utilde, tmp, atmp,
alg.stage_limiter!, alg.step_limiter!, alg.thread)
end

struct ExplicitTaylorConstantCache{P, jetType} <: OrdinaryDiffEqConstantCache
order::Val{P}
jet::jetType
end
function alg_cache(::ExplicitTaylor{P}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {P, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
if u isa AbstractArray
jet, _ = build_jet(f, p, Val(P), length(u))
else
jet = build_jet(f, p, Val(P))
end
ExplicitTaylorConstantCache(Val(P), jet)
end

# FSAL currently not used, providing dummy implementation to satisfy the interface
get_fsalfirstlast(cache::ExplicitTaylorCache, u) = (cache.u, cache.u)
100 changes: 100 additions & 0 deletions lib/OrdinaryDiffEqTaylorSeries/src/TaylorSeries_perform_step.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
using TaylorDiff: TaylorDiff, extract_derivative, extract_derivative!

@inline make_taylor(all::Vararg{X, P}) where {P, X <: AbstractArray} = TaylorArray(
Base.first(all), Base.tail(all))
@inline make_taylor(all::Vararg{X, P}) where {P, X} = TaylorScalar(all)

function initialize!(integrator, cache::ExplicitTaylor2ConstantCache)
integrator.kshortsize = 3
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
end

@muladd function perform_step!(
integrator, cache::ExplicitTaylor2ConstantCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
k1 = f(uprev, p, t)
u1 = make_taylor(uprev, k1)
t1 = TaylorScalar{1}(t, one(t))
k2 = f(u1, p, t1).partials[1]
u = @.. uprev + dt * k1 + dt^2 / 2 * k2
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 3)
integrator.k[1] = k1
integrator.k[2] = k2
integrator.u = u
end

function initialize!(integrator, cache::ExplicitTaylor2Cache)
integrator.kshortsize = 3
resize!(integrator.k, integrator.kshortsize)
# Setup k pointers
integrator.k[1] = cache.k1
integrator.k[2] = cache.k2
integrator.k[3] = cache.k3
return nothing
end

@muladd function perform_step!(integrator, cache::ExplicitTaylor2Cache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack k1, k2, k3, utilde, tmp = cache

# The following code is written to be fully non-allocating
f(k1, uprev, p, t)
u1 = make_taylor(uprev, k1)
t1 = TaylorScalar{1}(t, one(t))
out1 = make_taylor(k1, k2)
f(out1, u1, p, t1)
@.. u = uprev + dt * k1 + dt^2 / 2 * k2
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 3)
return nothing
end

function initialize!(integrator, cache::ExplicitTaylorConstantCache{P}) where {P}
integrator.kshortsize = P
integrator.k = typeof(integrator.k)(undef, P)
end

@muladd function perform_step!(
integrator, cache::ExplicitTaylorConstantCache{P}, repeat_step = false) where {P}
@unpack t, dt, uprev, u, f, p = integrator
@unpack jet = cache
utaylor = jet(uprev, t)
u = map(x -> evaluate_polynomial(x, dt), utaylor)
if integrator.opts.adaptive
utilde = TaylorDiff.get_coefficient(utaylor, P) * dt^(P + 1)
atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, P + 1)
integrator.u = u
end

function initialize!(integrator, cache::ExplicitTaylorCache{P}) where {P}
integrator.kshortsize = P
resize!(integrator.k, P)
# Setup k pointers
for i in 1:P
integrator.k[i] = get_coefficient(cache.utaylor, i)
end
return nothing
end

@muladd function perform_step!(
integrator, cache::ExplicitTaylorCache{P}, repeat_step = false) where {P}
@unpack t, dt, uprev, u, f, p = integrator
@unpack jet, utaylor, utilde, tmp, atmp, thread = cache

jet(utaylor, uprev, t)
for i in eachindex(utaylor)
u[i] = @inline evaluate_polynomial(utaylor[i], dt)
end
if integrator.opts.adaptive
@.. broadcast=false thread=thread utilde=TaylorDiff.get_coefficient(utaylor, P) *
dt^(P + 1)
calculate_residuals!(atmp, utilde, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, P + 1)
return nothing
end
55 changes: 55 additions & 0 deletions lib/OrdinaryDiffEqTaylorSeries/src/alg_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
alg_order(::ExplicitTaylor2) = 2
alg_stability_size(alg::ExplicitTaylor2) = 1

alg_order(::ExplicitTaylor{P}) where {P} = P
alg_stability_size(alg::ExplicitTaylor) = 1

JET_CACHE = IdDict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also thread safety?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to make this thread-safe? Do I need to switch to other data structures

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you're going to use a global cache, you need to synchronize mutation to it. Alternatively, is there a way to make the cache local?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewriting you are guaranteed to have the same function right? So I don't think it's unsafe. What might happen is you might compile more times than you need to, but that would still be correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. If you have concurent modification of an IDDict without synchronization, that's a data race (which is UB). You can get torn writes, or other arbitrarily wrong results getting written.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay torn writes would be bad. So it just needs a lock on write?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best approach would be to make the cache non-global. As long as you have a separate cache per function (which should be totally fine) there won't be multithreading to cause problems.


function build_jet(f::ODEFunction{iip}, p, order, length = nothing) where {iip}
build_jet(f, Val{iip}(), p, order, length)
end

function build_jet(f, ::Val{iip}, p, order::Val{P}, length = nothing) where {P, iip}
if haskey(JET_CACHE, f)
list = JET_CACHE[f]
index = findfirst(x -> x[1] == order && x[2] == p, list)
index !== nothing && return list[index][3]
end
@variables t0::Real
u0 = isnothing(length) ? Symbolics.variable(:u0) : Symbolics.variables(:u0, 1:length)
if iip
@assert length isa Integer
f0 = similar(u0)
f(f0, u0, p, t0)
else
f0 = f(u0, p, t0)
end
u = TaylorDiff.make_seed(u0, f0, Val(1))
for index in 2:P
t = TaylorScalar{index - 1}(t0, one(t0))
if iip
fu = similar(u)
f(fu, u, p, t)
else
fu = f(u, p, t)
end
d = get_coefficient(fu, index - 1) / index
u = append_coefficient(u, d)
end
jet = build_function(u, u0, t0; expression = Val(false), cse = true)
if !haskey(JET_CACHE, f)
JET_CACHE[f] = []
end
push!(JET_CACHE[f], (order, p, jet))
return jet
end

# evaluate using Qin Jiushao's algorithm
@generated function evaluate_polynomial(t::TaylorScalar{T, P}, z) where {T, P}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this have to be generated?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pretty like the built-in macro @evalpoly, which unroll the iteration such that it doesn't need to go through a loop to use Horner's algorithm to evaluate polynomials

ex = :(v[$(P + 1)])
for i in P:-1:1
ex = :(v[$i] + z * $ex)
end
return :($(Expr(:meta, :inline)); v = flatten(t); $ex)
end
Loading
Loading