Skip to content

ForwardDiff Overload Fixes #629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
17b4364
don't set properties again
jClugstor Jul 2, 2025
0b7842a
make sure that when A, b, or u are accessed you get the Dual numbers
jClugstor Jul 2, 2025
b0d53ae
make sure cache u is updated
jClugstor Jul 2, 2025
e98a9b9
no infiltrate
jClugstor Jul 2, 2025
ad32c73
exclude nested duals
jClugstor Jul 3, 2025
16f5f0f
add Float32 support
jClugstor Jul 14, 2025
5dd9ce0
change to AbstractFloat
jClugstor Jul 14, 2025
43a0379
add recursive support for nested Duals
jClugstor Jul 14, 2025
f6db1ee
fix support for nested Duals
jClugstor Jul 14, 2025
c10a8bf
clean up
jClugstor Jul 14, 2025
7abb243
more clean up
jClugstor Jul 14, 2025
0717289
allow any number
jClugstor Jul 14, 2025
f829926
reuse list
jClugstor Jul 14, 2025
00431ad
add tests for nested duals
jClugstor Jul 14, 2025
6ac7bf1
no infiltrator
jClugstor Jul 14, 2025
8bae0f7
proper RAT indexing
jClugstor Jul 14, 2025
8e8b47e
all numbers
jClugstor Jul 14, 2025
308bc76
correct RAT index
jClugstor Jul 14, 2025
d9ed726
get rid of unecessary things
jClugstor Jul 15, 2025
fdc9774
get rid of log
jClugstor Jul 15, 2025
9c11cc2
add more tests
jClugstor Jul 16, 2025
96244fd
streamline
jClugstor Jul 16, 2025
762bfb1
make sure u is aliased
jClugstor Jul 16, 2025
248718f
add tests for sparse arrays and sparse solvers
jClugstor Jul 16, 2025
5497adc
make sure AbstractArrays of size 1 are accounted for
jClugstor Jul 16, 2025
e842eac
use AbstractVector, fix setproperty!
jClugstor Jul 16, 2025
1b1697a
add test for setting nested Duals
jClugstor Jul 16, 2025
2e003ee
fix test
jClugstor Jul 16, 2025
47d4d68
make sparse
jClugstor Jul 16, 2025
6323432
fix default solver
jClugstor Jul 16, 2025
b1ffa15
allow setting alg
jClugstor Jul 16, 2025
62127b8
Update ext/LinearSolveForwardDiffExt.jl
ChrisRackauckas Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 69 additions & 46 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LinearSolveForwardDiffExt

using LinearSolve
using LinearSolve: SciMLLinearSolveAlgorithm
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Partials
Expand Down Expand Up @@ -36,8 +37,14 @@ const DualAbstractLinearProblem = Union{
LinearSolve.@concrete mutable struct DualLinearCache
linear_cache
dual_type

partials_A
partials_b
partials_u

dual_A
dual_b
dual_u
end

function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
Expand All @@ -55,16 +62,15 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa

rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)

partial_cache = cache.linear_cache
partial_cache.u = dual_u0

cache.linear_cache.u = dual_u0
# We can reuse the linear cache, because the same factorization will work for the partials.
for i in eachindex(rhs_list)
partial_cache.b = rhs_list[i]
rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u)
cache.linear_cache.b = rhs_list[i]
rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
end

# Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to
partial_cache.b = primal_b
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
cache.linear_cache.b = primal_b

partial_sols = rhs_list

