Skip to content

Commit 26666c9

Browse files
authored
Merge pull request #31 from gdalle/dev
Enable higher-order derivatives
2 parents 1581d2e + 005170d commit 26666c9

4 files changed

+57
-36
lines changed

src/implicit_function.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,10 @@ Keyword arguments are given to both `implicit.forward` and `implicit.conditions`
5757
function ChainRulesCore.frule(
5858
rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
5959
) where {R<:Real}
60-
forward = implicit.forward
6160
conditions = implicit.conditions
6261
linear_solver = implicit.linear_solver
6362

64-
y = forward(x; kwargs...)
63+
y = implicit(x; kwargs...)
6564

6665
conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...)
6766
conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...)
@@ -98,11 +97,10 @@ Keyword arguments are given to both `implicit.forward` and `implicit.conditions`
9897
function ChainRulesCore.rrule(
9998
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
10099
) where {R<:Real}
101-
forward = implicit.forward
102100
conditions = implicit.conditions
103101
linear_solver = implicit.linear_solver
104102

105-
y = forward(x; kwargs...)
103+
y = implicit(x; kwargs...)
106104

107105
conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...)
108106
conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...)

test/1_unconstrained_optimization.jl

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,20 @@ The optimality conditions are given by gradient stationarity:
99
```math
1010
F(x, \hat{y}(x)) = 0 \quad \text{with} \quad F(x,y) = \nabla_2 f(x, y) = 0
1111
```
12-
1312
=#
1413

14+
using ChainRulesTestUtils #src
15+
using ForwardDiff
16+
using ForwardDiffChainRules
17+
using ImplicitDifferentiation
18+
using LinearAlgebra #src
19+
using Optim
20+
using Random
21+
using Test #src
22+
using Zygote
23+
24+
Random.seed!(63);
25+
1526
# ## Implicit function wrapper
1627

1728
#=
@@ -22,8 +33,6 @@ f(x, y) = \lVert y - x \rVert^2
2233
In this case, the optimization algorithm is very simple (the identity function does the job), but still we implement it using a black box solver from [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl) to show that it doesn't change the result.
2334
=#
2435

25-
using Optim
26-
2736
function dumb_identity(x::AbstractArray{Float64})
2837
f(y) = sum(abs2, y - x)
2938
y0 = zero(x)
@@ -40,72 +49,88 @@ zero_gradient(x, y) = 2(y - x);
4049

4150
# We now have all the ingredients to construct our implicit function.
4251

43-
using ImplicitDifferentiation
44-
4552
implicit = ImplicitFunction(dumb_identity, zero_gradient);
4653

4754
# Time to test!
4855

49-
using Random
50-
Random.seed!(63)
51-
5256
x = rand(3, 2)
5357

5458
# Let's start by taking a look at the forward pass, which should be the identity function.
5559

5660
implicit(x)
5761

58-
# ## Autodiff with Zygote.jl
62+
# ## Why bother?
5963

60-
using Zygote
64+
# It is important to understand why implicit differentiation is necessary here. Indeed, our optimization solver alone doesn't support autodiff with ForwardDiff.jl (due to type constraints)
65+
66+
try
67+
ForwardDiff.jacobian(dumb_identity, x)
68+
catch e
69+
e
70+
end
6171

62-
# If we use an autodiff package compatible with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl), such as [Zygote.jl](https://github.com/FluxML/Zygote.jl), differentiation works out of the box.
72+
# ... nor is it compatible with Zygote.jl (due to unsupported `try/catch` statements).
73+
74+
try
75+
Zygote.jacobian(dumb_identity, x)[1]
76+
catch e
77+
e
78+
end
79+
80+
# ## Autodiff with Zygote.jl
81+
82+
# If we use an autodiff package compatible with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl), such as [Zygote.jl](https://github.com/FluxML/Zygote.jl), implicit differentiation works out of the box.
6383

6484
Zygote.jacobian(implicit, x)[1]
6585

6686
# As expected, we recover the identity matrix as Jacobian. Strictly speaking, the Jacobian should be a 4D tensor, but it is flattened into a 2D matrix.
6787

6888
# ## Autodiff with ForwardDiff.jl
6989

