Skip to content

Commit e842eac

Browse files
committed
use AbstractVector, fix setproperty!
1 parent 5497adc commit e842eac

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,22 +181,26 @@ end
181181

182182
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
183183
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
184-
if sym === :A
184+
# If the property is A or b, also update it in the LinearCache
185+
if sym === :A || sym === :b || sym === :u
185186
setproperty!(dc.linear_cache, sym, nodual_value(val))
187+
elseif hasfield(DualLinearCache, sym)
188+
setfield!(dc, sym, val)
189+
elseif hasfield(LinearSolve.LinearCache, sym)
190+
setproperty!(dc.linear_cache, sym, val)
191+
end
192+
193+
194+
# Update the partials if setting A or b
195+
if sym === :A
186196
setfield!(dc, :dual_A, val)
187197
setfield!(dc, :partials_A, partial_vals(val))
188198
elseif sym === :b
189-
setproperty!(dc.linear_cache, sym, nodual_value(val))
190199
setfield!(dc, :dual_b, val)
191200
setfield!(dc, :partials_b, partial_vals(val))
192201
elseif sym === :u
193-
setproperty!(dc.linear_cache, sym, nodual_value(val))
194202
setfield!(dc, :dual_u, val)
195203
setfield!(dc, :partials_u, partial_vals(val))
196-
elseif hasfield(DualLinearCache, sym)
197-
setfield!(dc,sym,val)
198-
elseif hasfield(LinearSolve.LinearCache, sym)
199-
setproperty!(dc.linear_cache, sym, val)
200204
end
201205
end
202206

@@ -236,7 +240,7 @@ nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inn
236240
nodual_value(x::AbstractArray{<:Dual}) = map(nodual_value, x)
237241

238242

239-
function partials_to_list(partial_matrix::AbstractArray{T, 1}) where {T}
243+
function partials_to_list(partial_matrix::AbstractVector{T}) where {T}
240244
p = eachindex(first(partial_matrix))
241245
[[partial[i] for partial in partial_matrix] for i in p]
242246
end

0 commit comments

Comments
 (0)