Skip to content

Commit cf0eccd

Browse files
committed
fix support for nested Duals
1 parent 2d7700d commit cf0eccd

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LinearSolveForwardDiffExt
22

33
using LinearSolve
4+
using LinearSolve: SciMLLinearSolveAlgorithm
45
using LinearAlgebra
56
using ForwardDiff
67
using ForwardDiff: Dual, Partials
@@ -92,7 +93,9 @@ end
9293

9394
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
9495
# Solve the primal problem
96+
#Main.@infiltrate
9597
dual_u0 = copy(cache.linear_cache.u)
98+
#Main.@infiltrate
9699
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
97100
primal_b = copy(cache.linear_cache.b)
98101
uu = sol.u
@@ -104,6 +107,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
104107
∂_b = cache.partials_b
105108

106109
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
110+
#Main.@infiltrate
107111

108112
cache.linear_cache.u = dual_u0
109113
# We can reuse the linear cache, because the same factorization will work for the partials.
@@ -151,8 +155,16 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
151155
end
152156

153157
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
154-
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
155-
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
158+
assump = OperatorAssumptions(issquare(nodual_value(prob.A))), kwargs...)
159+
# Extract primal values
160+
primal_A = nodual_value(prob.A)
161+
primal_b = nodual_value(prob.b)
162+
163+
# Use the default algorithm selection based on primal values
164+
default_alg = LinearSolve.defaultalg(primal_A, primal_b, assump)
165+
166+
# Solve with the selected algorithm
167+
return solve(prob, default_alg, args...; kwargs...)
156168
end
157169

158170
function SciMLBase.solve(prob::DualAbstractLinearProblem,
@@ -224,10 +236,10 @@ function SciMLBase.init(
224236
verbose::Bool = false,
225237
Pl = nothing,
226238
Pr = nothing,
227-
assumptions = OperatorAssumptions(issquare(prob.A)),
239+
assumptions = nothing,
228240
sensealg = LinearSolveAdjoint(),
229241
kwargs...)
230-
242+
@info "here!"
231243
(; A, b, u0, p) = prob
232244
new_A = nodual_value(A)
233245
new_b = nodual_value(b)
@@ -238,12 +250,14 @@ function SciMLBase.init(
238250

239251
primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0)
240252

253+
assumptions = OperatorAssumptions(issquare(primal_prob.A))
254+
241255
if get_dual_type(prob.A) !== nothing
242256
dual_type = get_dual_type(prob.A)
243257
elseif get_dual_type(prob.b) !== nothing
244258
dual_type = get_dual_type(prob.b)
245259
end
246-
260+
#Main.@infiltrate
247261
non_partial_cache = init(
248262
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
249263
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
@@ -252,10 +266,15 @@ function SciMLBase.init(
252266
end
253267

254268
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
269+
solve!(cache, cache.alg, args...; kwargs...)
270+
end
271+
272+
function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
273+
#Main.@infiltrate
255274
sol,
256275
partials = linearsolve_forwarddiff_solve(
257276
cache::DualLinearCache, cache.alg, args...; kwargs...)
258-
277+
#Main.@infiltrate
259278
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
260279

261280
cache.dual_u = dual_sol
@@ -331,9 +350,9 @@ end
331350
function partials_to_list(partial_matrix)
332351
p = length(first(partial_matrix))
333352
m, n = size(partial_matrix)
334-
res_list = fill(zeros(m, n), p)
353+
res_list = fill(zeros(typeof(partial_matrix[1, 1][1]), m, n), p)
335354
for k in 1:p
336-
res = zeros(m, n)
355+
res = zeros(typeof(partial_matrix[1, 1][1]), m, n)
337356
for i in 1:m
338357
for j in 1:n
339358
res[i, j] = partial_matrix[i, j][k]

src/default.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,10 @@ end
362362
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
363363
DefaultAlgorithmChoice.GenericLUFactorization))
364364
newex = quote
365-
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
365+
@info $(algchoice_to_alg(alg))
366+
alg = $(algchoice_to_alg(alg))
367+
#Main.@infiltrate
368+
sol = SciMLBase.solve!(cache, alg, args...; kwargs...)
366369
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
367370
## TODO: Add verbosity logging here about using the fallback
368371
sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...)

0 commit comments

Comments
 (0)