Skip to content

Commit 7c0aa99

Browse files
authored
Adapt to new Enzyme and DifferentiationInterface (#156)
1 parent 6195cd3 commit 7c0aa99

6 files changed

+81
-109
lines changed

Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
4-
version = "0.6.1"
4+
version = "0.6.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -21,14 +21,14 @@ ImplicitDifferentiationEnzymeExt = "Enzyme"
2121
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
2222

2323
[compat]
24-
ADTypes = "1.7.1"
25-
ChainRulesCore = "1.23.0"
26-
DifferentiationInterface = "0.5.12"
27-
Enzyme = "0.11.20,0.12"
24+
ADTypes = "1.9.0"
25+
ChainRulesCore = "1.25.0"
26+
DifferentiationInterface = "0.6.1"
27+
Enzyme = "0.13.3"
2828
ForwardDiff = "0.10.36"
29-
Krylov = "0.9.5"
29+
Krylov = "0.9.6"
3030
LinearAlgebra = "1.10"
31-
LinearOperators = "2.7.0"
31+
LinearOperators = "2.8.0"
3232
julia = "1.10"
3333

3434
[extras]

ext/ImplicitDifferentiationEnzymeExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ module ImplicitDifferentiationEnzymeExt
22

33
using ADTypes
44
using Enzyme
5-
using Enzyme.EnzymeCore
5+
using Enzyme.EnzymeRules
66
using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output
77

88
const FORWARD_BACKEND = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)
99

