Skip to content

Commit 46d5f64

Browse files
committed
use kwarg VJP_input for recomputing pullback
1 parent e0fc9c4 commit 46d5f64

File tree

3 files changed

+144
-65
lines changed

3 files changed

+144
-65
lines changed

ext/SparseDiffToolsZygote.jl

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,33 @@ function SparseDiffTools._vecjac(f, u, autodiff::AutoZygote)
9191
AutoDiffVJP(f, u, cache, autodiff, pullback)
9292
end
9393

94-
function update_coefficients(L::AutoDiffVJP{AD}, u, p, t) where{AD <: AutoZygote}
95-
@set! L.f = update_coefficients(L.f, u, p, t)
96-
@set! L.u = u
97-
@set! L.pullback = Zygote.pullback(L.f, u)
94+
function update_coefficients(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
95+
) where{AD <: AutoZygote}
96+
97+
if !isnothing(VJP_input)
98+
@set! L.u = VJP_input
99+
end
100+
101+
@set! L.f = update_coefficients(L.f, L.u, p, t)
102+
@set! L.pullback = Zygote.pullback(L.f, L.u)
98103
end
99104

100-
function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t) where{AD <: AutoZygote}
101-
update_coefficients!(L.f, u, p, t)
102-
copy!(L.u, u)
103-
L.pullback = Zygote.pullback(L.f, u)
105+
function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
106+
) where{AD <: AutoZygote}
107+
108+
if !isnothing(VJP_input)
109+
copy!(L.u, VJP_input)
110+
end
111+
112+
update_coefficients!(L.f, L.u, p, t)
113+
L.pullback = Zygote.pullback(L.f, L.u)
114+
104115
L
105116
end
106117

107118
# Interpret the call as df/du' * v
108-
function (L::AutoDiffVJP{AD})(v, p, t) where{AD <: AutoZygote}
119+
function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing) where{AD <: AutoZygote}
120+
# ignore VJP_input as pullback was computed in update_coefficients(...)
109121

110122
y, back = L.pullback
111123
V = reshape(v, size(y))
@@ -114,13 +126,15 @@ function (L::AutoDiffVJP{AD})(v, p, t) where{AD <: AutoZygote}
114126
end
115127

116128
# prefer non in-place method
117-
function (L::AutoDiffVJP{AD, IIP, true})(dv, v, p, t) where {AD <: AutoZygote, IIP}
129+
function (L::AutoDiffVJP{AD, IIP, true})(dv, v, p, t; VJP_input = nothing) where {AD <: AutoZygote, IIP}
130+
# ignore VJP_input as pullback was computed in update_coefficients!(...)
131+
118132
_dv = L(v, p, t)
119133
copy!(dv, _dv)
120134
end
121135

122-
function (L::AutoDiffVJP{AD, true, false})(dv, v, p, t) where {AD <: AutoZygote}
123-
SparseDiffTools.auto_vecjac!(dv, L.f, L.u, v, L.cache...)
136+
function (L::AutoDiffVJP{AD, true, false})(dv, v, p, t; VJP_input = nothing) where {AD <: AutoZygote}
137+
@error("Zygote requires an out of place method with signature f(u).")
124138
end
125139

126140
end # module

src/differentiation/vecjac_products.jl

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,34 @@ end
3939

4040
"""
4141
VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
42+
43+
Returns SciMLOperators.FunctionOperator which computes vector-jacobian
44+
product `df/du * v`.
45+
46+
```
47+
L = VecJac(f, u)
48+
49+
L * v # = df/du * v
50+
mul!(w, L, v) # = df/du * v
51+
52+
L(v, p, t; VJP_input = w) # = df/dw * v
53+
L(x, v, p, t; VJP_input = w) # = df/dw * v
54+
```
4255
"""
4356
function VecJac(f, u::AbstractArray, p = nothing, t = nothing;
4457
autodiff = AutoFiniteDiff(), kwargs...)
4558

4659
L = _vecjac(f, u, autodiff)
4760
IIP, OOP = get_iip_oop(L)
4861

