Skip to content

Commit 6bdfdda

Browse files
Merge pull request #245 from vpuri3/vj
Fix VecJac
2 parents f3fd6ec + 6f1e267 commit 6bdfdda

File tree

7 files changed

+266
-127
lines changed

7 files changed

+266
-127
lines changed

Project.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,19 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1717
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1818
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
19+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1920
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2021
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
2122
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2223
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
2324
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
2425

26+
[weakdeps]
27+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
28+
29+
[extensions]
30+
SparseDiffToolsZygote = "Zygote"
31+
2532
[compat]
2633
ADTypes = "0.1"
2734
Adapt = "3.0"
@@ -33,17 +40,14 @@ ForwardDiff = "0.10"
3340
Graphs = "1"
3441
Reexport = "1"
3542
Requires = "1"
36-
SciMLOperators = "0.1.19, 0.2"
43+
SciMLOperators = "0.2.10"
3744
StaticArrayInterface = "1.3"
3845
StaticArrays = "1"
3946
Tricks = "0.1.6"
4047
VertexSafeGraphs = "0.2"
4148
Zygote = "0.6"
4249
julia = "1.6"
4350

44-
[extensions]
45-
SparseDiffToolsZygote = "Zygote"
46-
4751
[extras]
4852
ArrayInterfaceBandedMatrices = "2e50d22c-5be1-4042-81b1-c572ed69783d"
4953
ArrayInterfaceBlockBandedMatrices = "5331f1e9-51c7-46b0-a9b0-df4434785e0a"
@@ -60,6 +64,3 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6064

6165
[targets]
6266
test = ["Test", "ArrayInterfaceBandedMatrices", "ArrayInterfaceBlockBandedMatrices", "BandedMatrices", "BlockBandedMatrices", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]
63-
64-
[weakdeps]
65-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

ext/SparseDiffToolsZygote.jl

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
module SparseDiffToolsZygote
22

3-
if isdefined(Base, :get_extension)
4-
import Zygote
5-
using LinearAlgebra
6-
using SparseDiffTools: SparseDiffTools, DeivVecTag
7-
using ForwardDiff: ForwardDiff, Dual, partials
8-
else
9-
import ..Zygote
10-
using ..LinearAlgebra
11-
using ..SparseDiffTools: SparseDiffTools, DeivVecTag
12-
using ..ForwardDiff: ForwardDiff, Dual, partials
13-
end
3+
import Zygote
4+
using ADTypes
5+
using LinearAlgebra
6+
using SparseDiffTools: SparseDiffTools, DeivVecTag, AutoDiffVJP
7+
using ForwardDiff: ForwardDiff, Dual, partials
8+
import SciMLOperators: update_coefficients, update_coefficients!
9+
import Setfield: @set!
1410

1511
### Jac, Hes products
1612

@@ -75,14 +71,70 @@ end
7571

7672
## VecJac products
7773

78-
function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
79-
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false")
74+
# VJP methods
75+
function SparseDiffTools.auto_vecjac!(du, f, x, v)
76+
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = AutoFiniteDiff()")
8077
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
8178
end
8279

8380
function SparseDiffTools.auto_vecjac(f, x, v)
84-
vv, back = Zygote.pullback(f, x)
85-
return vec(back(reshape(v, size(vv)))[1])
81+
y, back = Zygote.pullback(f, x)
82+
return vec(back(reshape(v, size(y)))[1])
83+
end
84+
85+
# overload operator interface
86+
function SparseDiffTools._vecjac(f, u, autodiff::AutoZygote)
87+
88+
cache = ()
89+
pullback = Zygote.pullback(f, u)
90+
91+
AutoDiffVJP(f, u, cache, autodiff, pullback)
92+
end
93+
94+
function update_coefficients(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
95+
) where{AD <: AutoZygote}
96+
97+
if !isnothing(VJP_input)
98+
@set! L.u = VJP_input
99+
end
100+
101+
@set! L.f = update_coefficients(L.f, L.u, p, t)
102+
@set! L.pullback = Zygote.pullback(L.f, L.u)
103+
end
104+
105+
function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
106+
) where{AD <: AutoZygote}
107+
108+
if !isnothing(VJP_input)
109+
copy!(L.u, VJP_input)
110+
end
111+
112+
update_coefficients!(L.f, L.u, p, t)
113+
L.pullback = Zygote.pullback(L.f, L.u)
114+
115+
L
116+
end
117+
118+
# Interpret the call as df/du' * v
119+
function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing) where{AD <: AutoZygote}
120+
# ignore VJP_input as pullback was computed in update_coefficients(...)
121+
122+
y, back = L.pullback
123+
V = reshape(v, size(y))
124+
125+
back(V)[1] |> vec
126+
end
127+
128+
# prefer non in-place method
129+
function (L::AutoDiffVJP{AD, IIP, true})(dv, v, p, t; VJP_input = nothing) where {AD <: AutoZygote, IIP}
130+
# ignore VJP_input as pullback was computed in update_coefficients!(...)
131+
132+
_dv = L(v, p, t; VJP_input = VJP_input)
133+
copy!(dv, _dv)
134+
end
135+
136+
function (L::AutoDiffVJP{AD, true, false})(dv, v, p, t; VJP_input = nothing) where {AD <: AutoZygote}
137+
@error("Zygote requires an out of place method with signature f(u).")
86138
end
87139

