Skip to content

Commit bbd1f6e

Browse files
Merge pull request #282 from avik-pal/ap/test
Fix master
2 parents 61930e3 + 7e06e45 commit bbd1f6e

File tree

6 files changed

+48
-15
lines changed

6 files changed

+48
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
4-
version = "2.15.0"
4+
version = "2.15.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/SparseDiffToolsZygoteExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ end
4545

4646
### Jac, Hes products
4747

48-
function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v)) where {F}
48+
function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v),
49+
cache3 = similar(v)) where {F}
4950
g = let f = f
5051
(dx, x) -> dx .= first(Zygote.gradient(f, x))
5152
end

src/SparseDiffTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import ADTypes: AbstractADType, AutoSparseZygote, AbstractSparseForwardMode,
1414
import ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
1515
# Array Packages
1616
using ArrayInterface, SparseArrays
17-
import ArrayInterface: matrix_colors
17+
import ArrayInterface: matrix_colors, allowed_setindex!
1818
import StaticArrays
1919
import StaticArrays: StaticArray, SArray, MArray, Size
2020
# Others

src/differentiation/jaches_products.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ function auto_jacvec(f, x, v)
3232
vec(partials.(vec(f(y)), 1))
3333
end
3434

35-
function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v);
36-
compute_f0 = true)
35+
function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v),
36+
cache3 = similar(v); compute_f0 = true)
3737
vv = reshape(v, axes(x))
3838
compute_f0 && (f(cache1, x))
3939
T = eltype(x)
@@ -134,7 +134,8 @@ function autonum_hesvec(f, x, v)
134134
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
135135
end
136136

137-
function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v))
137+
function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v),
138+
cache3 = similar(v))
138139
T = eltype(x)
139140
# Should it be min? max? mean?
140141
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))

src/differentiation/vecjac_products.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v);
2-
compute_f0 = true) where {F}
1+
function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v),
2+
cache3 = similar(x); compute_f0 = true) where {F}
33
compute_f0 && (f(cache1, x))
44
T = eltype(x)
55
# Should it be min? max? mean?
@@ -15,19 +15,35 @@ function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v), c
1515
return du
1616
end
1717

18+
# Special Non-Allocating case for StaticArrays
19+
function num_vecjac(f::F, x::SArray, v::SArray, f0 = nothing) where {F}
20+
f0 === nothing ? (_f0 = f(x)) : (_f0 = f0)
21+
vv = reshape(v, axes(_f0))
22+
T = eltype(x)
23+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
24+
du = zeros(typeof(x))
25+
for i in 1:length(x)
26+
cache = Base.setindex(x, x[i] + ϵ, i)
27+
f0 = f(cache)
28+
du = Base.setindex(du, (((f0 .- _f0) ./ ϵ)' * vv), i)
29+
end
30+
return du
31+
end
32+
1833
function num_vecjac(f::F, x, v, f0 = nothing) where {F}
1934
f0 === nothing ? (_f0 = f(x)) : (_f0 = f0)
2035
vv = reshape(v, axes(_f0))
2136
T = eltype(x)
2237
# Should it be min? max? mean?
2338
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
2439
du = similar(x)
25-
cache = copy(x)
40+
cache = similar(x)
41+
copyto!(cache, x)
2642
for i in 1:length(x)
27-
cache[i] += ϵ
28-
f0 = f(x)
29-
cache[i] = x[i]
30-
du[i] = (((f0 .- _f0) ./ ϵ)' * vv)[1]
43+
cache = allowed_setindex!(cache, x[i] + ϵ, i)
44+
f0 = f(cache)
45+
cache = allowed_setindex!(cache, x[i], i)
46+
du = allowed_setindex!(du, (((f0 .- _f0) ./ ϵ)' * vv)[1], i)
3147
end
3248
return vec(du)
3349
end
@@ -93,7 +109,7 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
93109
end
94110

95111
function _vecjac(f::F, fu, u, autodiff::AutoFiniteDiff) where {F}
96-
cache = (similar(fu), similar(fu), similar(fu))
112+
cache = (similar(fu), similar(fu), similar(u))
97113
pullback = nothing
98114
return AutoDiffVJP(f, u, cache, autodiff, pullback)
99115
end

test/test_vecjac_products.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using SparseDiffTools, Zygote
1+
using SparseDiffTools, Zygote, ForwardDiff
22
using LinearAlgebra, Test
3+
using StaticArrays
34

45
using Random
56
Random.seed!(123)
@@ -170,3 +171,17 @@ L = VecJac(f3_iip, copy(x); autodiff = AutoFiniteDiff(), fu = copy(y))
170171
L = VecJac(f3_oop, copy(x); autodiff = AutoZygote())
171172
@test size(L) == (length(x), length(y))
172173
@test L * y Zygote.jacobian(f3_oop, copy(x))[1]' * y
174+
175+
@info "Testing StaticArrays"
176+
177+
const A_sa = rand(SMatrix{4, 4, Float32})
178+
_f_sa(x) = A_sa * (x .^ 2)
179+
180+
u = rand(SVector{4, Float32})
181+
v = rand(SVector{4, Float32})
182+
183+
J = ForwardDiff.jacobian(_f_sa, u)
184+
Jᵀv_true = J' * v
185+
186+
@test num_vecjac(_f_sa, u, v) isa SArray
187+
@test num_vecjac(_f_sa, u, v)Jᵀv_true atol=1e-3

0 commit comments

Comments
 (0)