Expand Down Expand Up @@ -96,35 +102,25 @@ function xp_linsolve_rhs(
b_list
end

#=
function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
return solve(prob, nothing, args...; kwargs...)
end

function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
end

function SciMLBase.solve(prob::DualAbstractLinearProblem,
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
solve!(init(prob, alg, args...; kwargs...))
end
=#

function linearsolve_dual_solution(
u::Number, partials, dual_type)
return dual_type(u, partials)
end

function linearsolve_dual_solution(
u::AbstractArray, partials, dual_type)
function linearsolve_dual_solution(u::Number, partials,
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
# Handle single-level duals
return dual_type(u, partials)
end

function linearsolve_dual_solution(u::AbstractArray, partials,
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
# Handle single-level duals for arrays
partials_list = RecursiveArrayTools.VectorOfArray(partials)
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
end

#=
function SciMLBase.init(
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm,
args...;
Expand All @@ -138,7 +134,6 @@ function SciMLBase.init(
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)

(; A, b, u0, p) = prob
new_A = nodual_value(A)
new_b = nodual_value(b)
Expand All @@ -147,7 +142,6 @@ function SciMLBase.init(
∂_A = partial_vals(A)
∂_b = partial_vals(b)

#primal_prob = LinearProblem(new_A, new_b, u0 = new_u0)
primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0)

if get_dual_type(prob.A) !== nothing
Expand All @@ -156,48 +150,71 @@ function SciMLBase.init(
dual_type = get_dual_type(prob.b)
end

alg isa LinearSolve.DefaultLinearSolver ? real_alg = LinearSolve.defaultalg(primal_prob.A, primal_prob.b) : real_alg = alg

non_partial_cache = init(
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
primal_prob, real_alg, assumptions, args...;
alias = alias, abstol = abstol, reltol = reltol,
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
sensealg = sensealg, u0 = new_u0, kwargs...)
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b)
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b)))
end

function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
solve!(cache, cache.alg, args...; kwargs...)
end

function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
sol,
partials = linearsolve_forwarddiff_solve(
cache::DualLinearCache, cache.alg, args...; kwargs...)

dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)

if cache.dual_u isa AbstractArray
cache.dual_u[:] = dual_sol
else
cache.dual_u = dual_sol
end

return SciMLBase.build_linear_solution(
cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats
)
end
=#

# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
# Also "forwards" setproperty so that
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
# If the property is A or b, also update it in the LinearCache
if sym === :A || sym === :b || sym === :u
setproperty!(dc.linear_cache, sym, nodual_value(val))
elseif hasfield(DualLinearCache, sym)
setfield!(dc, sym, val)
elseif hasfield(LinearSolve.LinearCache, sym)
setproperty!(dc.linear_cache, sym, val)
end


# Update the partials if setting A or b
if sym === :A
setfield!(dc, :dual_A, val)
setfield!(dc, :partials_A, partial_vals(val))
elseif sym === :b
elseif sym === :b
setfield!(dc, :dual_b, val)
setfield!(dc, :partials_b, partial_vals(val))
else
setfield!(dc, sym, val)
elseif sym === :u
setfield!(dc, :dual_u, val)
setfield!(dc, :partials_u, partial_vals(val))
end
end

# "Forwards" getproperty to LinearCache if necessary
function Base.getproperty(dc::DualLinearCache, sym::Symbol)
if hasfield(LinearSolve.LinearCache, sym)
if sym === :A
dc.dual_A
elseif sym === :b
dc.dual_b
elseif sym === :u
dc.dual_u
elseif hasfield(LinearSolve.LinearCache, sym)
return getproperty(dc.linear_cache, sym)
else
return getfield(dc, sym)
Expand All @@ -206,31 +223,36 @@ end



# Helper functions for Dual numbers
get_dual_type(x::Dual) = typeof(x)
# Enhanced helper functions for Dual numbers to handle recursion
get_dual_type(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = typeof(x)
get_dual_type(x::Dual{T, V, P}) where {T, V <: Dual, P} = typeof(x)
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
get_dual_type(x) = nothing

partial_vals(x::Dual) = ForwardDiff.partials(x)
# Add recursive handling for nested dual partials
partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.partials(x)
partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x)
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
partial_vals(x) = nothing

# Add recursive handling for nested dual values
nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
nodual_value(x::AbstractArray{<:Dual}) = map(nodual_value, x)


function partials_to_list(partial_matrix::Vector)
function partials_to_list(partial_matrix::AbstractVector{T}) where {T}
p = eachindex(first(partial_matrix))
[[partial[i] for partial in partial_matrix] for i in p]
end

function partials_to_list(partial_matrix)
p = length(first(partial_matrix))
m, n = size(partial_matrix)
res_list = fill(zeros(m, n), p)
res_list = fill(zeros(typeof(partial_matrix[1, 1][1]), m, n), p)
for k in 1:p
res = zeros(m, n)
res = zeros(typeof(partial_matrix[1, 1][1]), m, n)
for i in 1:m
for j in 1:n
res[i, j] = partial_matrix[i, j][k]
Expand All @@ -243,3 +265,4 @@ end


end

117 changes: 114 additions & 3 deletions test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using LinearSolve
using ForwardDiff
using Test
using SparseArrays

function h(p)
(A = [p[1] p[2]+1 p[2]^3;
Expand All @@ -23,12 +24,11 @@ krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES())

@test ≈(krylov_u0_sol, backslash_x_p, rtol = 1e-9)


A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
backslash_x_p = A \ [6.0, 10.0, 25.0]
prob = LinearProblem(A, [6.0, 10.0, 25.0])

@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9)
@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9)
@test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9)