88140
end # module

src/SparseDiffTools.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Graphs
77
using Graphs: SimpleGraph
88
using VertexSafeGraphs
99
using Adapt
10+
1011
using Reexport
1112
@reexport using ADTypes
1213

@@ -23,6 +24,7 @@ using ArrayInterface: matrix_colors
2324
using SciMLOperators
2425
import SciMLOperators: update_coefficients, update_coefficients!
2526
using Tricks: Tricks, static_hasmethod
27+
using Setfield: @set!
2628

2729
abstract type AbstractAutoDiffVecProd end
2830

src/differentiation/jaches_products.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ struct FwdModeAutoDiffVecProd{F, U, C, V, V!} <: AbstractAutoDiffVecProd
210210
end
211211

212212
function update_coefficients(L::FwdModeAutoDiffVecProd, u, p, t)
213-
f = update_coefficients(L.f, u, p, t)
214-
FwdModeAutoDiffVecProd(f, u, L.cache, L.vecprod, L.vecprod!)
213+
@set! L.f = update_coefficients(L.f, u, p, t)
214+
@set! L.u = u
215215
end
216216

217217
function update_coefficients!(L::FwdModeAutoDiffVecProd, u, p, t)
@@ -248,7 +248,7 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing;
248248
elseif autodiff isa AutoForwardDiff
249249
cache1 = Dual{
250250
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
251-
}.(u, ForwardDiff.Partials.(tuple.(u)))
251+
}.(u, ForwardDiff.Partials.(tuple.(u)))
252252

253253
cache2 = copy(cache1)
254254

src/differentiation/vecjac_products.jl

Lines changed: 89 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,67 @@ end
3737

3838
### Operator Forms
3939

40-
struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffVecProd
40+
"""
41+
VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
42+
43+
Returns SciMLOperators.FunctionOperator which computes vector-jacobian
44+
product `df/du * v`.
45+
46+
```
47+
L = VecJac(f, u)
48+
49+
L * v # = df/du * v
50+
mul!(w, L, v) # = df/du * v
51+
52+
L(v, p, t; VJP_input = w) # = df/dw * v
53+
L(x, v, p, t; VJP_input = w) # = df/dw * v
54+
```
55+
"""
56+
function VecJac(f, u::AbstractArray, p = nothing, t = nothing;
57+
autodiff = AutoFiniteDiff(), kwargs...)
58+
59+
L = _vecjac(f, u, autodiff)
60+
IIP, OOP = get_iip_oop(L)
61+
62+
if isa(autodiff, AutoZygote) & !OOP
63+
msg = "Zygote requires an out of place method with signature f(u)."
64+
throw(ArgumentError(msg))
65+
end
66+
67+
FunctionOperator(L, u, u; isinplace = IIP, outofplace = OOP,
68+
p = p, t = t, islinear = true,
69+
accepted_kwargs = (:VJP_input,), kwargs...)
70+
end
71+
72+
function _vecjac(f, u, autodiff::AutoFiniteDiff)
73+
74+
cache = (similar(u), similar(u))
75+
pullback = nothing
76+
77+
AutoDiffVJP(f, u, cache, autodiff, pullback)
78+
end
79+
80+
mutable struct AutoDiffVJP{AD, IIP, OOP, F, U, C, PB} <: AbstractAutoDiffVecProd
81+
""" Compute VJP of `f` at `u`, applied to vector `v`: `df/du' * u` """
4182
f::F
83+
""" input to `f` """
4284
u::U
85+
""" Cache for num_vecjac! when autodiff isa AutoFintieDiff """
4386
cache::C
44-
vecprod::V
45-
vecprod!::V!
87+
""" Type of automatic differentiation algorithm """
88+
autodiff::AD
89+
""" stores the result of Zygote.pullback for AutoZygote """
90+
pullback::PB
91+
92+
function AutoDiffVJP(f, u, cache, autodiff, pullback)
4693

47-
function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!;
48-
autodiff = AutoFiniteDiff(),
49-
isinplace = false, outofplace = true)
50-
@assert isinplace || outofplace
94+
outofplace = static_hasmethod(f, typeof((u,)))
95+
isinplace = static_hasmethod(f, typeof((u, u)))
96+
97+
if !(isinplace) & !(outofplace)
98+
msg = "$f must have signature f(u), or f(du, u)"
99+
throw(ArgumentError(msg))
100+
end
51101

