Skip to content

Commit 6195cd3

Browse files
authored
Fix Enzyme extension and add new broken test (#151)
* Fix Enzyme extension and add new test * Adapt to latest version * No function annotation * Test broken * Fix tests * Mode * Const * Bump version and move constructor doc
1 parent e0f156c commit 6195cd3

File tree

5 files changed

+35
-26
lines changed

5 files changed

+35
-26
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -21,9 +21,9 @@ ImplicitDifferentiationEnzymeExt = "Enzyme"
2121
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
2222

2323
[compat]
24-
ADTypes = "1.0"
24+
ADTypes = "1.7.1"
2525
ChainRulesCore = "1.23.0"
26-
DifferentiationInterface = "0.5"
26+
DifferentiationInterface = "0.5.12"
2727
Enzyme = "0.11.20,0.12"
2828
ForwardDiff = "0.10.36"
2929
Krylov = "0.9.5"

examples/3_tricks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ We demonstrate several features that may come in handy for some users.
55
=#
66

77
using ComponentArrays
8+
using Enzyme #src
89
using ForwardDiff
910
using ImplicitDifferentiation
1011
using Krylov
@@ -67,6 +68,8 @@ J = ForwardDiff.jacobian(forward_components, x) #src
6768
Zygote.jacobian(implicit_components, x)[1]
6869
@test Zygote.jacobian(implicit_components, x)[1] J #src
6970

71+
@test_broken Enzyme.jacobian(Enzyme.Forward, implicit_components, x) J #src
72+
7073
#- The full differentiable pipeline looks like this
7174

7275
function full_pipeline(a, b, m)

ext/ImplicitDifferentiationEnzymeExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using Enzyme
55
using Enzyme.EnzymeCore
66
using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output
77

8+
const FORWARD_BACKEND = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)
9+
810
function EnzymeRules.forward(
911
func::Const{<:ImplicitFunction},
1012
RT::Type{<:Union{BatchDuplicated,BatchDuplicatedNoNeed}},
@@ -20,12 +22,11 @@ function EnzymeRules.forward(
2022
y = output(y_or_yz)
2123
Y = typeof(y)
2224

23-
suggested_backend = AutoEnzyme(Enzyme.Forward)
25+
suggested_backend = FORWARD_BACKEND
2426
A = build_A(implicit, x, y_or_yz, args...; suggested_backend)
2527
B = build_B(implicit, x, y_or_yz, args...; suggested_backend)
2628

27-
dx_batch = reduce(hcat, dx)
28-
dc_batch = mapreduce(hcat, eachcol(dx_batch)) do dₖx
29+
dc_batch = mapreduce(hcat, dx) do dₖx
2930
B * dₖx
3031
end
3132
dy_batch = implicit.linear_solver(A, -dc_batch)

src/implicit_function.jl

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ The value of `lazy` must be chosen together with the `linear_solver`, see below.
6060
- `conditions_x_backend`: how the conditions will be differentiated w.r.t. the first argument `x`
6161
- `conditions_y_backend`: how the conditions will be differentiated w.r.t. the second argument `y`
6262
63+
# Constructors
64+
65+
ImplicitFunction(
66+
forward, conditions;
67+
linear_solver=KrylovLinearSolver(),
68+
conditions_x_backend=nothing,
69+
conditions_x_backend=nothing,
70+
)
71+
72+
Picks the `lazy` parameter automatically based on the `linear_solver`, using the following heuristic: `lazy = linear_solver != \\`.
73+
74+
ImplicitFunction{lazy}(
75+
forward, conditions;
76+
linear_solver=lazy ? KrylovLinearSolver() : \\,
77+
conditions_x_backend=nothing,
78+
conditions_y_backend=nothing,
79+
)
80+
81+
Picks the `linear_solver` automatically based on the `lazy` parameter.
82+
6383
# Function signatures
6484
6585
There are two possible signatures for `forward` and `conditions`, which must be consistent with one another:
@@ -87,8 +107,10 @@ Typically, direct solvers work best with dense Jacobians (`lazy = false`) while
87107
# Condition backends
88108
89109
The provided `conditions_x_backend` and `conditions_y_backend` can be either:
110+
- `nothing` (the default), in which case the outer backend (the one differentiating through the `ImplicitFunction`) is used.
90111
- an object subtyping `AbstractADType` from [ADTypes.jl](https://github.com/SciML/ADTypes.jl);
91-
- `nothing`, in which case the outer backend (the one differentiating through the `ImplicitFunction`) is used.
112+
113+
When differentiating with Enzyme as an outer backend, the default setting assumes that `conditions` does not contain writeable data involved in derivatives.
92114
"""
93115
struct ImplicitFunction{
94116
lazy,F,C,L,B1<:Union{Nothing,AbstractADType},B2<:Union{Nothing,AbstractADType}
@@ -101,14 +123,7 @@ struct ImplicitFunction{
101123
end
102124

103125
"""
104-
ImplicitFunction{lazy}(
105-
forward, conditions;
106-
linear_solver=lazy ? KrylovLinearSolver() : \\,
107-
conditions_x_backend=nothing,
108-
conditions_y_backend=nothing,
109-
)
110126
111-
Constructor for an [`ImplicitFunction`](@ref) which picks the `linear_solver` automatically based on the `lazy` parameter.
112127
"""
113128
function ImplicitFunction{lazy}(
114129
forward::F,
@@ -126,16 +141,6 @@ function ImplicitFunction{lazy}(
126141
)
127142
end
128143

129-
"""
130-
ImplicitFunction(
131-
forward, conditions;
132-
linear_solver=KrylovLinearSolver(),
133-
conditions_x_backend=nothing,
134-
conditions_x_backend=nothing,
135-
)
136-
137-
Constructor for an [`ImplicitFunction`](@ref) which picks the `lazy` parameter automatically based on the `linear_solver`, using the following heuristic: `lazy = linear_solver != \\`.
138-
"""
139144
function ImplicitFunction(
140145
forward,
141146
conditions;

test/systematic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ include("utils.jl")
1212

1313
backends = [
1414
AutoForwardDiff(; chunksize=1), #
15-
AutoEnzyme(Enzyme.Forward),
15+
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
1616
AutoZygote(),
1717
]
1818

@@ -24,7 +24,7 @@ linear_solver_candidates = (
2424
conditions_backend_candidates = (
2525
nothing, #
2626
AutoForwardDiff(; chunksize=1),
27-
AutoEnzyme(Enzyme.Forward),
27+
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
2828
);
2929

3030
x_candidates = (

0 commit comments

Comments
 (0)