1
1
module LinearSolveForwardDiffExt
2
2
3
3
using LinearSolve
4
+ using LinearSolve: SciMLLinearSolveAlgorithm
4
5
using LinearAlgebra
5
6
using ForwardDiff
6
7
using ForwardDiff: Dual, Partials
92
93
93
94
function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
94
95
# Solve the primal problem
96
+ # Main.@infiltrate
95
97
dual_u0 = copy (cache. linear_cache. u)
98
+ # Main.@infiltrate
96
99
sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
97
100
primal_b = copy (cache. linear_cache. b)
98
101
uu = sol. u
@@ -104,6 +107,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
104
107
∂_b = cache. partials_b
105
108
106
109
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
110
+ # Main.@infiltrate
107
111
108
112
cache. linear_cache. u = dual_u0
109
113
# We can reuse the linear cache, because the same factorization will work for the partials.
@@ -152,8 +156,16 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
152
156
end
153
157
154
158
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
155
- assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
156
- return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
159
+ assump = OperatorAssumptions(issquare(nodual_value(prob.A))), kwargs...)
160
+ # Extract primal values
161
+ primal_A = nodual_value(prob.A)
162
+ primal_b = nodual_value(prob.b)
163
+
164
+ # Use the default algorithm selection based on primal values
165
+ default_alg = LinearSolve.defaultalg(primal_A, primal_b, assump)
166
+
167
+ # Solve with the selected algorithm
168
+ return solve(prob, default_alg, args...; kwargs...)
157
169
end
158
170
159
171
function SciMLBase.solve(prob::DualAbstractLinearProblem,
@@ -226,10 +238,10 @@ function SciMLBase.init(
226
238
verbose:: Bool = false ,
227
239
Pl = nothing ,
228
240
Pr = nothing ,
229
- assumptions = OperatorAssumptions ( issquare (prob . A)) ,
241
+ assumptions = nothing ,
230
242
sensealg = LinearSolveAdjoint (),
231
243
kwargs... )
232
-
244
+ @info " here! "
233
245
(; A, b, u0, p) = prob
234
246
new_A = nodual_value (A)
235
247
new_b = nodual_value (b)
@@ -240,12 +252,14 @@ function SciMLBase.init(
240
252
241
253
primal_prob = remake (prob; A = new_A, b = new_b, u0 = new_u0)
242
254
255
+ assumptions = OperatorAssumptions (issquare (primal_prob. A))
256
+
243
257
if get_dual_type (prob. A) != = nothing
244
258
dual_type = get_dual_type (prob. A)
245
259
elseif get_dual_type (prob. b) != = nothing
246
260
dual_type = get_dual_type (prob. b)
247
261
end
248
-
262
+ # Main.@infiltrate
249
263
non_partial_cache = init (
250
264
primal_prob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
251
265
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
@@ -254,10 +268,15 @@ function SciMLBase.init(
254
268
end
255
269
256
270
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
271
+ solve! (cache, cache. alg, args... ; kwargs... )
272
+ end
273
+
274
+ function SciMLBase. solve! (cache:: DualLinearCache , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
275
+ # Main.@infiltrate
257
276
sol,
258
277
partials = linearsolve_forwarddiff_solve (
259
278
cache:: DualLinearCache , cache. alg, args... ; kwargs... )
260
-
279
+ # Main.@infiltrate
261
280
dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
262
281
263
282
cache. dual_u = dual_sol
334
353
function partials_to_list (partial_matrix)
335
354
p = length (first (partial_matrix))
336
355
m, n = size (partial_matrix)
337
- res_list = fill (zeros (m, n), p)
356
+ res_list = fill (zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n), p)
338
357
for k in 1 : p
339
- res = zeros (m, n)
358
+ res = zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n)
340
359
for i in 1 : m
341
360
for j in 1 : n
342
361
res[i, j] = partial_matrix[i, j][k]
0 commit comments