1010
function EnzymeRules.forward(
11+
config::EnzymeRules.FwdConfig,
1112
func::Const{<:ImplicitFunction},
1213
RT::Type{<:Union{BatchDuplicated,BatchDuplicatedNoNeed}},
1314
func_x::Union{BatchDuplicated{T,N},BatchDuplicatedNoNeed{T,N}},

src/ImplicitDifferentiation.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@ module ImplicitDifferentiation
99

1010
using ADTypes: AbstractADType
1111
using DifferentiationInterface:
12+
Constant,
1213
jacobian,
1314
prepare_pushforward_same_point,
1415
prepare_pullback_same_point,
1516
pullback!,
16-
pushforward!
17+
pushforward!,
18+
unwrap
1719
using Krylov: block_gmres, gmres
1820
using LinearOperators: LinearOperator
19-
using LinearAlgebra: factorize, lu
21+
using LinearAlgebra: axpby!, factorize, lu
2022

2123
include("implicit_function.jl")
2224
include("operators.jl")

src/implicit_function.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,11 @@ end
177177

178178
output(y::AbstractVector) = y
179179
byproduct(::AbstractVector) = error("No byproduct")
180+
rest(::AbstractVector) = ()
180181

181182
output(yz::Tuple{<:Any,<:Any}) = yz[1]
182183
byproduct(yz::Tuple{<:Any,<:Any}) = yz[2]
184+
rest(yz::Tuple) = (byproduct(yz),)
183185

184186
output((y, z)) = y
185187
byproduct((y, z)) = z

src/operators.jl

Lines changed: 65 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,75 @@
1-
## Partial conditions
2-
3-
struct ConditionsXNoByproduct{C,Y,A,K}
1+
struct ConditionsX{C,K}
42
conditions::C
5-
y::Y
6-
args::A
73
kwargs::K
84
end
95

10-
function (conditions_x_nobyproduct::ConditionsXNoByproduct)(x::AbstractVector)
11-
(; conditions, y, args, kwargs) = conditions_x_nobyproduct
12-
return conditions(x, y, args...; kwargs...)
13-
end
14-
15-
struct ConditionsYNoByproduct{C,X,A,K}
6+
struct ConditionsY{C,K}
167
conditions::C
17-
x::X
18-
args::A
198
kwargs::K
209
end
2110

22-
function (conditions_y_nobyproduct::ConditionsYNoByproduct)(y::AbstractVector)
23-
(; conditions, x, args, kwargs) = conditions_y_nobyproduct
24-
return conditions(x, y, args...; kwargs...)
11+
function (cx::ConditionsX)(x, y, args...)
12+
return cx.conditions(x, y, args...; cx.kwargs...)
2513
end
2614

27-
struct ConditionsXByproduct{C,Y,Z,A,K}
28-
conditions::C
29-
y::Y
30-
z::Z
31-
args::A
32-
kwargs::K
33-
end
34-
35-
function (conditions_x_byproduct::ConditionsXByproduct)(x::AbstractVector)
36-
(; conditions, y, z, args, kwargs) = conditions_x_byproduct
37-
return conditions(x, y, z, args...; kwargs...)
15+
function (cy::ConditionsY)(y, x, args...) # order switch
16+
return cy.conditions(x, y, args...; cy.kwargs...)
3817
end
3918

40-
struct ConditionsYByproduct{C,X,Z,A,K}
41-
conditions::C
19+
struct PushforwardOperator!{F,P,B,X,C,R}
20+
f::F
21+
prep::P
22+
backend::B
4223
x::X
43-
z::Z
44-
args::A
45-
kwargs::K
46-
end
47-
48-
function (conditions_y_byproduct::ConditionsYByproduct)(y::AbstractVector)
49-
(; conditions, x, z, args, kwargs) = conditions_y_byproduct
50-
return conditions(x, y, z, args...; kwargs...)
51-
end
52-
53-
function ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
54-
y = output(y_or_yz)
55-
if y_or_yz isa Tuple
56-
z = byproduct(y_or_yz)
57-
return ConditionsXByproduct(conditions, y, z, args, kwargs)
58-
else
59-
return ConditionsXNoByproduct(conditions, y, args, kwargs)
60-
end
61-
end
62-
63-
function ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
64-
if y_or_yz isa Tuple
65-
z = byproduct(y_or_yz)
66-
return ConditionsYByproduct(conditions, x, z, args, kwargs)
67-
else
68-
return ConditionsYNoByproduct(conditions, x, args, kwargs)
69-
end
24+
contexts::C
25+
res_backup::R
7026
end
7127

72-
## Lazy operators
73-
74-
struct PushforwardOperator!{F,B,X,E,R}
28+
struct PullbackOperator!{F,P,B,X,C,R}
7529
f::F
30+
prep::P
7631
backend::B
7732
x::X
78-
extras::E
33+
contexts::C
7934
res_backup::R
8035
end
8136

37+
function PushforwardOperator!(f, prep, backend, x, contexts)
38+
res_backup = similar(f(x, map(unwrap, contexts)...))
39+
return PushforwardOperator!(f, prep, backend, x, contexts, res_backup)
40+
end
41+
42+
function PullbackOperator!(f, prep, backend, x, contexts)
43+
res_backup = similar(x)
44+
return PullbackOperator!(f, prep, backend, x, contexts, res_backup)
45+
end
46+
8247
function (po::PushforwardOperator!)(res, v, α, β)
48+
(; f, backend, x, contexts, prep, res_backup) = po
8349
if iszero(β)
84-
pushforward!(po.f, res, po.backend, po.x, v, po.extras)
85-
res .= α .* res
50+
pushforward!(f, (res,), prep, backend, x, (v,), contexts...)
51+
if !isone(α)
52+
res .*= α
53+
end
8654
else
87-
po.res_backup .= res
88-
pushforward!(po.f, res, po.backend, po.x, v, po.extras)
89-
res .= α .* res .+ β .* po.res_backup
55+
copyto!(res_backup, res)
56+
pushforward!(f, (res,), prep, backend, x, (v,), contexts...)
57+
axpby!(β, res_backup, α, res)
9058
end
9159
return res
9260
end
9361

94-
struct PullbackOperator!{F,B,X,E,R}
95-
f::F
96-
backend::B
97-
x::X
98-
extras::E
99-
res_backup::R
100-
end
101-
10262
function (po::PullbackOperator!)(res, v, α, β)
63+
(; f, backend, x, contexts, prep, res_backup) = po
10364
if iszero(β)
104-
pullback!(po.f, res, po.backend, po.x, v, po.extras)
105-
res .= α .* res
65+
pullback!(f, (res,), prep, backend, x, (v,), contexts...)
66+
if !isone(α)
67+
res .*= α
68+
end
10669
else
107-
po.res_backup .= res
108-
pullback!(po.f, res, po.backend, po.x, v, po.extras)
109-
res .= α .* res .+ β .+ po.res_backup
70+
copyto!(res_backup, res)
71+
pullback!(f, (res,), prep, backend, x, (v,), contexts...)
72+
axpby!(β, res_backup, α, res)
11073
end
11174
return res
11275
end
@@ -119,24 +82,25 @@ function build_A(
11982
suggested_backend,
12083
kwargs...,
12184
) where {lazy}
122-
(; conditions, linear_solver, conditions_y_backend) = implicit
85+
(; conditions, conditions_y_backend) = implicit
12386
y = output(y_or_yz)
12487
n, m = length(x), length(y)
12588
back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend
126-
cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
89+
cond_y = ConditionsY(conditions, kwargs)
90+
contexts = (Constant(x), map(Constant, rest(y_or_yz))..., map(Constant, args)...)
12791
if lazy
128-
extras = prepare_pushforward_same_point(cond_y, back_y, y, zero(y))
92+
prep = prepare_pushforward_same_point(cond_y, back_y, y, (zero(y),), contexts...)
12993
A = LinearOperator(
13094
eltype(y),
13195
m,
13296
m,
13397
false,
13498
false,
135-
PushforwardOperator!(cond_y, back_y, y, extras, similar(y)),
99+
PushforwardOperator!(cond_y, prep, back_y, y, contexts),
136100
typeof(y),
137101
)
138102
else
139-
J = jacobian(cond_y, back_y, y)
103+
J = jacobian(cond_y, back_y, y, contexts...)
140104
A = factorize(J)
141105
end
142106
return A
@@ -150,24 +114,25 @@ function build_Aᵀ(
150114
suggested_backend,
151115
kwargs...,
152116
) where {lazy}
153-
(; conditions, linear_solver, conditions_y_backend) = implicit
117+
(; conditions, conditions_y_backend) = implicit
154118
y = output(y_or_yz)
155119
n, m = length(x), length(y)
156120
back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend
157-
cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
121+
cond_y = ConditionsY(conditions, kwargs)
122+
contexts = (Constant(x), map(Constant, rest(y_or_yz))..., map(Constant, args)...)
158123
if lazy
159-
extras = prepare_pullback_same_point(cond_y, back_y, y, zero(y))
124+
prep = prepare_pullback_same_point(cond_y, back_y, y, (zero(y),), contexts...)
160125
Aᵀ = LinearOperator(
161126
eltype(y),
162127
m,
163128
m,
164129
false,
165130
false,
166-
PullbackOperator!(cond_y, back_y, y, extras, similar(y)),
131+
PullbackOperator!(cond_y, prep, back_y, y, contexts),
167132
typeof(y),
168133
)
169134
else
170-
Jᵀ = transpose(jacobian(cond_y, back_y, y))
135+
Jᵀ = transpose(jacobian(cond_y, back_y, y, contexts...))
171136
Aᵀ = factorize(Jᵀ)
172137
end
173138
return Aᵀ
@@ -181,24 +146,25 @@ function build_B(
181146
suggested_backend,
182147
kwargs...,
183148
) where {lazy}
184-
(; conditions, linear_solver, conditions_x_backend) = implicit
149+
(; conditions, conditions_x_backend) = implicit
185150
y = output(y_or_yz)
186151
n, m = length(x), length(y)
187152
back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend
188-
cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
153+
cond_x = ConditionsX(conditions, kwargs)
154+
contexts = (Constant(y), map(Constant, rest(y_or_yz))..., map(Constant, args)...)
189155
if lazy
190-
extras = prepare_pushforward_same_point(cond_x, back_x, x, zero(x))
156+
prep = prepare_pushforward_same_point(cond_x, back_x, x, (zero(x),), contexts...)
191157
B = LinearOperator(
192158
eltype(y),
193159
m,
194160
n,
195161
false,
196162
false,
197-
PushforwardOperator!(cond_x, back_x, x, extras, similar(y)),
163+
PushforwardOperator!(cond_x, prep, back_x, x, contexts),
198164
typeof(x),
199165
)
200166
else
201-
B = transpose(jacobian(cond_x, back_x, x))
167+
B = transpose(jacobian(cond_x, back_x, x, contexts...))
202168
end
203169
return B
204170
end
@@ -211,24 +177,25 @@ function build_Bᵀ(
211177
suggested_backend,
212178
kwargs...,
213179
) where {lazy}
214-
(; conditions, linear_solver, conditions_x_backend) = implicit
180+
(; conditions, conditions_x_backend) = implicit
215181
y = output(y_or_yz)
216182
n, m = length(x), length(y)
217183
back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend
218-
cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
184+
cond_x = ConditionsX(conditions, kwargs)
185+
contexts = (Constant(y), map(Constant, rest(y_or_yz))..., map(Constant, args)...)
219186
if lazy
220-
extras = prepare_pullback_same_point(cond_x, back_x, x, zero(y))
187+
prep = prepare_pullback_same_point(cond_x, back_x, x, (zero(y),), contexts...)
221188
Bᵀ = LinearOperator(
222189
eltype(y),
223190
n,
224191
m,
225192
false,
226193
false,
227-
PullbackOperator!(cond_x, back_x, x, extras, similar(x)),
194+
PullbackOperator!(cond_x, prep, back_x, x, contexts),
228195
typeof(x),
229196
)
230197
else
231-
Bᵀ = transpose(jacobian(cond_x, back_x, x))
198+
Bᵀ = transpose(jacobian(cond_x, back_x, x, contexts...))
232199
end
233200
return Bᵀ
234201
end

test/systematic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ conditions_backend_candidates = (
2929

3030
x_candidates = (
3131
Float32[3, 4], #
32-
MVector{2}(Float32[3, 4]), #
32+
# MVector{2}(Float32[3, 4]), #
3333
);
3434

3535
## Test loop

0 commit comments

Comments
 (0)