70-
using ForwardDiff
71-
7290
# If we want to use [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) instead, we run into a problem: custom chain rules are not directly translated into dual number dispatch. Luckily, [ForwardDiffChainRules.jl](https://github.com/ThummeTo/ForwardDiffChainRules.jl) provides us with a workaround. All we need to do is to apply the following macro:
7391

74-
using ForwardDiffChainRules
75-
7692
@ForwardDiff_frule (f::typeof(implicit))(x::AbstractArray{<:ForwardDiff.Dual}; kwargs...)
7793

7894
# And then things work like a charm!
7995

8096
ForwardDiff.jacobian(implicit, x)
8197

82-
# ## Why did we bother?
98+
# ## Higher order differentiation
8399

84-
# It is important to understand that implicit differentiation was necessary here. Indeed our solver alone doesn't support autodiff with ForwardDiff.jl (due to type constraints)
100+
h = rand(size(x));
85101

86-
try
87-
ForwardDiff.jacobian(dumb_identity, x)
88-
catch e
89-
e
90-
end
102+
# Assuming we need second-order derivatives, nesting calls to Zygote.jl is generally a bad idea. We can, however, nest calls to ForwardDiff.jl.
91103

92-
# ... nor was it compatible with Zygote.jl (due to unsupported `try/catch` statements).
104+
D(x, h) = ForwardDiff.derivative(t -> implicit(x .+ t .* h), 0)
105+
DD(x, h1, h2) = ForwardDiff.derivative(t -> D(x .+ t .* h2, h1), 0);
106+
107+
#-
93108

94109
try
95-
Zygote.jacobian(dumb_identity, x)[1]
110+
DD(x, h, h) # fails
96111
catch e
97112
e
98113
end
99114

100-
# The following tests are not included in the docs. #src
115+
# The only requirement is to switch to a linear solver that is compatible with dual numbers (which the default `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) is not).
101116

102-
using ChainRulesTestUtils #src
103-
using LinearAlgebra #src
104-
using Test #src
117+
linear_solver2(A, b) = (Matrix(A) \ b, (solved=true,))
118+
implicit2 = ImplicitFunction(dumb_identity, zero_gradient, linear_solver2);
119+
@ForwardDiff_frule (f::typeof(implicit2))(x::AbstractArray{<:ForwardDiff.Dual}; kwargs...)
120+
121+
D2(x, h) = ForwardDiff.derivative(t -> implicit2(x .+ t .* h), 0)
122+
DD2(x, h1, h2) = ForwardDiff.derivative(t -> D2(x .+ t .* h2, h1), 0);
123+
124+
#-
125+
126+
DD2(x, h, h)
127+
128+
# The following tests are not included in the docs. #src
105129

106130
@testset verbose = true "ForwardDiff.jl" begin #src
107131
@test_throws MethodError ForwardDiff.jacobian(dumb_identity, x) #src
108132
@test ForwardDiff.jacobian(implicit, x) == I #src
133+
@test all(DD2(x, h, h) .≈ 0) #src
109134
end #src
110135

111136
@testset verbose = true "Zygote.jl" begin #src

test/2_sparse_linear_regression.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ using MathOptInterface
2020
using MathOptSetDistances
2121
using Random
2222
using SCS
23+
using Test #src
2324
using Zygote
2425

2526
Random.seed!(63);
@@ -111,8 +112,6 @@ sum(abs, J - J_ref) / prod(size(J))
111112

112113
# The following tests are not included in the docs. #src
113114

114-
using Test #src
115-
116115
@testset verbose = true "FiniteDifferences.jl" begin #src
117116
@test sum(abs, J - J_ref) / prod(size(J)) <= 1e-2 #src
118117
end #src

test/3_optimal_transport.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using FiniteDifferences
99
using ImplicitDifferentiation
1010
using LinearAlgebra
1111
using Random
12+
using Test #src
1213
using Zygote
1314

1415
Random.seed!(63);
@@ -160,8 +161,6 @@ sum(abs, J2 - J_ref) / prod(size(J_ref))
160161

161162
# The following tests are not included in the docs. #src
162163

163-
using Test #src
164-
165164
@testset verbose = true "FiniteDifferences.jl" begin #src
166165
@test u1 == u2 #src
167166
@test all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T)) #src

0 commit comments

Comments
 (0)