Skip to content

Commit 7e06e45

Browse files
committed
Use allowed_setindex! for cache
1 parent 4ad919a commit 7e06e45

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

src/SparseDiffTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import ADTypes: AbstractADType, AutoSparseZygote, AbstractSparseForwardMode,
1414
import ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
1515
# Array Packages
1616
using ArrayInterface, SparseArrays
17-
import ArrayInterface: matrix_colors
17+
import ArrayInterface: matrix_colors, allowed_setindex!
1818
import StaticArrays
1919
import StaticArrays: StaticArray, SArray, MArray, Size
2020
# Others

src/differentiation/vecjac_products.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,21 @@ function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v),
1515
return du
1616
end
1717

18+
# Special Non-Allocating case for StaticArrays
19+
function num_vecjac(f::F, x::SArray, v::SArray, f0 = nothing) where {F}
20+
f0 === nothing ? (_f0 = f(x)) : (_f0 = f0)
21+
vv = reshape(v, axes(_f0))
22+
T = eltype(x)
23+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
24+
du = zeros(typeof(x))
25+
for i in 1:length(x)
26+
cache = Base.setindex(x, x[i] + ϵ, i)
27+
f0 = f(cache)
28+
du = Base.setindex(du, (((f0 .- _f0) ./ ϵ)' * vv), i)
29+
end
30+
return du
31+
end
32+
1833
function num_vecjac(f::F, x, v, f0 = nothing) where {F}
1934
f0 === nothing ? (_f0 = f(x)) : (_f0 = f0)
2035
vv = reshape(v, axes(_f0))
@@ -25,10 +40,10 @@ function num_vecjac(f::F, x, v, f0 = nothing) where {F}
2540
cache = similar(x)
2641
copyto!(cache, x)
2742
for i in 1:length(x)
28-
cache[i] += ϵ
43+
cache = allowed_setindex!(cache, x[i] + ϵ, i)
2944
f0 = f(cache)
30-
cache[i] = x[i]
31-
du[i] = (((f0 .- _f0) ./ ϵ)' * vv)[1]
45+
cache = allowed_setindex!(cache, x[i], i)
46+
du = allowed_setindex!(du, (((f0 .- _f0) ./ ϵ)' * vv)[1], i)
3247
end
3348
return vec(du)
3449
end

test/test_vecjac_products.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using SparseDiffTools, Zygote
1+
using SparseDiffTools, Zygote, ForwardDiff
22
using LinearAlgebra, Test
3+
using StaticArrays
34

45
using Random
56
Random.seed!(123)
@@ -170,3 +171,17 @@ L = VecJac(f3_iip, copy(x); autodiff = AutoFiniteDiff(), fu = copy(y))
170171
L = VecJac(f3_oop, copy(x); autodiff = AutoZygote())
171172
@test size(L) == (length(x), length(y))
172173
@test L * y Zygote.jacobian(f3_oop, copy(x))[1]' * y
174+
175+
@info "Testing StaticArrays"
176+
177+
const A_sa = rand(SMatrix{4, 4, Float32})
178+
_f_sa(x) = A_sa * (x .^ 2)
179+
180+
u = rand(SVector{4, Float32})
181+
v = rand(SVector{4, Float32})
182+
183+
J = ForwardDiff.jacobian(_f_sa, u)
184+
Jᵀv_true = J' * v
185+
186+
@test num_vecjac(_f_sa, u, v) isa SArray
187+
@test num_vecjac(_f_sa, u, v)Jᵀv_true atol=1e-3

0 commit comments

Comments
 (0)