1
1
module LinearSolveForwardDiffExt
2
2
3
3
using LinearSolve
4
+ using LinearSolve: SciMLLinearSolveAlgorithm
4
5
using LinearAlgebra
5
6
using ForwardDiff
6
7
using ForwardDiff: Dual, Partials
92
93
93
94
function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
94
95
# Solve the primal problem
96
+ # Main.@infiltrate
95
97
dual_u0 = copy (cache. linear_cache. u)
98
+ # Main.@infiltrate
96
99
sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
97
100
primal_b = copy (cache. linear_cache. b)
98
101
uu = sol. u
@@ -104,6 +107,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
104
107
∂_b = cache. partials_b
105
108
106
109
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
110
+ # Main.@infiltrate
107
111
108
112
cache. linear_cache. u = dual_u0
109
113
# We can reuse the linear cache, because the same factorization will work for the partials.
@@ -151,8 +155,16 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
151
155
end
152
156
153
157
function SciMLBase. solve (prob:: DualAbstractLinearProblem , :: Nothing , args... ;
154
- assump = OperatorAssumptions (issquare (prob. A)), kwargs... )
155
- return solve (prob, LinearSolve. defaultalg (prob. A, prob. b, assump), args... ; kwargs... )
158
+ assump = OperatorAssumptions (issquare (nodual_value (prob. A))), kwargs... )
159
+ # Extract primal values
160
+ primal_A = nodual_value (prob. A)
161
+ primal_b = nodual_value (prob. b)
162
+
163
+ # Use the default algorithm selection based on primal values
164
+ default_alg = LinearSolve. defaultalg (primal_A, primal_b, assump)
165
+
166
+ # Solve with the selected algorithm
167
+ return solve (prob, default_alg, args... ; kwargs... )
156
168
end
157
169
158
170
function SciMLBase. solve (prob:: DualAbstractLinearProblem ,
@@ -224,10 +236,10 @@ function SciMLBase.init(
224
236
verbose:: Bool = false ,
225
237
Pl = nothing ,
226
238
Pr = nothing ,
227
- assumptions = OperatorAssumptions ( issquare (prob . A)) ,
239
+ assumptions = nothing ,
228
240
sensealg = LinearSolveAdjoint (),
229
241
kwargs... )
230
-
242
+ @info " here! "
231
243
(; A, b, u0, p) = prob
232
244
new_A = nodual_value (A)
233
245
new_b = nodual_value (b)
@@ -238,12 +250,14 @@ function SciMLBase.init(
238
250
239
251
primal_prob = remake (prob; A = new_A, b = new_b, u0 = new_u0)
240
252
253
+ assumptions = OperatorAssumptions (issquare (primal_prob. A))
254
+
241
255
if get_dual_type (prob. A) != = nothing
242
256
dual_type = get_dual_type (prob. A)
243
257
elseif get_dual_type (prob. b) != = nothing
244
258
dual_type = get_dual_type (prob. b)
245
259
end
246
-
260
+ # Main.@infiltrate
247
261
non_partial_cache = init (
248
262
primal_prob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
249
263
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
@@ -252,10 +266,15 @@ function SciMLBase.init(
252
266
end
253
267
254
268
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
269
+ solve! (cache, cache. alg, args... ; kwargs... )
270
+ end
271
+
272
+ function SciMLBase. solve! (cache:: DualLinearCache , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
273
+ # Main.@infiltrate
255
274
sol,
256
275
partials = linearsolve_forwarddiff_solve (
257
276
cache:: DualLinearCache , cache. alg, args... ; kwargs... )
258
-
277
+ # Main.@infiltrate
259
278
dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
260
279
261
280
cache. dual_u = dual_sol
331
350
function partials_to_list (partial_matrix)
332
351
p = length (first (partial_matrix))
333
352
m, n = size (partial_matrix)
334
- res_list = fill (zeros (m, n), p)
353
+ res_list = fill (zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n), p)
335
354
for k in 1 : p
336
- res = zeros (m, n)
355
+ res = zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n)
337
356
for i in 1 : m
338
357
for j in 1 : n
339
358
res[i, j] = partial_matrix[i, j][k]
0 commit comments