diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 2457df224..d133d75c0 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -1,6 +1,7 @@ module LinearSolveForwardDiffExt using LinearSolve +using LinearSolve: SciMLLinearSolveAlgorithm using LinearAlgebra using ForwardDiff using ForwardDiff: Dual, Partials @@ -36,8 +37,14 @@ const DualAbstractLinearProblem = Union{ LinearSolve.@concrete mutable struct DualLinearCache linear_cache dual_type + partials_A partials_b + partials_u + + dual_A + dual_b + dual_u end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) @@ -55,16 +62,15 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - partial_cache = cache.linear_cache - partial_cache.u = dual_u0 - + cache.linear_cache.u = dual_u0 + # We can reuse the linear cache, because the same factorization will work for the partials. for i in eachindex(rhs_list) - partial_cache.b = rhs_list[i] - rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u) + cache.linear_cache.b = rhs_list[i] + rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u) end - # Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to - partial_cache.b = primal_b + # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to + cache.linear_cache.b = primal_b partial_sols = rhs_list @@ -96,35 +102,25 @@ function xp_linsolve_rhs( b_list end -#= -function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...) - return solve(prob, nothing, args...; kwargs...) -end - -function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; - assump = OperatorAssumptions(issquare(prob.A)), kwargs...) - return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...) -end - -function SciMLBase.solve(prob::DualAbstractLinearProblem, - alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) - solve!(init(prob, alg, args...; kwargs...)) -end -=# - function linearsolve_dual_solution( u::Number, partials, dual_type) return dual_type(u, partials) end -function linearsolve_dual_solution( - u::AbstractArray, partials, dual_type) +function linearsolve_dual_solution(u::Number, partials, + dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} + # Handle single-level duals + return dual_type(u, partials) +end + +function linearsolve_dual_solution(u::AbstractArray, partials, + dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} + # Handle single-level duals for arrays partials_list = RecursiveArrayTools.VectorOfArray(partials) return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), - zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) + zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1]))) end -#= function SciMLBase.init( prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; @@ -138,7 +134,6 @@ function SciMLBase.init( assumptions = OperatorAssumptions(issquare(prob.A)), sensealg = LinearSolveAdjoint(), kwargs...) - (; A, b, u0, p) = prob new_A = nodual_value(A) new_b = nodual_value(b) @@ -147,7 +142,6 @@ function SciMLBase.init( ∂_A = partial_vals(A) ∂_b = partial_vals(b) - #primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0) if get_dual_type(prob.A) !== nothing @@ -156,48 +150,71 @@ function SciMLBase.init( dual_type = get_dual_type(prob.b) end + alg isa LinearSolve.DefaultLinearSolver ? real_alg = LinearSolve.defaultalg(primal_prob.A, primal_prob.b) : real_alg = alg + non_partial_cache = init( - primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, + primal_prob, real_alg, assumptions, args...; + alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) - return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b) + return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b))) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) + solve!(cache, cache.alg, args...; kwargs...) +end + +function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) sol, partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) - dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) + + if cache.dual_u isa AbstractArray + cache.dual_u[:] = dual_sol + else + cache.dual_u = dual_sol + end + return SciMLBase.build_linear_solution( cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats ) end -=# # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache -# Also "forwards" setproperty so that function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache if sym === :A || sym === :b || sym === :u setproperty!(dc.linear_cache, sym, nodual_value(val)) + elseif hasfield(DualLinearCache, sym) + setfield!(dc, sym, val) elseif hasfield(LinearSolve.LinearCache, sym) setproperty!(dc.linear_cache, sym, val) end + # Update the partials if setting A or b if sym === :A + setfield!(dc, :dual_A, val) setfield!(dc, :partials_A, partial_vals(val)) - elseif sym === :b + elseif sym === :b + setfield!(dc, :dual_b, val) setfield!(dc, :partials_b, partial_vals(val)) - else - setfield!(dc, sym, val) + elseif sym === :u + setfield!(dc, :dual_u, val) + setfield!(dc, :partials_u, partial_vals(val)) end end # "Forwards" getproperty to LinearCache if necessary function Base.getproperty(dc::DualLinearCache, sym::Symbol) - if hasfield(LinearSolve.LinearCache, sym) + if sym === :A + dc.dual_A + elseif sym === :b + dc.dual_b + elseif sym === :u + dc.dual_u + elseif hasfield(LinearSolve.LinearCache, sym) return getproperty(dc.linear_cache, sym) else return getfield(dc, sym) @@ -206,21 +223,26 @@ end -# Helper functions for Dual numbers -get_dual_type(x::Dual) = typeof(x) +# Enhanced helper functions for Dual numbers to handle recursion +get_dual_type(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = typeof(x) +get_dual_type(x::Dual{T, V, P}) where {T, V <: Dual, P} = typeof(x) get_dual_type(x::AbstractArray{<:Dual}) = eltype(x) get_dual_type(x) = nothing -partial_vals(x::Dual) = ForwardDiff.partials(x) +# Add recursive handling for nested dual partials +partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.partials(x) +partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x) partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) partial_vals(x) = nothing +# Add recursive handling for nested dual values nodual_value(x) = x -nodual_value(x::Dual) = ForwardDiff.value(x) -nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) +nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x) +nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact +nodual_value(x::AbstractArray{<:Dual}) = map(nodual_value, x) -function partials_to_list(partial_matrix::Vector) +function partials_to_list(partial_matrix::AbstractVector{T}) where {T} p = eachindex(first(partial_matrix)) [[partial[i] for partial in partial_matrix] for i in p] end @@ -228,9 +250,9 @@ end function partials_to_list(partial_matrix) p = length(first(partial_matrix)) m, n = size(partial_matrix) - res_list = fill(zeros(m, n), p) + res_list = fill(zeros(typeof(partial_matrix[1, 1][1]), m, n), p) for k in 1:p - res = zeros(m, n) + res = zeros(typeof(partial_matrix[1, 1][1]), m, n) for i in 1:m for j in 1:n res[i, j] = partial_matrix[i, j][k] @@ -243,3 +265,4 @@ end end + diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index eb66c64dc..a53a4cf0e 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -1,6 +1,7 @@ using LinearSolve using ForwardDiff using Test +using SparseArrays function h(p) (A = [p[1] p[2]+1 p[2]^3; @@ -23,12 +24,11 @@ krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES()) @test ≈(krylov_u0_sol, backslash_x_p, rtol = 1e-9) - A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) backslash_x_p = A \ [6.0, 10.0, 25.0] prob = LinearProblem(A, [6.0, 10.0, 25.0]) -@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) +@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) @test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9) _, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) @@ -48,6 +48,9 @@ new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1. cache.A = new_A cache.b = new_b +@test cache.A == new_A +@test cache.b == new_b + x_p = solve!(cache) backslash_x_p = new_A \ new_b @@ -61,6 +64,7 @@ cache = init(prob) new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) cache.A = new_A +@test cache.A == new_A x_p = solve!(cache) backslash_x_p = new_A \ b @@ -75,8 +79,115 @@ cache = init(prob) _, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) cache.b = new_b +@test cache.b == new_b x_p = solve!(cache) backslash_x_p = A \ new_b -@test ≈(x_p, backslash_x_p, rtol = 1e-9) \ No newline at end of file +@test ≈(x_p, backslash_x_p, rtol = 1e-9) + +# Nested Duals +function h(p) + (A = [p[1] p[2]+1 p[2]^3; + 3*p[1] p[1]+5 p[2] * p[1]-4; + p[2]^2 9*p[1] p[2]], + b = [p[1] + 1, p[2] * 2, p[1]^2]) +end + +A, b = h([ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 1.0, 0.0), + ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 0.0, 1.0)]) + +prob = LinearProblem(A, b) +overload_x_p = solve(prob) + +original_x_p = A \ b + +@test ≈(overload_x_p, original_x_p, rtol = 1e-9) + +prob = LinearProblem(A, b) +cache = init(prob) + +new_A, new_b = h([ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 1.0, 0.0), + ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 0.0, 1.0)]) + +cache.A = new_A +cache.b = new_b + +@test cache.A == new_A +@test cache.b == new_b + +function linprob_f(p) + A, b = h(p) + prob = LinearProblem(A, b) + solve(prob) +end + +function slash_f(p) + A, b = h(p) + A \ b +end + +@test ≈( + ForwardDiff.jacobian(slash_f, [5.0, 5.0]), ForwardDiff.jacobian(linprob_f, [5.0, 5.0])) + +@test ≈(ForwardDiff.jacobian(p -> ForwardDiff.jacobian(slash_f, [5.0, p[1]]), [5.0]), + ForwardDiff.jacobian(p -> ForwardDiff.jacobian(linprob_f, [5.0, p[1]]), [5.0])) + +function g(p) + (A = [p[1] p[1]+1 p[1]^3; + 3*p[1] p[1]+5 p[1] * p[1]-4; + p[1]^2 9*p[1] p[1]], + b = [p[1] + 1, p[1] * 2, p[1]^2]) +end + +function slash_f_hes(p) + A, b = g(p) + x = A \ b + sum(x) +end + +function linprob_f_hes(p) + A, b = g(p) + prob = LinearProblem(A, b) + x = solve(prob) + sum(x) +end + +@test ≈(ForwardDiff.hessian(slash_f_hes, [5.0]), + ForwardDiff.hessian(linprob_f_hes, [5.0])) + +# Test aliasing +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +cache = init(prob) + +new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +cache.A = new_A +cache.b = new_b + +linu = [ForwardDiff.Dual(0.0, 0.0, 0.0), ForwardDiff.Dual(0.0, 0.0, 0.0), + ForwardDiff.Dual(0.0, 0.0, 0.0)] +cache.u = linu +x_p = solve!(cache) +backslash_x_p = new_A \ new_b + +@test linu == cache.u + +# Test Float Only solvers + +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(sparse(A), sparse(b)) +overload_x_p = solve(prob, KLUFactorization()) +backslash_x_p = A \ b + +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) + +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(sparse(A), sparse(b)) +overload_x_p = solve(prob, UMFPACKFactorization()) +backslash_x_p = A \ b + +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) \ No newline at end of file