@@ -8,18 +8,33 @@ jacobian_f(::Number, p::Number) = 1 / (2 * √p)
8
8
jacobian_f (u, p:: Number ) = one .(u) .* (1 / (2 * √ p))
9
9
jacobian_f (u, p:: AbstractArray ) = diagm (vec (@. 1 / (2 * √ p)))
10
10
11
- function solve_with (:: Val{iip } , u, alg) where {iip }
12
- f = if iip
11
+ function solve_with (:: Val{mode } , u, alg) where {mode }
12
+ f = if mode === : iip
13
13
solve_iip (p) = solve (NonlinearProblem (test_f!, u, p), alg). u
14
- else
14
+ elseif mode === :iip_cache
15
+ function solve_iip_init (p)
16
+ cache = SciMLBase. init (NonlinearProblem (test_f!, u, p), alg)
17
+ return SciMLBase. solve! (cache). u
18
+ end
19
+ elseif mode === :oop
15
20
solve_oop (p) = solve (NonlinearProblem (test_f, u, p), alg). u
21
+ elseif mode === :oop_cache
22
+ function solve_oop_init (p)
23
+ cache = SciMLBase. init (NonlinearProblem (test_f, u, p), alg)
24
+ return SciMLBase. solve! (cache). u
25
+ end
16
26
end
17
27
return f
18
28
end
19
29
20
- __can_inplace (:: Number ) = false
21
- __can_inplace (:: AbstractArray ) = true
22
- __can_inplace (:: StaticArray ) = false
30
+ __compatible (:: Any , :: Val{:oop} ) = true
31
+ __compatible (:: Any , :: Val{:oop_cache} ) = true
32
+ __compatible (:: Number , :: Val{:iip} ) = false
33
+ __compatible (:: AbstractArray , :: Val{:iip} ) = true
34
+ __compatible (:: StaticArray , :: Val{:iip} ) = false
35
+ __compatible (:: Number , :: Val{:iip_cache} ) = false
36
+ __compatible (:: AbstractArray , :: Val{:iip_cache} ) = true
37
+ __compatible (:: StaticArray , :: Val{:iip_cache} ) = false
23
38
24
39
__compatible (:: Any , :: Number ) = true
25
40
__compatible (:: Number , :: AbstractArray ) = false
@@ -32,37 +47,49 @@ __compatible(u::StaticArray, ::SciMLBase.AbstractNonlinearAlgorithm) = true
32
47
__compatible (u:: StaticArray , :: Union{CMINPACK, NLsolveJL} ) = false
33
48
__compatible (u, :: Nothing ) = true
34
49
50
+ __compatible (:: Any , :: Any ) = true
51
+ __compatible (:: CMINPACK , :: Val{:iip_cache} ) = false
52
+ __compatible (:: CMINPACK , :: Val{:oop_cache} ) = false
53
+ __compatible (:: NLsolveJL , :: Val{:iip_cache} ) = false
54
+ __compatible (:: NLsolveJL , :: Val{:oop_cache} ) = false
55
+
35
56
@testset " ForwardDiff.jl Integration: $(alg) " for alg in (NewtonRaphson (), TrustRegion (),
36
57
LevenbergMarquardt (), PseudoTransient (; alpha_initial = 10.0 ), Broyden (), Klement (),
37
58
DFSane (), nothing , NLsolveJL (), CMINPACK ())
38
59
us = (2.0 , @SVector [1.0 , 1.0 ], [1.0 , 1.0 ], ones (2 , 2 ), @SArray ones (2 , 2 ))
39
60
40
61
@testset " Scalar AD" begin
41
- for p in 1.0 : 0.1 : 100.0
42
- for u0 in us
43
- __compatible (u0, alg) || continue
44
- sol = solve (NonlinearProblem (test_f, u0, p), alg)
45
- if SciMLBase. successful_retcode (sol)
46
- gs = abs .(ForwardDiff. derivative (solve_with (Val {false} (), u0, alg), p))
47
- gs_true = abs .(jacobian_f (u0, p))
48
- if ! (isapprox (gs, gs_true, atol = 1e-5 ))
49
- @show sol. retcode, sol. u
50
- @error " ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg) " forwardiff_gradient= gs true_gradient= gs_true
51
- else
52
- @test abs .(gs)≈ abs .(gs_true) atol= 1e-5
53
- end
62
+ for p in 1.0 : 0.1 : 100.0 , u0 in us, mode in (:iip , :oop , :iip_cache , :oop_cache )
63
+ __compatible (u0, alg) || continue
64
+ __compatible (u0, Val (mode)) || continue
65
+ __compatible (alg, Val (mode)) || continue
66
+
67
+ sol = solve (NonlinearProblem (test_f, u0, p), alg)
68
+ if SciMLBase. successful_retcode (sol)
69
+ gs = abs .(ForwardDiff. derivative (solve_with (Val {mode} (), u0, alg), p))
70
+ gs_true = abs .(jacobian_f (u0, p))
71
+ if ! (isapprox (gs, gs_true, atol = 1e-5 ))
72
+ @show sol. retcode, sol. u
73
+ @error " ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg) " forwardiff_gradient= gs true_gradient= gs_true
74
+ else
75
+ @test abs .(gs)≈ abs .(gs_true) atol= 1e-5
54
76
end
55
77
end
56
78
end
57
79
end
58
80
59
81
@testset " Jacobian" begin
60
- for u0 in us, p in ([2.0 , 1.0 ], [2.0 1.0 ; 3.0 4.0 ])
82
+ for u0 in us, p in ([2.0 , 1.0 ], [2.0 1.0 ; 3.0 4.0 ]),
83
+ mode in (:iip , :oop , :iip_cache , :oop_cache )
84
+
61
85
__compatible (u0, p) || continue
62
86
__compatible (u0, alg) || continue
87
+ __compatible (u0, Val (mode)) || continue
88
+ __compatible (alg, Val (mode)) || continue
89
+
63
90
sol = solve (NonlinearProblem (test_f, u0, p), alg)
64
91
if SciMLBase. successful_retcode (sol)
65
- gs = abs .(ForwardDiff. jacobian (solve_with (Val {false } (), u0, alg), p))
92
+ gs = abs .(ForwardDiff. jacobian (solve_with (Val {mode } (), u0, alg), p))
66
93
gs_true = abs .(jacobian_f (u0, p))
67
94
if ! (isapprox (gs, gs_true, atol = 1e-5 ))
68
95
@show sol. retcode, sol. u
0 commit comments