62+
if isa(autodiff, AutoZygote) & !OOP
63+
msg = "Zygote requires an out of place method with signature f(u)."
64+
throw(ArgumentError(msg))
65+
end
66+
4967
FunctionOperator(L, u, u; isinplace = IIP, outofplace = OOP,
50-
p = p, t = t, islinear = true, kwargs...)
68+
p = p, t = t, islinear = true,
69+
accepted_kwargs = (:VJP_input,), kwargs...)
5170
end
5271

5372
function _vecjac(f, u, autodiff::AutoFiniteDiff)
@@ -59,10 +78,15 @@ function _vecjac(f, u, autodiff::AutoFiniteDiff)
5978
end
6079

6180
mutable struct AutoDiffVJP{AD, IIP, OOP, F, U, C, PB} <: AbstractAutoDiffVecProd
81+
""" Compute VJP of `f` at `u`, applied to vector `v`: `df/du' * u` """
6282
f::F
83+
""" input to `f` """
6384
u::U
85+
""" Cache for num_vecjac! when autodiff isa AutoFintieDiff """
6486
cache::C
87+
""" Type of automatic differentiation algorithm """
6588
autodiff::AD
89+
""" stores the result of Zygote.pullback for AutoZygote """
6690
pullback::PB
6791

6892
function AutoDiffVJP(f, u, cache, autodiff, pullback)
@@ -93,23 +117,36 @@ function get_iip_oop(::AutoDiffVJP{AD, IIP, OOP}) where{AD, IIP, OOP}
93117
IIP, OOP
94118
end
95119

96-
function update_coefficients(L::AutoDiffVJP{AD}, u, p, t) where{AD <: AutoFiniteDiff}
97-
@set! L.f = update_coefficients(L.f, u, p, t)
98-
@set! L.u = u
120+
function update_coefficients(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
121+
) where{AD <: AutoFiniteDiff}
122+
123+
if !isnothing(VJP_input)
124+
@set! L.u = VJP_input
125+
end
126+
127+
@set! L.f = update_coefficients(L.f, L.u, p, t)
99128
end
100129

101-
function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t) where{AD <: AutoFiniteDiff}
102-
update_coefficients!(L.f, u, p, t)
103-
copy!(L.u, u)
130+
function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
131+
) where{AD <: AutoFiniteDiff}
132+
133+
if !isnothing(VJP_input)
134+
copy!(L.u, VJP_input)
135+
end
136+
137+
update_coefficients!(L.f, L.u, p, t)
138+
104139
L
105140
end
106141

107142
# Interpret the call as df/du' * v
108-
function (L::AutoDiffVJP{AD})(v, p, t) where{AD <: AutoFiniteDiff}
143+
function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing,) where{AD <: AutoFiniteDiff}
144+
# ignore VJP_input as L.u was set in update_coefficients(...)
109145
num_vecjac(L.f, L.u, v)
110146
end
111147

112-
function (L::AutoDiffVJP{AD})(dv, v, p, t) where{AD <: AutoFiniteDiff}
148+
function (L::AutoDiffVJP{AD})(dv, v, p, t; VJP_input = nothing,) where{AD <: AutoFiniteDiff}
149+
# ignore VJP_input as L.u was set in update_coefficients!(...)
113150
num_vecjac!(dv, L.f, L.u, v, L.cache...)
114151
end
115152

test/test_vecjac_products.jl

Lines changed: 72 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ Random.seed!(123)
66
N = 300
77

88
# Use Float32 since Zygote defaults to Float32
9-
x = rand(Float32, N)
9+
x1 = rand(Float32, N)
10+
x2 = rand(Float32, N)
11+
1012
v = rand(Float32, N)
1113

1214
# Save original values of x and v to make sure they are not ever mutated
13-
x0 = copy(x)
14-
v0 = copy(v)
15+
_x1 = copy(x1)
16+
_x2 = copy(x2)
17+
_v = copy(v)
1518

1619
a, b = rand(Float32, 2)
1720

@@ -23,81 +26,106 @@ _f(x) = A * (x .^ 2)
2326
include("update_coeffs_testutils.jl")
2427
f = WrapFunc(_f, 1.0f0, 1.0f0)
2528

