Skip to content

Commit 16679fe

Browse files
Merge pull request #242 from vpuri3/resize
overload Base.resize!
2 parents 8d9a112 + 953afc3 commit 16679fe

File tree

4 files changed

+51
-0
lines changed

4 files changed

+51
-0
lines changed

src/differentiation/jaches_products.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,16 @@ function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
228228
L.vecprod!(dv, L.f, L.u, v, L.cache...)
229229
end
230230

231+
function Base.resize!(L::FwdModeAutoDiffVecProd, n::Integer)
232+
233+
static_hasmethod(resize!, typeof((L.f, n))) && resize!(L.f, n)
234+
resize!(L.u, n)
235+
236+
for v in L.cache
237+
resize!(v, n)
238+
end
239+
end
240+
231241
function JacVec(f, u::AbstractArray, p = nothing, t = nothing;
232242
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
233243
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff

src/differentiation/vecjac_products.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ function (L::RevModeAutoDiffVecProd{ad, true, false})(dv, v, p, t) where {ad}
8787
L.vecprod!(dv, L.f, L.u, v, L.cache...)
8888
end
8989

90+
function Base.resize!(L::RevModeAutoDiffVecProd, n::Integer)
91+
92+
static_hasmethod(resize!, typeof((L.f, n))) && resize!(L.f, n)
93+
resize!(L.u, n)
94+
95+
for v in L.cache
96+
resize!(v, n)
97+
end
98+
end
99+
90100
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
91101
kwargs...)
92102
vecprod, vecprod! = if autodiff isa AutoFiniteDiff

test/test_jaches_products.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ function _h(dy, x)
3232
FiniteDiff.finite_difference_gradient!(dy, _g, x)
3333
end
3434

35+
f2(x) = 2x
36+
f2(y, x) = (copy!(y, x); lmul!(2, y); y)
37+
3538
# Make functions state-dependent for operator tests
3639

3740
include("update_coeffs_testutils.jl")
@@ -138,6 +141,17 @@ L = JacVec(f, copy(x), 1.0, 1.0)
138141
L = JacVec(f, copy(x), 1.0, 1.0; tag = MyTag())
139142
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{MyTag, eltype(x)}
140143

144+
# Resize test
145+
for M in (100, 400)
146+
L = JacVec(f2, copy(x), 1.0, 1.0)
147+
resize!(L, M)
148+
_x = resize!(copy(x), M)
149+
_u = rand(M)
150+
151+
@test L * _u auto_jacvec(f2, _x, _u)
152+
_v = zeros(M); @test mul!(_v, L, _u) auto_jacvec(f2, _x, _u)
153+
end
154+
141155
@info "HesVec"
142156

143157
L = HesVec(g, copy(x), 1.0, 1.0, autodiff = AutoFiniteDiff())

test/test_vecjac_products.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,21 @@ actual_jac = Zygote.jacobian(f, v)[1]
6363
# Test that x and v were not mutated
6464
@test x x0
6565
@test v v0
66+
67+
# Resize test
68+
f2(x) = 2x
69+
f2(y, x) = (copy!(y, x); lmul!(2, y); y)
70+
71+
for M in (100, 400)
72+
L = VecJac(f2, copy(x), 1.0f0, 1.0f0; autodiff = AutoZygote())
73+
resize!(L, M)
74+
_x = resize!(copy(x), M)
75+
_u = rand(M)
76+
77+
J2 = Zygote.jacobian(f2, _x)[1]
78+
79+
@test L * _u J2' * _u rtol=1e-6
80+
_v = zeros(M); @test mul!(_v, L, _u) J2' * _u rtol=1e-6
81+
end
82+
6683
#

0 commit comments

Comments
 (0)