Skip to content

Commit f6db1ee

Browse files
committed
fix support for nested Duals
1 parent 43a0379 commit f6db1ee

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LinearSolveForwardDiffExt
22

33
using LinearSolve
4+
using LinearSolve: SciMLLinearSolveAlgorithm
45
using LinearAlgebra
56
using ForwardDiff
67
using ForwardDiff: Dual, Partials
@@ -92,7 +93,9 @@ end
9293

9394
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
9495
# Solve the primal problem
96+
#Main.@infiltrate
9597
dual_u0 = copy(cache.linear_cache.u)
98+
#Main.@infiltrate
9699
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
97100
primal_b = copy(cache.linear_cache.b)
98101
uu = sol.u
@@ -104,6 +107,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
104107
∂_b = cache.partials_b
105108

106109
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
110+
#Main.@infiltrate
107111

108112
cache.linear_cache.u = dual_u0
109113
# We can reuse the linear cache, because the same factorization will work for the partials.
@@ -152,8 +156,16 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
152156
end
153157
154158
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
155-
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
156-
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
159+
assump = OperatorAssumptions(issquare(nodual_value(prob.A))), kwargs...)
160+
# Extract primal values
161+
primal_A = nodual_value(prob.A)
162+
primal_b = nodual_value(prob.b)
163+
164+
# Use the default algorithm selection based on primal values
165+
default_alg = LinearSolve.defaultalg(primal_A, primal_b, assump)
166+
167+
# Solve with the selected algorithm
168+
return solve(prob, default_alg, args...; kwargs...)
157169
end
158170
159171
function SciMLBase.solve(prob::DualAbstractLinearProblem,
@@ -226,10 +238,10 @@ function SciMLBase.init(
226238
verbose::Bool = false,
227239
Pl = nothing,
228240
Pr = nothing,
229-
assumptions = OperatorAssumptions(issquare(prob.A)),
241+
assumptions = nothing,
230242
sensealg = LinearSolveAdjoint(),
231243
kwargs...)
232-
244+
@info "here!"
233245
(; A, b, u0, p) = prob
234246
new_A = nodual_value(A)
235247
new_b = nodual_value(b)
@@ -240,12 +252,14 @@ function SciMLBase.init(
240252

241253
primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0)
242254

255+
assumptions = OperatorAssumptions(issquare(primal_prob.A))
256+
243257
if get_dual_type(prob.A) !== nothing
244258
dual_type = get_dual_type(prob.A)
245259
elseif get_dual_type(prob.b) !== nothing
246260
dual_type = get_dual_type(prob.b)
247261
end
248-
262+
#Main.@infiltrate
249263
non_partial_cache = init(
250264
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
251265
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
@@ -254,10 +268,15 @@ function SciMLBase.init(
254268
end
255269

256270
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
271+
solve!(cache, cache.alg, args...; kwargs...)
272+
end
273+
274+
function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
275+
#Main.@infiltrate
257276
sol,
258277
partials = linearsolve_forwarddiff_solve(
259278
cache::DualLinearCache, cache.alg, args...; kwargs...)
260-
279+
#Main.@infiltrate
261280
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
262281

263282
cache.dual_u = dual_sol
@@ -334,9 +353,9 @@ end
334353
function partials_to_list(partial_matrix)
335354
p = length(first(partial_matrix))
336355
m, n = size(partial_matrix)
337-
res_list = fill(zeros(m, n), p)
356+
res_list = fill(zeros(typeof(partial_matrix[1, 1][1]), m, n), p)
338357
for k in 1:p
339-
res = zeros(m, n)
358+
res = zeros(typeof(partial_matrix[1, 1][1]), m, n)
340359
for i in 1:m
341360
for j in 1:n
342361
res[i, j] = partial_matrix[i, j][k]

src/default.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,10 @@ end
362362
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
363363
DefaultAlgorithmChoice.GenericLUFactorization))
364364
newex = quote
365-
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
365+
@info $(algchoice_to_alg(alg))
366+
alg = $(algchoice_to_alg(alg))
367+
#Main.@infiltrate
368+
sol = SciMLBase.solve!(cache, alg, args...; kwargs...)
366369
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
367370
## TODO: Add verbosity logging here about using the fallback
368371
sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...)

0 commit comments

Comments
 (0)