Skip to content

Commit 4ad919a

Browse files
committed
Fix num_vecjac
1 parent 1156e34 commit 4ad919a

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
4-
version = "2.15.0"
4+
version = "2.15.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/SparseDiffToolsZygoteExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ end
4545

4646
### Jac, Hes products
4747

48-
function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v)) where {F}
48+
function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v),
49+
cache3 = similar(v)) where {F}
4950
g = let f = f
5051
(dx, x) -> dx .= first(Zygote.gradient(f, x))
5152
end

src/differentiation/jaches_products.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ function auto_jacvec(f, x, v)
3232
vec(partials.(vec(f(y)), 1))
3333
end
3434

35-
function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v);
36-
compute_f0 = true)
35+
function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v),
36+
cache3 = similar(v); compute_f0 = true)
3737
vv = reshape(v, axes(x))
3838
compute_f0 && (f(cache1, x))
3939
T = eltype(x)
@@ -134,7 +134,8 @@ function autonum_hesvec(f, x, v)
134134
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
135135
end
136136

137-
function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v))
137+
function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v),
138+
cache3 = similar(v))
138139
T = eltype(x)
139140
# Should it be min? max? mean?
140141
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))

src/differentiation/vecjac_products.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v);
2-
compute_f0 = true) where {F}
1+
function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v),
2+
cache3 = similar(x); compute_f0 = true) where {F}
33
compute_f0 && (f(cache1, x))
44
T = eltype(x)
55
# Should it be min? max? mean?
@@ -22,10 +22,11 @@ function num_vecjac(f::F, x, v, f0 = nothing) where {F}
2222
# Should it be min? max? mean?
2323
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
2424
du = similar(x)
25-
cache = copy(x)
25+
cache = similar(x)
26+
copyto!(cache, x)
2627
for i in 1:length(x)
2728
cache[i] += ϵ
28-
f0 = f(x)
29+
f0 = f(cache)
2930
cache[i] = x[i]
3031
du[i] = (((f0 .- _f0) ./ ϵ)' * vv)[1]
3132
end
@@ -93,7 +94,7 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
9394
end
9495

9596
function _vecjac(f::F, fu, u, autodiff::AutoFiniteDiff) where {F}
96-
cache = (similar(fu), similar(fu), similar(fu))
97+
cache = (similar(fu), similar(fu), similar(u))
9798
pullback = nothing
9899
return AutoDiffVJP(f, u, cache, autodiff, pullback)
99100
end

0 commit comments

Comments
 (0)