@@ -45,7 +45,7 @@ const DualBLinearProblem = LinearProblem{
45
45
} where {iip}
46
46
47
47
const DualAbstractLinearProblem = Union{
48
- SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}
48
+ SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem, NestedDualLinearProblem }
49
49
50
50
LinearSolve. @concrete mutable struct DualLinearCache
51
51
linear_cache
@@ -60,6 +60,36 @@ LinearSolve.@concrete mutable struct DualLinearCache
60
60
dual_u
61
61
end
62
62
63
+ # function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
64
+ # # Solve the primal problem
65
+ # dual_u0 = copy(cache.linear_cache.u)
66
+ # sol = solve!(cache.linear_cache, alg, args...; kwargs...)
67
+ # primal_b = copy(cache.linear_cache.b)
68
+ # uu = sol.u
69
+
70
+ # primal_sol = deepcopy(sol)
71
+
72
+ # # Solves Dual partials separately
73
+ # ∂_A = cache.partials_A
74
+ # ∂_b = cache.partials_b
75
+
76
+ # rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
77
+
78
+ # cache.linear_cache.u = dual_u0
79
+ # # We can reuse the linear cache, because the same factorization will work for the partials.
80
+ # for i in eachindex(rhs_list)
81
+ # cache.linear_cache.b = rhs_list[i]
82
+ # rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
83
+ # end
84
+
85
+ # # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
86
+ # cache.linear_cache.b = primal_b
87
+
88
+ # partial_sols = rhs_list
89
+
90
+ # primal_sol, partial_sols
91
+ # end
92
+
63
93
function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
64
94
# Solve the primal problem
65
95
dual_u0 = copy (cache. linear_cache. u)
@@ -77,16 +107,17 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
77
107
78
108
cache. linear_cache. u = dual_u0
79
109
# We can reuse the linear cache, because the same factorization will work for the partials.
110
+ partial_sols = []
80
111
for i in eachindex (rhs_list)
81
112
cache. linear_cache. b = rhs_list[i]
82
- rhs_list[i] = copy (solve! (cache. linear_cache, alg, args... ; kwargs... ). u)
113
+ # For nested duals, the result of this solve might also be a dual number
114
+ # which will be handled recursively by the same mechanism
115
+ push! (partial_sols, copy (solve! (cache. linear_cache, alg, args... ; kwargs... ). u))
83
116
end
84
117
85
- # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
118
+ # Reset to the original `b` and `u`
86
119
cache. linear_cache. b = primal_b
87
120
88
- partial_sols = rhs_list
89
-
90
121
primal_sol, partial_sols
91
122
end
92
123
@@ -136,14 +167,55 @@ function linearsolve_dual_solution(
136
167
return dual_type (u, partials)
137
168
end
138
169
139
- function linearsolve_dual_solution (
140
- u:: AbstractArray , partials, dual_type)
170
+ function linearsolve_dual_solution (u:: Number , partials,
171
+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: AbstractFloat , P}
172
+ # Handle single-level duals
173
+ return dual_type (u, partials)
174
+ end
175
+
176
+ # function linearsolve_dual_solution(
177
+ # u::AbstractArray, partials, dual_type)
178
+ # partials_list = RecursiveArrayTools.VectorOfArray(partials)
179
+ # return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
180
+ # zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
181
+ # end
182
+
183
+ function linearsolve_dual_solution (u:: AbstractArray , partials,
184
+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: AbstractFloat , P}
185
+ # Handle single-level duals for arrays
141
186
partials_list = RecursiveArrayTools. VectorOfArray (partials)
142
187
return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
143
188
zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
144
189
end
145
190
146
- #=
191
+
192
+ function linearsolve_dual_solution (
193
+ u:: Number , partials, dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: Dual , P}
194
+ # Handle nested duals - recursive case
195
+ # For nested duals, u itself could be a dual number with its own partials
196
+ inner_dual_type = V
197
+ outer_tag_type = T
198
+
199
+ # Reconstruct the nested dual by first building the inner dual, then the outer one
200
+ inner_dual = u # u is already a dual for the inner level
201
+
202
+ # Create outer dual with the inner dual as its value
203
+ return Dual {outer_tag_type, typeof(inner_dual), P} (inner_dual, partials)
204
+ end
205
+
206
+ function linearsolve_dual_solution (u:: AbstractArray , partials,
207
+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: Dual , P}
208
+ # Handle nested duals for arrays - recursive case
209
+ inner_dual_type = V
210
+ outer_tag_type = T
211
+
212
+ partials_list = RecursiveArrayTools. VectorOfArray (partials)
213
+
214
+ # For nested duals, each element of u could be a dual number with its own partials
215
+ return map (((uᵢ, pᵢ),) -> Dual {outer_tag_type, typeof(uᵢ), P} (uᵢ, Partials (Tuple (pᵢ))),
216
+ zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
217
+ end
218
+
147
219
function SciMLBase. init (
148
220
prob:: DualAbstractLinearProblem , alg:: LinearSolve.SciMLLinearSolveAlgorithm ,
149
221
args... ;
@@ -235,18 +307,23 @@ end
235
307
236
308
237
309
238
- # Helper functions for Dual numbers
239
- get_dual_type (x:: Dual ) = typeof (x)
310
+ # Enhanced helper functions for Dual numbers to handle recursion
311
+ get_dual_type (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = typeof (x)
312
+ get_dual_type (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = typeof (x)
240
313
get_dual_type (x:: AbstractArray{<:Dual} ) = eltype (x)
241
314
get_dual_type (x) = nothing
242
315
243
- partial_vals (x:: Dual ) = ForwardDiff. partials (x)
316
+ # Add recursive handling for nested dual partials
317
+ partial_vals (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. partials (x)
318
+ partial_vals (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = ForwardDiff. partials (x)
244
319
partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. partials, x)
245
320
partial_vals (x) = nothing
246
321
322
+ # Add recursive handling for nested dual values
247
323
nodual_value (x) = x
248
- nodual_value (x:: Dual ) = ForwardDiff. value (x)
249
- nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
324
+ nodual_value (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. value (x)
325
+ nodual_value (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = x. value # Keep the inner dual intact
326
+ nodual_value (x:: AbstractArray{<:Dual} ) = map (nodual_value, x)
250
327
251
328
252
329
function partials_to_list (partial_matrix:: Vector )
0 commit comments