_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
Expand All @@ -48,6 +48,9 @@ new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.
cache.A = new_A
cache.b = new_b

@test cache.A == new_A
@test cache.b == new_b

x_p = solve!(cache)
backslash_x_p = new_A \ new_b

Expand All @@ -61,6 +64,7 @@ cache = init(prob)

new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.A = new_A
@test cache.A == new_A

x_p = solve!(cache)
backslash_x_p = new_A \ b
Expand All @@ -75,8 +79,115 @@ cache = init(prob)

_, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.b = new_b
@test cache.b == new_b

x_p = solve!(cache)
backslash_x_p = A \ new_b

@test ≈(x_p, backslash_x_p, rtol = 1e-9)
@test ≈(x_p, backslash_x_p, rtol = 1e-9)

# Nested Duals
function h(p)
(A = [p[1] p[2]+1 p[2]^3;
3*p[1] p[1]+5 p[2] * p[1]-4;
p[2]^2 9*p[1] p[2]],
b = [p[1] + 1, p[2] * 2, p[1]^2])
end

A, b = h([ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 1.0, 0.0),
ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 0.0, 1.0)])

prob = LinearProblem(A, b)
overload_x_p = solve(prob)

original_x_p = A \ b

@test ≈(overload_x_p, original_x_p, rtol = 1e-9)

prob = LinearProblem(A, b)
cache = init(prob)

new_A, new_b = h([ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 1.0, 0.0),
ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 0.0, 1.0)])

cache.A = new_A
cache.b = new_b

@test cache.A == new_A
@test cache.b == new_b

function linprob_f(p)
A, b = h(p)
prob = LinearProblem(A, b)
solve(prob)
end

function slash_f(p)
A, b = h(p)
A \ b
end

@test ≈(
ForwardDiff.jacobian(slash_f, [5.0, 5.0]), ForwardDiff.jacobian(linprob_f, [5.0, 5.0]))

@test ≈(ForwardDiff.jacobian(p -> ForwardDiff.jacobian(slash_f, [5.0, p[1]]), [5.0]),
ForwardDiff.jacobian(p -> ForwardDiff.jacobian(linprob_f, [5.0, p[1]]), [5.0]))

function g(p)
(A = [p[1] p[1]+1 p[1]^3;
3*p[1] p[1]+5 p[1] * p[1]-4;
p[1]^2 9*p[1] p[1]],
b = [p[1] + 1, p[1] * 2, p[1]^2])
end

function slash_f_hes(p)
A, b = g(p)
x = A \ b
sum(x)
end

function linprob_f_hes(p)
A, b = g(p)
prob = LinearProblem(A, b)
x = solve(prob)
sum(x)
end

@test ≈(ForwardDiff.hessian(slash_f_hes, [5.0]),
ForwardDiff.hessian(linprob_f_hes, [5.0]))

# Test aliasing
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
cache = init(prob)

new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.A = new_A
cache.b = new_b

linu = [ForwardDiff.Dual(0.0, 0.0, 0.0), ForwardDiff.Dual(0.0, 0.0, 0.0),
ForwardDiff.Dual(0.0, 0.0, 0.0)]
cache.u = linu
x_p = solve!(cache)
backslash_x_p = new_A \ new_b

@test linu == cache.u

# Test Float Only solvers

A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(sparse(A), sparse(b))
overload_x_p = solve(prob, KLUFactorization())
backslash_x_p = A \ b

@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)

A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(sparse(A), sparse(b))
overload_x_p = solve(prob, UMFPACKFactorization())
backslash_x_p = A \ b

@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)
Loading