From 17b436451493edd927326ed72c724184e994c5b5 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 2 Jul 2025 09:58:19 -0400 Subject: [PATCH 01/32] don't set properties again --- ext/LinearSolveForwardDiffExt.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 2457df224..d2556f80d 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -190,8 +190,6 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) setfield!(dc, :partials_A, partial_vals(val)) elseif sym === :b setfield!(dc, :partials_b, partial_vals(val)) - else - setfield!(dc, sym, val) end end From 0b7842a5c289d1b01052cb0b38e53c57f7926aeb Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 2 Jul 2025 14:34:31 -0400 Subject: [PATCH 02/32] make sure that when A, b, or u are accessed you get the Dual numbers --- ext/LinearSolveForwardDiffExt.jl | 41 ++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index d2556f80d..e2021cbb1 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -36,8 +36,14 @@ const DualAbstractLinearProblem = Union{ LinearSolve.@concrete mutable struct DualLinearCache linear_cache dual_type + partials_A partials_b + partials_u + + dual_A + dual_b + dual_u end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) @@ -55,16 +61,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - partial_cache = cache.linear_cache - partial_cache.u = dual_u0 - + cache.linear_cache.u = dual_u0 + # We can reuse the linear cache, because the same factorization will work for the partials. for i in eachindex(rhs_list) - partial_cache.b = rhs_list[i] - rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u) + cache.linear_cache.b = rhs_list[i] + rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u) end - # Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to - partial_cache.b = primal_b + # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to + cache.linear_cache.b = primal_b + Main.@infiltrate partial_sols = rhs_list @@ -147,7 +153,6 @@ function SciMLBase.init( ∂_A = partial_vals(A) ∂_b = partial_vals(b) - #primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0) if get_dual_type(prob.A) !== nothing @@ -160,7 +165,7 @@ function SciMLBase.init( primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) - return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b) + return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zero.(b)) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) @@ -169,6 +174,10 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) cache::DualLinearCache, cache.alg, args...; kwargs...) dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) + Main.@infiltrate + + cache.dual_u = dual_sol + Main.@infiltrate return SciMLBase.build_linear_solution( cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats ) @@ -176,7 +185,6 @@ end =# # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache -# Also "forwards" setproperty so that function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache if sym === :A || sym === :b || sym === :u @@ -188,14 +196,23 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # Update the partials if setting A or b if sym === :A setfield!(dc, :partials_A, partial_vals(val)) - elseif sym === :b + elseif sym === :b setfield!(dc, :partials_b, partial_vals(val)) + elseif sym === :u + Main.@infiltrate + setfield!(dc, :partials_u, partial_vals(val)) end end # "Forwards" getproperty to LinearCache if necessary function Base.getproperty(dc::DualLinearCache, sym::Symbol) - if hasfield(LinearSolve.LinearCache, sym) + if sym === :A + dc.dual_A + elseif sym === :b + dc.dual_b + elseif sym === :u + dc.dual_u + elseif hasfield(LinearSolve.LinearCache, sym) return getproperty(dc.linear_cache, sym) else return getfield(dc, sym) From b0d53ae4fa11d1b50bbea3e4c3e545346e93109a Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 2 Jul 2025 14:49:17 -0400 Subject: [PATCH 03/32] make sure cache u is updated --- ext/LinearSolveForwardDiffExt.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index e2021cbb1..310f1958c 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -189,9 +189,12 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache if sym === :A || sym === :b || sym === :u setproperty!(dc.linear_cache, sym, nodual_value(val)) + elseif hasfield(DualLinearCache, sym) + setfield!(dc,sym,val) elseif hasfield(LinearSolve.LinearCache, sym) setproperty!(dc.linear_cache, sym, val) end + # Update the partials if setting A or b if sym === :A @@ -199,7 +202,6 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) elseif sym === :b setfield!(dc, :partials_b, partial_vals(val)) elseif sym === :u - Main.@infiltrate setfield!(dc, :partials_u, partial_vals(val)) end end @@ -258,3 +260,4 @@ end end + From e98a9b9e2fdc584b24bd74fd74280941f611b81a Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 2 Jul 2025 14:56:09 -0400 Subject: [PATCH 04/32] no infiltrate --- ext/LinearSolveForwardDiffExt.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 310f1958c..d1ce70909 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -70,7 +70,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to cache.linear_cache.b = primal_b - Main.@infiltrate partial_sols = rhs_list @@ -174,10 +173,9 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) cache::DualLinearCache, cache.alg, args...; kwargs...) dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) - Main.@infiltrate cache.dual_u = dual_sol - Main.@infiltrate + return SciMLBase.build_linear_solution( cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats ) From ad32c73ac99b920bdd31a102f4584553d8ae623c Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 3 Jul 2025 12:26:22 -0400 Subject: [PATCH 05/32] exclude nested duals --- ext/LinearSolveForwardDiffExt.jl | 32 +++++++++++++++++++++++--------- test/forwarddiff_overloads.jl | 2 +- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index d1ce70909..181f71555 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -7,31 +7,45 @@ using ForwardDiff: Dual, Partials using SciMLBase using RecursiveArrayTools -const DualLinearProblem = LinearProblem{ + +# Define type for non-nested dual numbers +const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Float64 , P} + +# Define type for nested dual numbers +const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P} + +const SingleDualLinearProblem = LinearProblem{ + <:Union{Number, <:AbstractArray, Nothing}, iip, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, + <:Any +} where {iip} + +const NestedDualLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}}, + <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}}, <:Any -} where {iip, T, V, P} +} where {iip} const DualALinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, <:Union{Number, <:AbstractArray}, <:Any -} where {iip, T, V, P} +} where {iip} const DualBLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, <:Union{Number, <:AbstractArray}, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, <:Any -} where {iip, T, V, P} +} where {iip} const DualAbstractLinearProblem = Union{ - DualLinearProblem, DualALinearProblem, DualBLinearProblem} + SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem} LinearSolve.@concrete mutable struct DualLinearCache linear_cache diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index eb66c64dc..d73ec1746 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -79,4 +79,4 @@ cache.b = new_b x_p = solve!(cache) backslash_x_p = A \ new_b -@test ≈(x_p, backslash_x_p, rtol = 1e-9) \ No newline at end of file +@test ≈(x_p, backslash_x_p, rtol = 1e-9) From 16f5f0fbb3dad00918c0e7c7ab931beac61c32d6 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 09:29:14 -0400 Subject: [PATCH 06/32] add Float32 support --- ext/LinearSolveForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 181f71555..1a5792128 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -9,7 +9,7 @@ using RecursiveArrayTools # Define type for non-nested dual numbers -const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Float64 , P} +const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Union{Float64, Float32} , P} # Define type for nested dual numbers const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P} From 5dd9ce0e58cefba056fd9c55ea14898d154d67f0 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 09:40:51 -0400 Subject: [PATCH 07/32] change to AbstractFloat --- ext/LinearSolveForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 1a5792128..7524586ef 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -9,7 +9,7 @@ using RecursiveArrayTools # Define type for non-nested dual numbers -const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Union{Float64, Float32} , P} +const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:AbstractFloat , P} # Define type for nested dual numbers const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P} From 43a0379d9b35f2426b85d810560081e4e1c9f0fe Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 10:51:00 -0400 Subject: [PATCH 08/32] add recursive support for nested Duals --- ext/LinearSolveForwardDiffExt.jl | 103 +++++++++++++++++++++++++++---- 1 file changed, 90 insertions(+), 13 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 7524586ef..6328971f7 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -45,7 +45,7 @@ const DualBLinearProblem = LinearProblem{ } where {iip} const DualAbstractLinearProblem = Union{ - SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem} + SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem, NestedDualLinearProblem} LinearSolve.@concrete mutable struct DualLinearCache linear_cache @@ -60,6 +60,36 @@ LinearSolve.@concrete mutable struct DualLinearCache dual_u end +# function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) +# # Solve the primal problem +# dual_u0 = copy(cache.linear_cache.u) +# sol = solve!(cache.linear_cache, alg, args...; kwargs...) +# primal_b = copy(cache.linear_cache.b) +# uu = sol.u + +# primal_sol = deepcopy(sol) + +# # Solves Dual partials separately +# ∂_A = cache.partials_A +# ∂_b = cache.partials_b + +# rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) + +# cache.linear_cache.u = dual_u0 +# # We can reuse the linear cache, because the same factorization will work for the partials. +# for i in eachindex(rhs_list) +# cache.linear_cache.b = rhs_list[i] +# rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u) +# end + +# # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to +# cache.linear_cache.b = primal_b + +# partial_sols = rhs_list + +# primal_sol, partial_sols +# end + function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) # Solve the primal problem dual_u0 = copy(cache.linear_cache.u) @@ -77,16 +107,17 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa cache.linear_cache.u = dual_u0 # We can reuse the linear cache, because the same factorization will work for the partials. + partial_sols = [] for i in eachindex(rhs_list) cache.linear_cache.b = rhs_list[i] - rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u) + # For nested duals, the result of this solve might also be a dual number + # which will be handled recursively by the same mechanism + push!(partial_sols, copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)) end - # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to + # Reset to the original `b` and `u` cache.linear_cache.b = primal_b - partial_sols = rhs_list - primal_sol, partial_sols end @@ -136,14 +167,55 @@ function linearsolve_dual_solution( return dual_type(u, partials) end -function linearsolve_dual_solution( - u::AbstractArray, partials, dual_type) +function linearsolve_dual_solution(u::Number, partials, + dual_type::Type{<:Dual{T, V, P}}) where {T, V <: AbstractFloat, P} + # Handle single-level duals + return dual_type(u, partials) +end + +# function linearsolve_dual_solution( +# u::AbstractArray, partials, dual_type) +# partials_list = RecursiveArrayTools.VectorOfArray(partials) +# return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), +# zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) +# end + +function linearsolve_dual_solution(u::AbstractArray, partials, + dual_type::Type{<:Dual{T, V, P}}) where {T, V <: AbstractFloat, P} + # Handle single-level duals for arrays partials_list = RecursiveArrayTools.VectorOfArray(partials) return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) end -#= + +function linearsolve_dual_solution( + u::Number, partials, dual_type::Type{<:Dual{T, V, P}}) where {T, V <: Dual, P} + # Handle nested duals - recursive case + # For nested duals, u itself could be a dual number with its own partials + inner_dual_type = V + outer_tag_type = T + + # Reconstruct the nested dual by first building the inner dual, then the outer one + inner_dual = u # u is already a dual for the inner level + + # Create outer dual with the inner dual as its value + return Dual{outer_tag_type, typeof(inner_dual), P}(inner_dual, partials) +end + +function linearsolve_dual_solution(u::AbstractArray, partials, + dual_type::Type{<:Dual{T, V, P}}) where {T, V <: Dual, P} + # Handle nested duals for arrays - recursive case + inner_dual_type = V + outer_tag_type = T + + partials_list = RecursiveArrayTools.VectorOfArray(partials) + + # For nested duals, each element of u could be a dual number with its own partials + return map(((uᵢ, pᵢ),) -> Dual{outer_tag_type, typeof(uᵢ), P}(uᵢ, Partials(Tuple(pᵢ))), + zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) +end + function SciMLBase.init( prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; @@ -235,18 +307,23 @@ end -# Helper functions for Dual numbers -get_dual_type(x::Dual) = typeof(x) +# Enhanced helper functions for Dual numbers to handle recursion +get_dual_type(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = typeof(x) +get_dual_type(x::Dual{T, V, P}) where {T, V <: Dual, P} = typeof(x) get_dual_type(x::AbstractArray{<:Dual}) = eltype(x) get_dual_type(x) = nothing -partial_vals(x::Dual) = ForwardDiff.partials(x) +# Add recursive handling for nested dual partials +partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.partials(x) +partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x) partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) partial_vals(x) = nothing +# Add recursive handling for nested dual values nodual_value(x) = x -nodual_value(x::Dual) = ForwardDiff.value(x) -nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) +nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x) +nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact +nodual_value(x::AbstractArray{<:Dual}) = map(nodual_value, x) function partials_to_list(partial_matrix::Vector) From f6db1eefd513f3cf1e049b8a393dd48c62c625be Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 14:15:57 -0400 Subject: [PATCH 09/32] fix support for nested Duals --- ext/LinearSolveForwardDiffExt.jl | 35 ++++++++++++++++++++++++-------- src/default.jl | 5 ++++- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 6328971f7..bb74a12b7 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -1,6 +1,7 @@ module LinearSolveForwardDiffExt using LinearSolve +using LinearSolve: SciMLLinearSolveAlgorithm using LinearAlgebra using ForwardDiff using ForwardDiff: Dual, Partials @@ -92,7 +93,9 @@ end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) # Solve the primal problem + #Main.@infiltrate dual_u0 = copy(cache.linear_cache.u) + #Main.@infiltrate sol = solve!(cache.linear_cache, alg, args...; kwargs...) primal_b = copy(cache.linear_cache.b) uu = sol.u @@ -104,6 +107,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa ∂_b = cache.partials_b rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) + #Main.@infiltrate cache.linear_cache.u = dual_u0 # We can reuse the linear cache, because the same factorization will work for the partials. @@ -152,8 +156,16 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...) end function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; - assump = OperatorAssumptions(issquare(prob.A)), kwargs...) - return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...) + assump = OperatorAssumptions(issquare(nodual_value(prob.A))), kwargs...) + # Extract primal values + primal_A = nodual_value(prob.A) + primal_b = nodual_value(prob.b) + + # Use the default algorithm selection based on primal values + default_alg = LinearSolve.defaultalg(primal_A, primal_b, assump) + + # Solve with the selected algorithm + return solve(prob, default_alg, args...; kwargs...) end function SciMLBase.solve(prob::DualAbstractLinearProblem, @@ -226,10 +238,10 @@ function SciMLBase.init( verbose::Bool = false, Pl = nothing, Pr = nothing, - assumptions = OperatorAssumptions(issquare(prob.A)), + assumptions = nothing, sensealg = LinearSolveAdjoint(), kwargs...) - + @info "here!" (; A, b, u0, p) = prob new_A = nodual_value(A) new_b = nodual_value(b) @@ -240,12 +252,14 @@ function SciMLBase.init( primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0) + assumptions = OperatorAssumptions(issquare(primal_prob.A)) + if get_dual_type(prob.A) !== nothing dual_type = get_dual_type(prob.A) elseif get_dual_type(prob.b) !== nothing dual_type = get_dual_type(prob.b) end - + #Main.@infiltrate non_partial_cache = init( primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, @@ -254,10 +268,15 @@ function SciMLBase.init( end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) + solve!(cache, cache.alg, args...; kwargs...) +end + +function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) + #Main.@infiltrate sol, partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) - + #Main.@infiltrate dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) cache.dual_u = dual_sol @@ -334,9 +353,9 @@ end function partials_to_list(partial_matrix) p = length(first(partial_matrix)) m, n = size(partial_matrix) - res_list = fill(zeros(m, n), p) + res_list = fill(zeros(typeof(partial_matrix[1, 1][1]), m, n), p) for k in 1:p - res = zeros(m, n) + res = zeros(typeof(partial_matrix[1, 1][1]), m, n) for i in 1:m for j in 1:n res[i, j] = partial_matrix[i, j][k] diff --git a/src/default.jl b/src/default.jl index 5051d000e..24d621a49 100644 --- a/src/default.jl +++ b/src/default.jl @@ -362,7 +362,10 @@ end DefaultAlgorithmChoice.AppleAccelerateLUFactorization, DefaultAlgorithmChoice.GenericLUFactorization)) newex = quote - sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) + @info $(algchoice_to_alg(alg)) + alg = $(algchoice_to_alg(alg)) + #Main.@infiltrate + sol = SciMLBase.solve!(cache, alg, args...; kwargs...) if sol.retcode === ReturnCode.Failure && alg.safetyfallback ## TODO: Add verbosity logging here about using the fallback sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...) From c10a8bfb94374d7af8ab0ec92910f4a4c59e9c0d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 15:20:48 -0400 Subject: [PATCH 10/32] clean up --- ext/LinearSolveForwardDiffExt.jl | 1 - src/default.jl | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index bb74a12b7..c46699c91 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -241,7 +241,6 @@ function SciMLBase.init( assumptions = nothing, sensealg = LinearSolveAdjoint(), kwargs...) - @info "here!" (; A, b, u0, p) = prob new_A = nodual_value(A) new_b = nodual_value(b) diff --git a/src/default.jl b/src/default.jl index 24d621a49..5051d000e 100644 --- a/src/default.jl +++ b/src/default.jl @@ -362,10 +362,7 @@ end DefaultAlgorithmChoice.AppleAccelerateLUFactorization, DefaultAlgorithmChoice.GenericLUFactorization)) newex = quote - @info $(algchoice_to_alg(alg)) - alg = $(algchoice_to_alg(alg)) - #Main.@infiltrate - sol = SciMLBase.solve!(cache, alg, args...; kwargs...) + sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) if sol.retcode === ReturnCode.Failure && alg.safetyfallback ## TODO: Add verbosity logging here about using the fallback sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...) From 7abb243f9d04d7d08f25daf4b46fa9938d7b954b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 16:13:07 -0400 Subject: [PATCH 11/32] more clean up --- ext/LinearSolveForwardDiffExt.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index c46699c91..5b04d65d9 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -93,9 +93,7 @@ end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) # Solve the primal problem - #Main.@infiltrate dual_u0 = copy(cache.linear_cache.u) - #Main.@infiltrate sol = solve!(cache.linear_cache, alg, args...; kwargs...) primal_b = copy(cache.linear_cache.b) uu = sol.u @@ -107,7 +105,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa ∂_b = cache.partials_b rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - #Main.@infiltrate cache.linear_cache.u = dual_u0 # We can reuse the linear cache, because the same factorization will work for the partials. @@ -185,13 +182,6 @@ function linearsolve_dual_solution(u::Number, partials, return dual_type(u, partials) end -# function linearsolve_dual_solution( -# u::AbstractArray, partials, dual_type) -# partials_list = RecursiveArrayTools.VectorOfArray(partials) -# return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), -# zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) -# end - function linearsolve_dual_solution(u::AbstractArray, partials, dual_type::Type{<:Dual{T, V, P}}) where {T, V <: AbstractFloat, P} # Handle single-level duals for arrays @@ -258,7 +248,6 @@ function SciMLBase.init( elseif get_dual_type(prob.b) !== nothing dual_type = get_dual_type(prob.b) end - #Main.@infiltrate non_partial_cache = init( primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, @@ -271,11 +260,9 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) end function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) - #Main.@infiltrate sol, partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) - #Main.@infiltrate dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) cache.dual_u = dual_sol From 0717289a3b9da204f71d77e39dc9cba198c03c6d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 16:17:04 -0400 Subject: [PATCH 12/32] allow any number --- ext/LinearSolveForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 5b04d65d9..77131b863 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -10,7 +10,7 @@ using RecursiveArrayTools # Define type for non-nested dual numbers -const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:AbstractFloat , P} +const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Number , P} # Define type for nested dual numbers const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P} From f8299265819fb8dca159da6fb5bd12cf4df4d4ef Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 17:05:16 -0400 Subject: [PATCH 13/32] reuse list --- ext/LinearSolveForwardDiffExt.jl | 68 +++----------------------------- 1 file changed, 6 insertions(+), 62 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 77131b863..bcf699b52 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -61,36 +61,6 @@ LinearSolve.@concrete mutable struct DualLinearCache dual_u end -# function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) -# # Solve the primal problem -# dual_u0 = copy(cache.linear_cache.u) -# sol = solve!(cache.linear_cache, alg, args...; kwargs...) -# primal_b = copy(cache.linear_cache.b) -# uu = sol.u - -# primal_sol = deepcopy(sol) - -# # Solves Dual partials separately -# ∂_A = cache.partials_A -# ∂_b = cache.partials_b - -# rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - -# cache.linear_cache.u = dual_u0 -# # We can reuse the linear cache, because the same factorization will work for the partials. -# for i in eachindex(rhs_list) -# cache.linear_cache.b = rhs_list[i] -# rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u) -# end - -# # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to -# cache.linear_cache.b = primal_b - -# partial_sols = rhs_list - -# primal_sol, partial_sols -# end - function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) # Solve the primal problem dual_u0 = copy(cache.linear_cache.u) @@ -108,17 +78,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa cache.linear_cache.u = dual_u0 # We can reuse the linear cache, because the same factorization will work for the partials. - partial_sols = [] for i in eachindex(rhs_list) cache.linear_cache.b = rhs_list[i] - # For nested duals, the result of this solve might also be a dual number - # which will be handled recursively by the same mechanism - push!(partial_sols, copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)) + rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u) end - # Reset to the original `b` and `u` + # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to cache.linear_cache.b = primal_b + partial_sols = rhs_list + primal_sol, partial_sols end @@ -147,30 +116,6 @@ function xp_linsolve_rhs( b_list end -#= -function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...) - return solve(prob, nothing, args...; kwargs...) -end - -function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; - assump = OperatorAssumptions(issquare(nodual_value(prob.A))), kwargs...) - # Extract primal values - primal_A = nodual_value(prob.A) - primal_b = nodual_value(prob.b) - - # Use the default algorithm selection based on primal values - default_alg = LinearSolve.defaultalg(primal_A, primal_b, assump) - - # Solve with the selected algorithm - return solve(prob, default_alg, args...; kwargs...) -end - -function SciMLBase.solve(prob::DualAbstractLinearProblem, - alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) - solve!(init(prob, alg, args...; kwargs...)) -end -=# - function linearsolve_dual_solution( u::Number, partials, dual_type) return dual_type(u, partials) @@ -252,7 +197,7 @@ function SciMLBase.init( primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) - return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zero.(b)) + return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b))) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) @@ -264,14 +209,13 @@ function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) - + Main.@infiltrate cache.dual_u = dual_sol return SciMLBase.build_linear_solution( cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats ) end -=# # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) From 00431ad2c3a2468f1418380ee4e73258791cf84b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 17:06:48 -0400 Subject: [PATCH 14/32] add tests for nested duals --- ext/LinearSolveForwardDiffExt.jl | 2 +- test/forwarddiff_overloads.jl | 57 ++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index bcf699b52..cee715cf1 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -13,7 +13,7 @@ using RecursiveArrayTools const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Number , P} # Define type for nested dual numbers -const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P} +const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <:Dual, P} const SingleDualLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index d73ec1746..6730ceb76 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -80,3 +80,60 @@ x_p = solve!(cache) backslash_x_p = A \ new_b @test ≈(x_p, backslash_x_p, rtol = 1e-9) + +# Nested Duals +function h(p) + (A = [p[1] p[2]+1 p[2]^3; + 3*p[1] p[1]+5 p[2] * p[1]-4; + p[2]^2 9*p[1] p[2]], + b = [p[1] + 1, p[2] * 2, p[1]^2]) +end + +A, b = h([ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 1.0, 0.0), + ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 0.0, 1.0)]) + +prob = LinearProblem(A, b) +overload_x_p = solve(prob) + +original_x_p = A \ b + +≈(overload_x_p, original_x_p, rtol = 1e-9) + +function linprob_f(p) + A, b = h(p) + prob = LinearProblem(A, b) + solve(prob) +end + +function slash_f(p) + A, b = h(p) + A \ b +end + +≈(ForwardDiff.jacobian(slash_f, [5.0, 5.0]), ForwardDiff.jacobian(linprob_f, [5.0, 5.0])) + +≈(ForwardDiff.jacobian(p -> ForwardDiff.jacobian(slash_f, [5.0, p[1]]), [5.0]), + ForwardDiff.jacobian(p -> ForwardDiff.jacobian(linprob_f, [5.0, p[1]]), [5.0])) + +function g(p) + (A = [p[1] p[1]+1 p[1]^3; + 3*p[1] p[1]+5 p[1] * p[1]-4; + p[1]^2 9*p[1] p[1]], + b = [p[1] + 1, p[1] * 2, p[1]^2]) +end + +function slash_f_hes(p) + A, b = g(p) + x = A \ b + sum(x) +end + +function linprob_f_hes(p) + A, b = g(p) + prob = LinearProblem(A, b) + x = solve(prob) + sum(x) +end + +≈(ForwardDiff.hessian(slash_f_hes, [5.0]), + ForwardDiff.hessian(linprob_f_hes, [5.0])) From 6ac7bf1a8ee427829b63c08f5539ba2f706e710f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 17:44:52 -0400 Subject: [PATCH 15/32] no infiltrator --- ext/LinearSolveForwardDiffExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index cee715cf1..0e9c1925c 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -209,7 +209,6 @@ function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) - Main.@infiltrate cache.dual_u = dual_sol return SciMLBase.build_linear_solution( From 8bae0f786d72cf6b6eeae16819e9c741ac12177b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 19:10:04 -0400 Subject: [PATCH 16/32] proper RAT indexing --- ext/LinearSolveForwardDiffExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 0e9c1925c..b52713422 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -10,7 +10,7 @@ using RecursiveArrayTools # Define type for non-nested dual numbers -const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Number , P} +const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:AbstractFloat , P} # Define type for nested dual numbers const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <:Dual, P} @@ -46,7 +46,7 @@ const DualBLinearProblem = LinearProblem{ } where {iip} const DualAbstractLinearProblem = Union{ - SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem, NestedDualLinearProblem} + SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}#, NestedDualLinearProblem} LinearSolve.@concrete mutable struct DualLinearCache linear_cache @@ -132,7 +132,7 @@ function linearsolve_dual_solution(u::AbstractArray, partials, # Handle single-level duals for arrays partials_list = RecursiveArrayTools.VectorOfArray(partials) return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), - zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) + zip(u, partials_list.u[i, :] for i in 1:length(partials_list.u[1]))) end From 8e8b47e6bd3c534777f32eb37f31062bd1291652 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 19:14:14 -0400 Subject: [PATCH 17/32] all numbers --- ext/LinearSolveForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index b52713422..9625a6fee 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -10,7 +10,7 @@ using RecursiveArrayTools # Define type for non-nested dual numbers -const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:AbstractFloat , P} +const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Number , P} # Define type for nested dual numbers const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <:Dual, P} From 308bc76652ff3e590fd8f0a51084db9d2d4e7ff3 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 14 Jul 2025 19:30:01 -0400 Subject: [PATCH 18/32] correct RAT index --- ext/LinearSolveForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 9625a6fee..7579625ce 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -132,7 +132,7 @@ function linearsolve_dual_solution(u::AbstractArray, partials, # Handle single-level duals for arrays partials_list = RecursiveArrayTools.VectorOfArray(partials) return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), - zip(u, partials_list.u[i, :] for i in 1:length(partials_list.u[1]))) + zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1]))) end From d9ed7265a0904dd39434696ba1b6c69a9123d124 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 15 Jul 2025 14:01:32 -0400 Subject: [PATCH 19/32] get rid of unecessary things --- ext/LinearSolveForwardDiffExt.jl | 65 ++++++-------------------------- 1 file changed, 12 insertions(+), 53 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 7579625ce..4397b7df2 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -8,45 +8,31 @@ using ForwardDiff: Dual, Partials using SciMLBase using RecursiveArrayTools - -# Define type for non-nested dual numbers -const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Number , P} - -# Define type for nested dual numbers -const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <:Dual, P} - -const SingleDualLinearProblem = LinearProblem{ +const DualLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, - <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, <:Any -} where {iip} - -const NestedDualLinearProblem = LinearProblem{ - <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}}, - <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}}, - <:Any -} where {iip} +} where {iip, T, V, P} const DualALinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, <:Union{Number, <:AbstractArray}, <:Any -} where {iip} +} where {iip, T, V, P} const DualBLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, <:Union{Number, <:AbstractArray}, - <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, <:Any -} where {iip} +} where {iip, T, V, P} const DualAbstractLinearProblem = Union{ - SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}#, NestedDualLinearProblem} + DualLinearProblem, DualALinearProblem, DualBLinearProblem} LinearSolve.@concrete mutable struct DualLinearCache linear_cache @@ -63,6 +49,7 @@ end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) # Solve the primal problem + @info "here" dual_u0 = copy(cache.linear_cache.u) sol = solve!(cache.linear_cache, alg, args...; kwargs...) primal_b = copy(cache.linear_cache.b) @@ -122,47 +109,19 @@ function linearsolve_dual_solution( end function linearsolve_dual_solution(u::Number, partials, - dual_type::Type{<:Dual{T, V, P}}) where {T, V <: AbstractFloat, P} + dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} # Handle single-level duals return dual_type(u, partials) end function linearsolve_dual_solution(u::AbstractArray, partials, - dual_type::Type{<:Dual{T, V, P}}) where {T, V <: AbstractFloat, P} + dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} # Handle single-level duals for arrays partials_list = RecursiveArrayTools.VectorOfArray(partials) return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1]))) end - -function linearsolve_dual_solution( - u::Number, partials, dual_type::Type{<:Dual{T, V, P}}) where {T, V <: Dual, P} - # Handle nested duals - recursive case - # For nested duals, u itself could be a dual number with its own partials - inner_dual_type = V - outer_tag_type = T - - # Reconstruct the nested dual by first building the inner dual, then the outer one - inner_dual = u # u is already a dual for the inner level - - # Create outer dual with the inner dual as its value - return Dual{outer_tag_type, typeof(inner_dual), P}(inner_dual, partials) -end - -function linearsolve_dual_solution(u::AbstractArray, partials, - dual_type::Type{<:Dual{T, V, P}}) where {T, V <: Dual, P} - # Handle nested duals for arrays - recursive case - inner_dual_type = V - outer_tag_type = T - - partials_list = RecursiveArrayTools.VectorOfArray(partials) - - # For nested duals, each element of u could be a dual number with its own partials - return map(((uᵢ, pᵢ),) -> Dual{outer_tag_type, typeof(uᵢ), P}(uᵢ, Partials(Tuple(pᵢ))), - zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) -end - function SciMLBase.init( prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; From fdc9774d6eec8d2915935786be36c4005d904b15 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 15 Jul 2025 14:04:07 -0400 Subject: [PATCH 20/32] get rid of log --- ext/LinearSolveForwardDiffExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 4397b7df2..82d72debf 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -49,7 +49,6 @@ end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) # Solve the primal problem - @info "here" dual_u0 = copy(cache.linear_cache.u) sol = solve!(cache.linear_cache, alg, args...; kwargs...) primal_b = copy(cache.linear_cache.b) From 9c11cc28f15e731e7bfc6651834ccccc5f190cf7 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 10:10:17 -0400 Subject: [PATCH 21/32] add more tests --- test/forwarddiff_overloads.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 6730ceb76..2ff12bf79 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -61,6 +61,7 @@ cache = init(prob) new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) cache.A = new_A +@test cache.A = new_A x_p = solve!(cache) backslash_x_p = new_A \ b @@ -75,6 +76,7 @@ cache = init(prob) _, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) cache.b = new_b +@test cache.b == new_b x_p = solve!(cache) backslash_x_p = A \ new_b @@ -97,7 +99,7 @@ overload_x_p = solve(prob) original_x_p = A \ b -≈(overload_x_p, original_x_p, rtol = 1e-9) +@test ≈(overload_x_p, original_x_p, rtol = 1e-9) function linprob_f(p) A, b = h(p) @@ -110,9 +112,9 @@ function slash_f(p) A \ b end -≈(ForwardDiff.jacobian(slash_f, [5.0, 5.0]), ForwardDiff.jacobian(linprob_f, [5.0, 5.0])) +@test ≈(ForwardDiff.jacobian(slash_f, [5.0, 5.0]), ForwardDiff.jacobian(linprob_f, [5.0, 5.0])) -≈(ForwardDiff.jacobian(p -> ForwardDiff.jacobian(slash_f, [5.0, p[1]]), [5.0]), +@test ≈(ForwardDiff.jacobian(p -> ForwardDiff.jacobian(slash_f, [5.0, p[1]]), [5.0]), ForwardDiff.jacobian(p -> ForwardDiff.jacobian(linprob_f, [5.0, p[1]]), [5.0])) function g(p) @@ -135,5 +137,5 @@ function linprob_f_hes(p) sum(x) end -≈(ForwardDiff.hessian(slash_f_hes, [5.0]), +@test ≈(ForwardDiff.hessian(slash_f_hes, [5.0]), ForwardDiff.hessian(linprob_f_hes, [5.0])) From 96244fdfd7edf432d2764930a9cde33d83191827 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 10:10:23 -0400 Subject: [PATCH 22/32] streamline --- ext/LinearSolveForwardDiffExt.jl | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 82d72debf..bd11d1553 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -176,23 +176,22 @@ end # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) - # If the property is A or b, also update it in the LinearCache - if sym === :A || sym === :b || sym === :u - setproperty!(dc.linear_cache, sym, nodual_value(val)) - elseif hasfield(DualLinearCache, sym) - setfield!(dc,sym,val) - elseif hasfield(LinearSolve.LinearCache, sym) - setproperty!(dc.linear_cache, sym, val) - end - - - # Update the partials if setting A or b if sym === :A + setproperty!(dc.linear_cache, sym, nodual_value(val)) + setfield!(dc, :dual_A, val) setfield!(dc, :partials_A, partial_vals(val)) elseif sym === :b + setproperty!(dc.linear_cache, sym, nodual_value(val)) + setfield!(dc, :dual_b, val) setfield!(dc, :partials_b, partial_vals(val)) elseif sym === :u + setproperty!(dc.linear_cache, sym, nodual_value(val)) + setfield!(dc, :dual_u, val) setfield!(dc, :partials_u, partial_vals(val)) + elseif hasfield(DualLinearCache, sym) + setfield!(dc,sym,val) + elseif hasfield(LinearSolve.LinearCache, sym) + setproperty!(dc.linear_cache, sym, val) end end From 762bfb18784e31f00284b8322b2827ec0a550717 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 10:30:59 -0400 Subject: [PATCH 23/32] make sure u is aliased --- ext/LinearSolveForwardDiffExt.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index bd11d1553..3c90f2ad4 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -167,7 +167,12 @@ function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) - cache.dual_u = dual_sol + + if cache.dual_u isa AbstractArray + cache.dual_u[:] = dual_sol + else + cache.dual_u = dual_sol + end return SciMLBase.build_linear_solution( cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats From 248718fcb4842ab78709d98d243ad9ebd18a1087 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 11:28:41 -0400 Subject: [PATCH 24/32] add tests for sparse arrays and sparse solvers --- test/forwarddiff_overloads.jl | 43 ++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 2ff12bf79..12988b67b 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -1,6 +1,7 @@ using LinearSolve using ForwardDiff using Test +using SparseArrays function h(p) (A = [p[1] p[2]+1 p[2]^3; @@ -48,6 +49,9 @@ new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1. cache.A = new_A cache.b = new_b +@test cache.A == new_A +@test cache.b == new_b + x_p = solve!(cache) backslash_x_p = new_A \ new_b @@ -61,7 +65,7 @@ cache = init(prob) new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) cache.A = new_A -@test cache.A = new_A +@test cache.A == new_A x_p = solve!(cache) backslash_x_p = new_A \ b @@ -139,3 +143,40 @@ end @test ≈(ForwardDiff.hessian(slash_f_hes, [5.0]), ForwardDiff.hessian(linprob_f_hes, [5.0])) + + +# Test aliasing + +prob = LinearProblem(A, b) +cache = init(prob) + +new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +cache.A = new_A +cache.b = new_b + +linu = [ForwardDiff.Dual(0.0, 0.0, 0.0), ForwardDiff.Dual(0.0, 0.0, 0.0), + ForwardDiff.Dual(0.0, 0.0, 0.0)] +cache.u = linu +x_p = solve!(cache) +backslash_x_p = new_A \ new_b + +@test linu == cache.u + + +# Test Float Only solvers + +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(sparse(A), sparse(b)) +overload_x_p = solve(prob, KLUFactorization()) +backslash_x_p = A \ b + +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) + +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +overload_x_p = solve(prob, UMFPACKFactorization()) +backslash_x_p = A \ b + +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) \ No newline at end of file From 5497adc4ce3120964640fd5e226fe2d753a69a7a Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 11:29:09 -0400 Subject: [PATCH 25/32] make sure AbstractArrays of size 1 are accounted for --- ext/LinearSolveForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 3c90f2ad4..349e76914 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -236,7 +236,7 @@ nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inn nodual_value(x::AbstractArray{<:Dual}) = map(nodual_value, x) -function partials_to_list(partial_matrix::Vector) +function partials_to_list(partial_matrix::AbstractArray{T, 1}) where {T} p = eachindex(first(partial_matrix)) [[partial[i] for partial in partial_matrix] for i in p] end From e842eac70e8d1413c0c5449a9091d8d574f68503 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 11:43:47 -0400 Subject: [PATCH 26/32] use AbstractVector, fix setproperty! --- ext/LinearSolveForwardDiffExt.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 349e76914..a44ca4173 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -181,22 +181,26 @@ end # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) - if sym === :A + # If the property is A or b, also update it in the LinearCache + if sym === :A || sym === :b || sym === :u setproperty!(dc.linear_cache, sym, nodual_value(val)) + elseif hasfield(DualLinearCache, sym) + setfield!(dc, sym, val) + elseif hasfield(LinearSolve.LinearCache, sym) + setproperty!(dc.linear_cache, sym, val) + end + + + # Update the partials if setting A or b + if sym === :A setfield!(dc, :dual_A, val) setfield!(dc, :partials_A, partial_vals(val)) elseif sym === :b - setproperty!(dc.linear_cache, sym, nodual_value(val)) setfield!(dc, :dual_b, val) setfield!(dc, :partials_b, partial_vals(val)) elseif sym === :u - setproperty!(dc.linear_cache, sym, nodual_value(val)) setfield!(dc, :dual_u, val) setfield!(dc, :partials_u, partial_vals(val)) - elseif hasfield(DualLinearCache, sym) - setfield!(dc,sym,val) - elseif hasfield(LinearSolve.LinearCache, sym) - setproperty!(dc.linear_cache, sym, val) end end @@ -236,7 +240,7 @@ nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inn nodual_value(x::AbstractArray{<:Dual}) = map(nodual_value, x) -function partials_to_list(partial_matrix::AbstractArray{T, 1}) where {T} +function partials_to_list(partial_matrix::AbstractVector{T}) where {T} p = eachindex(first(partial_matrix)) [[partial[i] for partial in partial_matrix] for i in p] end From 1b1697ac3a58123aa0a7ebf23570a538aa033693 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 12:09:37 -0400 Subject: [PATCH 27/32] add test for setting nested Duals --- test/forwarddiff_overloads.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 12988b67b..3b44de8ed 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -105,6 +105,19 @@ original_x_p = A \ b @test ≈(overload_x_p, original_x_p, rtol = 1e-9) +prob = LinearProblem(A, b) +cache = init(prob) + +new_A, new_b = h([ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 1.0, 0.0), + ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 0.0, 1.0)]) + +cache.A = new_A +cache.b = new_b + +@test cache.A == new_A +@test cache.b == new_b + + function linprob_f(p) A, b = h(p) prob = LinearProblem(A, b) From 2e003eec34e253f4681c6f9e4194379bc4dd0b6f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 12:52:35 -0400 Subject: [PATCH 28/32] fix test --- test/forwarddiff_overloads.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 3b44de8ed..20672190f 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -24,12 +24,11 @@ krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES()) @test ≈(krylov_u0_sol, backslash_x_p, rtol = 1e-9) - A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) backslash_x_p = A \ [6.0, 10.0, 25.0] prob = LinearProblem(A, [6.0, 10.0, 25.0]) -@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) +@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) @test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9) _, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) @@ -117,7 +116,6 @@ cache.b = new_b @test cache.A == new_A @test cache.b == new_b - function linprob_f(p) A, b = h(p) prob = LinearProblem(A, b) @@ -129,7 +127,8 @@ function slash_f(p) A \ b end -@test ≈(ForwardDiff.jacobian(slash_f, [5.0, 5.0]), ForwardDiff.jacobian(linprob_f, [5.0, 5.0])) +@test ≈( + ForwardDiff.jacobian(slash_f, [5.0, 5.0]), ForwardDiff.jacobian(linprob_f, [5.0, 5.0])) @test ≈(ForwardDiff.jacobian(p -> ForwardDiff.jacobian(slash_f, [5.0, p[1]]), [5.0]), ForwardDiff.jacobian(p -> ForwardDiff.jacobian(linprob_f, [5.0, p[1]]), [5.0])) @@ -157,8 +156,8 @@ end @test ≈(ForwardDiff.hessian(slash_f_hes, [5.0]), ForwardDiff.hessian(linprob_f_hes, [5.0])) - # Test aliasing +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, b) cache = init(prob) @@ -175,7 +174,6 @@ backslash_x_p = new_A \ new_b @test linu == cache.u - # Test Float Only solvers A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) From 47d4d6875face1685c918d83365dfbc8fc1e7d26 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 13:06:11 -0400 Subject: [PATCH 29/32] make sparse --- test/forwarddiff_overloads.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 20672190f..a53a4cf0e 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -186,7 +186,7 @@ backslash_x_p = A \ b A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) -prob = LinearProblem(A, b) +prob = LinearProblem(sparse(A), sparse(b)) overload_x_p = solve(prob, UMFPACKFactorization()) backslash_x_p = A \ b From 6323432426515b9e451e360f10cf9210fc603f2c Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 13:54:04 -0400 Subject: [PATCH 30/32] fix default solver --- ext/LinearSolveForwardDiffExt.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index a44ca4173..b1f2f274f 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -49,6 +49,7 @@ end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) # Solve the primal problem + @info "here" dual_u0 = copy(cache.linear_cache.u) sol = solve!(cache.linear_cache, alg, args...; kwargs...) primal_b = copy(cache.linear_cache.b) @@ -131,7 +132,7 @@ function SciMLBase.init( verbose::Bool = false, Pl = nothing, Pr = nothing, - assumptions = nothing, + assumptions = OperatorAssumptions(issquare(prob.A)), sensealg = LinearSolveAdjoint(), kwargs...) (; A, b, u0, p) = prob @@ -144,15 +145,15 @@ function SciMLBase.init( primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0) - assumptions = OperatorAssumptions(issquare(primal_prob.A)) - if get_dual_type(prob.A) !== nothing dual_type = get_dual_type(prob.A) elseif get_dual_type(prob.b) !== nothing dual_type = get_dual_type(prob.b) end + Main.@infiltrate non_partial_cache = init( - primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, + primal_prob, LinearSolve.defaultalg(primal_prob.A, primal_prob.b, assumptions), args...; + alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b))) From b1ffa15a2c6fdbe0ca3a2ad20754f7dd5507eb1c Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 16 Jul 2025 14:04:06 -0400 Subject: [PATCH 31/32] allow setting alg --- ext/LinearSolveForwardDiffExt.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index b1f2f274f..bd9d9ac85 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -150,9 +150,11 @@ function SciMLBase.init( elseif get_dual_type(prob.b) !== nothing dual_type = get_dual_type(prob.b) end - Main.@infiltrate + + alg isa LinearSolve.DefaultLinearSolver ? real_alg = LinearSolve.defaultalg(primal_prob.A, primal_prob.b) : real_alg = alg + non_partial_cache = init( - primal_prob, LinearSolve.defaultalg(primal_prob.A, primal_prob.b, assumptions), args...; + primal_prob, real_alg, assumptions, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) From 62127b87943f328a92d09906bff4b2ab1af599f2 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 17 Jul 2025 18:13:02 -0400 Subject: [PATCH 32/32] Update ext/LinearSolveForwardDiffExt.jl --- ext/LinearSolveForwardDiffExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index bd9d9ac85..d133d75c0 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -49,7 +49,6 @@ end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) # Solve the primal problem - @info "here" dual_u0 = copy(cache.linear_cache.u) sol = solve!(cache.linear_cache, alg, args...; kwargs...) primal_b = copy(cache.linear_cache.b)