26-
@test auto_vecjac(f, x, v) Zygote.jacobian(f, x)[1]' * v
27-
@test auto_vecjac!(zero(x), f, x, v) auto_vecjac(f, x, v)
28-
@test num_vecjac!(zero(x), f, copy(x), v) num_vecjac(f, copy(x), v)
29-
@test auto_vecjac(f, x, v) num_vecjac(f, copy(x), copy(v)) rtol = 1e-2
29+
@test auto_vecjac(f, x1, v) Zygote.jacobian(f, x1)[1]' * v
30+
@test auto_vecjac!(zero(x1), f, x1, v) auto_vecjac(f, x1, v)
31+
@test num_vecjac!(zero(x1), f, copy(x1), v) num_vecjac(f, copy(x1), v)
32+
@test auto_vecjac(f, x1, v) num_vecjac(f, copy(x1), copy(v)) rtol = 1e-2
3033

3134
# Compute Jacobian via Zygote
3235

3336
@info "VecJac AutoZygote"
3437

35-
L = VecJac(f, copy(x), 1.0f0, 1.0f0; autodiff = AutoZygote())
38+
p, t = rand(Float32, 2)
39+
L = VecJac(f, copy(x1), p, t; autodiff = AutoZygote())
40+
update_coefficients!(L, v, p, t)
3641

37-
Jx = Zygote.jacobian(f, x)[1]
38-
Jv = Zygote.jacobian(f, v)[1]
42+
update_coefficients!(f, v, p, t)
43+
J1 = Zygote.jacobian(f, x1)[1]
44+
J2 = Zygote.jacobian(f, x2)[1]
3945

40-
@test L * x Jx' * x
41-
@test L * v Jx' * v
42-
y=zero(x); @test mul!(y, L, v) Jx' * v
43-
y=zero(x); @test mul!(y, L, v) Jx' * v
46+
# test operator application
47+
@test L * v J1' * v
48+
@test L(v, p, t) J1' * v
49+
y=zeros(N); @test mul!(y, L, v) J1' * v
50+
y=zeros(N); @test L(y, v, p, t) J1' * v
4451

45-
@test L(x, 1.0f0, 1.0f0) Jx' * x
46-
y=zero(x); @test L(y, x, 1.0f0, 1.0f0) Jx' * x
47-
@test L(v, 1.0f0, 1.0f0) Jv' * v
48-
y=zero(v); @test L(y, v, 1.0f0, 1.0f0) Jv' * v
52+
# use kwarg VJP_input = x2
53+
@test L(v, p, t; VJP_input = x2) J2' * v
54+
y=zeros(N); @test L(y, v, p, t; VJP_input = x2) J2' * v
4955

50-
update_coefficients!(L, v, 3.0, 4.0)
56+
# update_coefficients
57+
p, t = rand(Float32, 2)
58+
L = update_coefficients(L, v, p, t; JVP_input = x2)
5159

52-
Jx = Zygote.jacobian(f, x)[1]
53-
Jv = Zygote.jacobian(f, v)[1]
60+
update_coefficients!(f, v, p, t)
61+
J1 = Zygote.jacobian(f, x1)[1]
62+
J2 = Zygote.jacobian(f, x2)[1]
5463

55-
@test L * x Jv' * x
56-
@test L * v Jv' * v
57-
y=zero(x); @test mul!(y, L, v) Jv' * v
58-
y=zero(x); @test mul!(y, L, v) Jv' * v
64+
# @show p, t
65+
# @show f.p, f.t
66+
# @show L.op.f.p, L.op.f.t
5967

60-
@test L(x, 3.0f0, 4.0f0) Jx' * x
61-
y=zero(x); @test L(y, x, 3.0f0, 4.0f0) Jx' * x
62-
@test L(v, 3.0f0, 4.0f0) Jv' * v
63-
y=zero(v); @test L(y, v, 3.0f0, 4.0f0) Jv' * v
68+
@test L * v J2' * v
69+
@test L(v, p, t) J2' * v
70+
y=zeros(N); @test mul!(y, L, v) J2' * v
71+
y=zeros(N); @test L(y, v, p, t) J2' * v
72+
73+
# use kwarg VJP_input = x1
74+
@test L(v, p, t; VJP_input = x1) J1' * v
75+
y=zeros(N); @test L(y, v, p, t; VJP_input = x1) J1' * v
6476