52102
new{
53103
typeof(autodiff),
@@ -56,72 +106,58 @@ struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffV
56106
typeof(f),
57107
typeof(u),
58108
typeof(cache),
59-
typeof(vecprod),
60-
typeof(vecprod!)
61-
}(f, u, cache, vecprod, vecprod!)
109+
typeof(pullback),
110+
}(
111+
f, u, cache, autodiff, pullback,
112+
)
62113
end
63114
end
64115

65-
function update_coefficients(L::RevModeAutoDiffVecProd, u, p, t)
66-
f = update_coefficients(L.f, u, p, t)
67-
RevModeAutoDiffVecProd(f, u, L.vecprod, L.vecprod!, L.cache)
116+
function get_iip_oop(::AutoDiffVJP{AD, IIP, OOP}) where{AD, IIP, OOP}
117+
IIP, OOP
68118
end
69119

70-
function update_coefficients!(L::RevModeAutoDiffVecProd, u, p, t)
71-
update_coefficients!(L.f, u, p, t)
72-
copy!(L.u, u)
73-
L
120+
function update_coefficients(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
121+
) where{AD <: AutoFiniteDiff}
122+
123+
if !isnothing(VJP_input)
124+
@set! L.u = VJP_input
125+
end
126+
127+
@set! L.f = update_coefficients(L.f, L.u, p, t)
74128
end
75129

76-
# Interpret the call as df/du' * u
77-
function (L::RevModeAutoDiffVecProd)(v, p, t)
78-
L.vecprod(L.f, L.u, v)
130+
function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
131+
) where{AD <: AutoFiniteDiff}
132+
133+
if !isnothing(VJP_input)
134+
copy!(L.u, VJP_input)
135+
end
136+
137+
update_coefficients!(L.f, L.u, p, t)
138+
139+
L
79140
end
80141

81-
# prefer non in-place method
82-
function (L::RevModeAutoDiffVecProd{ad, iip, true})(dv, v, p, t) where {ad, iip}
83-
L.vecprod!(dv, L.f, L.u, v, L.cache...)
142+
# Interpret the call as df/du' * v
143+
function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing,) where{AD <: AutoFiniteDiff}
144+
# ignore VJP_input as L.u was set in update_coefficients(...)
145+
num_vecjac(L.f, L.u, v)
84146
end
85147

86-
function (L::RevModeAutoDiffVecProd{ad, true, false})(dv, v, p, t) where {ad}
87-
L.vecprod!(dv, L.f, L.u, v, L.cache...)
148+
function (L::AutoDiffVJP{AD})(dv, v, p, t; VJP_input = nothing,) where{AD <: AutoFiniteDiff}
149+
# ignore VJP_input as L.u was set in update_coefficients!(...)
150+
num_vecjac!(dv, L.f, L.u, v, L.cache...)
88151
end
89152

90-
function Base.resize!(L::RevModeAutoDiffVecProd, n::Integer)
153+
function Base.resize!(L::AutoDiffVJP, n::Integer)
91154

92155
static_hasmethod(resize!, typeof((L.f, n))) && resize!(L.f, n)
93156
resize!(L.u, n)
94157

95158
for v in L.cache
96159
resize!(v, n)
97160
end
98-
end
99-
100-
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
101-
kwargs...)
102-
vecprod, vecprod! = if autodiff isa AutoFiniteDiff
103-
num_vecjac, num_vecjac!
104-
elseif autodiff isa AutoZygote
105-
@assert static_hasmethod(auto_vecjac, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
106161

107-
auto_vecjac, auto_vecjac!
108-
end
109-
110-
cache = (similar(u), similar(u))
111-
112-
outofplace = static_hasmethod(f, typeof((u,)))
113-
isinplace = static_hasmethod(f, typeof((u, u)))
114-
115-
if !(isinplace) & !(outofplace)
116-
error("$f must have signature f(u), or f(du, u)")
117-
end
118-
119-
L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = autodiff,
120-
isinplace = isinplace, outofplace = outofplace)
121-
122-
FunctionOperator(L, u, u;
123-
isinplace = isinplace, outofplace = outofplace,
124-
p = p, t = t, islinear = true,
125-
kwargs...)
126162
end
127163
#

test/test_jaches_products.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ L = JacVec(f, copy(x), 1.0, 1.0; tag = MyTag())
143143

144144
# Resize test
145145
for M in (100, 400)
146-
L = JacVec(f2, copy(x), 1.0, 1.0)
146+
local L = JacVec(f2, copy(x), 1.0, 1.0)
147147
resize!(L, M)
148148
_x = resize!(copy(x), M)
149149
_u = rand(M)

0 commit comments

Comments
 (0)