Skip to content

Commit 86917a5

Browse files
committed
Remove the static arrays special casing
1 parent 3c32525 commit 86917a5

File tree

7 files changed

+30
-32
lines changed

7 files changed

+30
-32
lines changed

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
1818
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1919
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2020
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
21+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2122
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2223

2324
[weakdeps]
2425
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2526
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
26-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2727
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2828
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2929

3030
[extensions]
3131
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
3232
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
33-
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
3433
SimpleNonlinearSolveTrackerExt = "Tracker"
3534
SimpleNonlinearSolveZygoteExt = "Zygote"
3635

@@ -40,7 +39,7 @@ AllocCheck = "0.1.1"
4039
Aqua = "0.8"
4140
ArrayInterface = "7.9"
4241
CUDA = "5.2"
43-
ChainRulesCore = "1.22"
42+
ChainRulesCore = "1.23"
4443
ConcreteStructs = "0.2.3"
4544
DiffEqBase = "6.149"
4645
DiffResults = "1.1"
@@ -59,13 +58,14 @@ PrecompileTools = "1.2"
5958
Random = "1.10"
6059
ReTestItems = "1.23"
6160
Reexport = "1.2"
62-
ReverseDiff = "1.15"
61+
ReverseDiff = "1.15.3"
6362
SciMLBase = "2.37.0"
6463
SciMLSensitivity = "7.58"
64+
Setfield = "1.1.1"
6565
StaticArrays = "1.9"
6666
StaticArraysCore = "1.4.2"
6767
Test = "1.10"
68-
Tracker = "0.2.32"
68+
Tracker = "0.2.33"
6969
Zygote = "0.6.69"
7070
julia = "1.10"
7171

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveStaticArraysExt.jl

Lines changed: 0 additions & 7 deletions
This file was deleted.

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
2424
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
2525
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
2626
build_solution, isinplace, _unwrap_val
27+
using Setfield: @set!
2728
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
2829
end
2930

lib/SimpleNonlinearSolve/src/ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,12 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
109109
end
110110
else
111111
# For small problems, nesting ForwardDiff is actually quite fast
112-
_f = Base.Fix2(prob.f, newprob.p)
113112
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
114113
# TODO: Remove once DI has the value_and_pullback_split defined
115-
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(_f, u, p)
114+
_F = @closure (u, p) -> begin
115+
_f = Base.Fix2(prob.f, p)
116+
return __zygote_compute_nlls_vjp(_f, u, p)
117+
end
116118
else
117119
_F = @closure (u, p) -> begin
118120
_f = Base.Fix2(prob.f, p)

lib/SimpleNonlinearSolve/src/nlsolve/dfsane.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
7777
α_1 = one(T)
7878
f_1 = fx_norm
7979

80-
history_f_k = if x isa SArray ||
81-
(x isa Number && __is_extension_loaded(Val(:StaticArrays)))
82-
ones(SVector{M, T}) * fx_norm
83-
else
84-
fill(fx_norm, M)
85-
end
80+
history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm :
81+
__history_vec(fx_norm, Val(M))
8682

8783
# Generate the cache
8884
@bb x_cache = similar(x)
@@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
150146
# Store function value
151147
if history_f_k isa SVector
152148
history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M))
149+
elseif history_f_k isa NTuple
150+
@set! history_f_k[mod1(k, M)] = fx_norm_new
153151
else
154152
history_f_k[mod1(k, M)] = fx_norm_new
155153
end
@@ -158,3 +156,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
158156

159157
return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
160158
end
159+
160+
@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M}
161+
M 11 && return :(fill(fx_norm, M)) # Julia can't specialize here
162+
return :(ntuple(Returns(fx_norm), $(M)))
163+
end

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,15 @@ function value_and_jacobian(
3232

3333
if isinplace(prob)
3434
if cache isa HasAnalyticJacobian
35-
prob.f.jac(J, x, p)
35+
prob.f.jac(J, x, prob.p)
3636
f(y, x)
37-
else
38-
DI.jacobian!(f, y, J, ad, x, cache)
37+
return y, J
3938
end
40-
return y, J
39+
return DI.value_and_jacobian!(f, y, J, ad, x, cache)
4140
else
4241
cache isa HasAnalyticJacobian && return f(x), prob.f.jac(x, prob.p)
4342
J === nothing && return DI.value_and_jacobian(f, ad, x, cache)
44-
y, _ = DI.value_and_jacobian!(f, J, ad, x, cache)
43+
y, J = DI.value_and_jacobian!(f, J, ad, x, cache)
4544
return y, J
4645
end
4746
end
@@ -63,8 +62,9 @@ end
6362
function compute_jacobian_and_hessian(
6463
ad, prob::AbstractNonlinearProblem, f::F, y, x) where {F}
6564
if x isa Number
66-
df = @closure x -> DI.derivative(f, ad, x)
67-
return f(x), df(x), DI.derivative(df, ad, x)
65+
H = DI.second_derivative(f, ad, x)
66+
v, J = DI.value_and_derivative(f, ad, x)
67+
return v, J, H
6868
end
6969

7070
if isinplace(prob)

lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,14 @@ end
3434

3535
export quadratic_f, quadratic_f!, quadratic_f2, newton_fails, TERMINATION_CONDITIONS,
3636
benchmark_nlsolve_oop, benchmark_nlsolve_iip
37-
3837
end
3938

4039
@testitem "First Order Methods" setup=[RootfindingTesting] tags=[:core] begin
4140
@testset "$(alg)" for alg in (SimpleNewtonRaphson,
4241
SimpleTrustRegion,
4342
(args...; kwargs...) -> SimpleTrustRegion(
4443
args...; nlsolve_update_rule = Val(true), kwargs...))
45-
@testset "AutoDiff: $(nameof(typeof(autodiff))))" for autodiff in (
44+
@testset "AutoDiff: $(nameof(typeof(autodiff)))" for autodiff in (
4645
AutoFiniteDiff(), AutoForwardDiff(), AutoPolyesterForwardDiff())
4746
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
4847
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
@@ -59,7 +58,7 @@ end
5958
end
6059
end
6160

62-
@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
61+
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
6362
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
6463

6564
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -79,7 +78,7 @@ end
7978
end
8079
end
8180

82-
@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
81+
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
8382
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
8483

8584
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -104,7 +103,7 @@ end
104103
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
105104
end
106105

107-
@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
106+
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
108107
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
109108

110109
probN = NonlinearProblem(quadratic_f, u0, 2.0)

0 commit comments

Comments
 (0)