Skip to content

Commit 43a0379

Browse files
committed
add recursive support for nested Duals
1 parent 5dd9ce0 commit 43a0379

File tree

1 file changed

+90
-13
lines changed

1 file changed

+90
-13
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ const DualBLinearProblem = LinearProblem{
4545
} where {iip}
4646

4747
const DualAbstractLinearProblem = Union{
48-
SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}
48+
SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem, NestedDualLinearProblem}
4949

5050
LinearSolve.@concrete mutable struct DualLinearCache
5151
linear_cache
@@ -60,6 +60,36 @@ LinearSolve.@concrete mutable struct DualLinearCache
6060
dual_u
6161
end
6262

63+
# function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
64+
# # Solve the primal problem
65+
# dual_u0 = copy(cache.linear_cache.u)
66+
# sol = solve!(cache.linear_cache, alg, args...; kwargs...)
67+
# primal_b = copy(cache.linear_cache.b)
68+
# uu = sol.u
69+
70+
# primal_sol = deepcopy(sol)
71+
72+
# # Solves Dual partials separately
73+
# ∂_A = cache.partials_A
74+
# ∂_b = cache.partials_b
75+
76+
# rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
77+
78+
# cache.linear_cache.u = dual_u0
79+
# # We can reuse the linear cache, because the same factorization will work for the partials.
80+
# for i in eachindex(rhs_list)
81+
# cache.linear_cache.b = rhs_list[i]
82+
# rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
83+
# end
84+
85+
# # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
86+
# cache.linear_cache.b = primal_b
87+
88+
# partial_sols = rhs_list
89+
90+
# primal_sol, partial_sols
91+
# end
92+
6393
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
6494
# Solve the primal problem
6595
dual_u0 = copy(cache.linear_cache.u)
@@ -77,16 +107,17 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
77107

78108
cache.linear_cache.u = dual_u0
79109
# We can reuse the linear cache, because the same factorization will work for the partials.
110+
partial_sols = []
80111
for i in eachindex(rhs_list)
81112
cache.linear_cache.b = rhs_list[i]
82-
rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
113+
# For nested duals, the result of this solve might also be a dual number
114+
# which will be handled recursively by the same mechanism
115+
push!(partial_sols, copy(solve!(cache.linear_cache, alg, args...; kwargs...).u))
83116
end
84117

85-
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
118+
# Reset to the original `b` and `u`
86119
cache.linear_cache.b = primal_b
87120

88-
partial_sols = rhs_list
89-
90121
primal_sol, partial_sols
91122
end
92123

@@ -136,14 +167,55 @@ function linearsolve_dual_solution(
136167
return dual_type(u, partials)
137168
end
138169

139-
function linearsolve_dual_solution(
140-
u::AbstractArray, partials, dual_type)
170+
function linearsolve_dual_solution(u::Number, partials,
171+
dual_type::Type{<:Dual{T, V, P}}) where {T, V <: AbstractFloat, P}
172+
# Handle single-level duals
173+
return dual_type(u, partials)
174+
end
175+
176+
# function linearsolve_dual_solution(
177+
# u::AbstractArray, partials, dual_type)
178+
# partials_list = RecursiveArrayTools.VectorOfArray(partials)
179+
# return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
180+
# zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
181+
# end
182+
183+
function linearsolve_dual_solution(u::AbstractArray, partials,
184+
dual_type::Type{<:Dual{T, V, P}}) where {T, V <: AbstractFloat, P}
185+
# Handle single-level duals for arrays
141186
partials_list = RecursiveArrayTools.VectorOfArray(partials)
142187
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
143188
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
144189
end
145190

146-
#=
191+
192+
function linearsolve_dual_solution(
193+
u::Number, partials, dual_type::Type{<:Dual{T, V, P}}) where {T, V <: Dual, P}
194+
# Handle nested duals - recursive case
195+
# For nested duals, u itself could be a dual number with its own partials
196+
inner_dual_type = V
197+
outer_tag_type = T
198+
199+
# Reconstruct the nested dual by first building the inner dual, then the outer one
200+
inner_dual = u # u is already a dual for the inner level
201+
202+
# Create outer dual with the inner dual as its value
203+
return Dual{outer_tag_type, typeof(inner_dual), P}(inner_dual, partials)
204+
end
205+
206+
function linearsolve_dual_solution(u::AbstractArray, partials,
207+
dual_type::Type{<:Dual{T, V, P}}) where {T, V <: Dual, P}
208+
# Handle nested duals for arrays - recursive case
209+
inner_dual_type = V
210+
outer_tag_type = T
211+
212+
partials_list = RecursiveArrayTools.VectorOfArray(partials)
213+
214+
# For nested duals, each element of u could be a dual number with its own partials
215+
return map(((uᵢ, pᵢ),) -> Dual{outer_tag_type, typeof(uᵢ), P}(uᵢ, Partials(Tuple(pᵢ))),
216+
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
217+
end
218+
147219
function SciMLBase.init(
148220
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm,
149221
args...;
@@ -235,18 +307,23 @@ end
235307

236308

237309

238-
# Helper functions for Dual numbers
239-
get_dual_type(x::Dual) = typeof(x)
310+
# Enhanced helper functions for Dual numbers to handle recursion
311+
get_dual_type(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = typeof(x)
312+
get_dual_type(x::Dual{T, V, P}) where {T, V <: Dual, P} = typeof(x)
240313
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
241314
get_dual_type(x) = nothing
242315

243-
partial_vals(x::Dual) = ForwardDiff.partials(x)
316+
# Add recursive handling for nested dual partials
317+
partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.partials(x)
318+
partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x)
244319
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
245320
partial_vals(x) = nothing
246321

322+
# Add recursive handling for nested dual values
247323
nodual_value(x) = x
248-
nodual_value(x::Dual) = ForwardDiff.value(x)
249-
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
324+
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
325+
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
326+
nodual_value(x::AbstractArray{<:Dual}) = map(nodual_value, x)
250327

251328

252329
function partials_to_list(partial_matrix::Vector)

0 commit comments

Comments
 (0)