Skip to content

Commit f5f1cc4

Browse files
Merge pull request #2305 from oscardssmith/os/optimize-StaticW
optimize StaticWOperator by using lu to allow saving the factorization
2 parents ba80b06 + 9474b35 commit f5f1cc4

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ import FunctionWrappersWrappers
1414
using DiffEqBase
1515

1616
import LinearAlgebra
17-
import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm
17+
import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm, lu
18+
import LinearAlgebra: LowerTriangular, UpperTriangular
1819
import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros
1920

2021
import InteractiveUtils
2122
import ArrayInterface
2223

2324
import StaticArrayInterface
25+
import StaticArrays
2426
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA,
2527
StaticMatrix
2628

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,31 @@
11
const ROSENBROCK_INV_CUTOFF = 7 # https://github.com/SciML/OrdinaryDiffEq.jl/pull/1539
22

3-
struct StaticWOperator{isinv, T} <: AbstractSciMLOperator{T}
3+
struct StaticWOperator{isinv, T, F} <: AbstractSciMLOperator{T}
44
W::T
5+
F::F
56
function StaticWOperator(W::T, callinv = true) where {T}
6-
isinv = size(W, 1) <= ROSENBROCK_INV_CUTOFF
7+
n = size(W, 1)
8+
isinv = n <= ROSENBROCK_INV_CUTOFF
79

10+
F = if isinv && callinv
11+
# this should be in ArrayInterface but can't be for silly reasons
12+
# doing to how StaticArrays and StaticArraysCore are split up
13+
StaticArrays.LU(LowerTriangular(W), UpperTriangular(W), SVector{n}(1:n))
14+
else
15+
lu(W, check=false)
16+
end
817
# when constructing W for the first time for the type
918
# inv(W) can be singular
1019
_W = if isinv && callinv
1120
inv(W)
1221
else
1322
W
1423
end
15-
new{isinv, T}(_W)
24+
new{isinv, T, typeof(F)}(_W, F)
1625
end
1726
end
1827
isinv(W::StaticWOperator{S}) where {S} = S
19-
Base.:\(W::StaticWOperator, v::AbstractArray) = isinv(W) ? W.W * v : W.W \ v
28+
Base.:\(W::StaticWOperator, v::AbstractArray) = isinv(W) ? W.W * v : W.F \ v
2029

2130
function calc_tderivative!(integrator, cache, dtd1, repeat_step)
2231
@inbounds begin

0 commit comments

Comments
 (0)