6577
@info "VecJac AutoFiniteDiff"
6678

67-
L = VecJac(f, copy(x), 1.0f0, 1.0f0; autodiff = AutoFiniteDiff())
79+
p, t = rand(Float32, 2)
80+
L = VecJac(f, copy(x1), 1.0f0, 1.0f0; autodiff = AutoFiniteDiff())
81+
update_coefficients!(L, v, p, t)
82+
update_coefficients!(f, v, p, t)
83+
84+
@test L * v num_vecjac(f, copy(x1), v)
85+
@test L(v, p, t) num_vecjac(f, copy(x1), v)
86+
y=zeros(N); @test mul!(y, L, v) num_vecjac(f, copy(x1), v)
87+
y=zeros(N); @test L(y, v, p, t) num_vecjac(f, copy(x1), v)
6888

69-
@test L * x num_vecjac(f, copy(x), x)
70-
@test L * v num_vecjac(f, copy(x), v)
71-
y=zero(x); @test mul!(y, L, v) num_vecjac(f, copy(x), v)
89+
# use kwarg VJP_input = x2
90+
@test L(v, p, t; VJP_input = x2) num_vecjac(f, copy(x2), v)
91+
y=zeros(N); @test L(y, v, p, t; VJP_input = x2) num_vecjac(f, copy(x2), v)
7292

73-
update_coefficients!(L, v, 3.0, 4.0)
74-
@test mul!(y, L, x) num_vecjac(f, copy(v), x)
75-
_y = copy(y); @test mul!(y, L, x, a, b) a * num_vecjac(f,copy(v),x) + b * _y
93+
# update_coefficients
94+
p, t = rand(Float32, 2)
95+
L = update_coefficients(L, v, p, t; JVP_input = x2)
96+
update_coefficients!(f, v, p, t)
7697

77-
update_coefficients!(f, v, 5.0, 6.0)
78-
@test L(y, v, 5.0, 6.0) num_vecjac(f, copy(v), v)
98+
@test L * v num_vecjac(f, copy(x2), v)
99+
@test L(v, p, t) num_vecjac(f, copy(x2), v)
100+
y=zeros(N); @test mul!(y, L, v) num_vecjac(f, copy(x2), v)
101+
y=zeros(N); @test L(y, v, p, t) num_vecjac(f, copy(x2), v)
102+
103+
# use kwarg VJP_input = x2
104+
@test L(v, p, t; VJP_input = x1) num_vecjac(f, copy(x1), v)
105+
y=zeros(N); @test L(y, v, p, t; VJP_input = x1) num_vecjac(f, copy(x1), v)
79106

80107
# Test that x and v were not mutated
81-
@test x x0
82-
@test v v0
108+
@test x1 _x1
109+
@test x2 _x2
110+
@test v v
83111

84112
@info "Base.resize!"
85113

86114
# Resize test
87115
f2(x) = 2x
88116
f2(y, x) = (copy!(y, x); lmul!(2, y); y)
89117

118+
x = rand(Float32, N)
90119
for M in (100, 400)
91120
local L = VecJac(f2, copy(x), 1.0f0, 1.0f0; autodiff = AutoZygote())
92121
resize!(L, M)
93122

94123
_x = resize!(copy(x), M)
95124
_u = rand(M)
96-
J2 = Zygote.jacobian(f2, _x)[1]
125+
local J2 = Zygote.jacobian(f2, _x)[1]
97126

98-
update_coefficients!(L, _x, 1.0f0, 1.0f0)
127+
update_coefficients!(L, _u, 1.0f0, 1.0f0; VJP_input = _x)
99128
@test L * _u J2' * _u rtol=1e-6
100-
_v = zeros(M); @test mul!(_v, L, _u) J2' * _u rtol=1e-6
129+
local _v = zeros(M); @test mul!(_v, L, _u) J2' * _u rtol=1e-6
101130
end
102-
103131
#

0 commit comments

Comments
 (0)