You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In this case, the optimization algorithm is very simple (the identity function does the job), but still we implement it using a black box solver from [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl) to show that it doesn't change the result.
# Let's start by taking a look at the forward pass, which should be the identity function.
55
59
56
60
implicit(x)
57
61
58
-
# ## Autodiff with Zygote.jl
62
+
# ## Why bother?
59
63
60
-
using Zygote
64
+
# It is important to understand why implicit differentiation is necessary here. Indeed, our optimization solver alone doesn't support autodiff with ForwardDiff.jl (due to type constraints)
65
+
66
+
try
67
+
ForwardDiff.jacobian(dumb_identity, x)
68
+
catch e
69
+
e
70
+
end
61
71
62
-
# If we use an autodiff package compatible with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl), such as [Zygote.jl](https://github.com/FluxML/Zygote.jl), differentiation works out of the box.
72
+
# ... nor is it compatible with Zygote.jl (due to unsupported `try/catch` statements).
73
+
74
+
try
75
+
Zygote.jacobian(dumb_identity, x)[1]
76
+
catch e
77
+
e
78
+
end
79
+
80
+
# ## Autodiff with Zygote.jl
81
+
82
+
# If we use an autodiff package compatible with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl), such as [Zygote.jl](https://github.com/FluxML/Zygote.jl), implicit differentiation works out of the box.
63
83
64
84
Zygote.jacobian(implicit, x)[1]
65
85
66
86
# As expected, we recover the identity matrix as Jacobian. Strictly speaking, the Jacobian should be a 4D tensor, but it is flattened into a 2D matrix.
67
87
68
88
# ## Autodiff with ForwardDiff.jl
69
89
70
-
using ForwardDiff
71
-
72
90
# If we want to use [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) instead, we run into a problem: custom chain rules are not directly translated into dual number dispatch. Luckily, [ForwardDiffChainRules.jl](https://github.com/ThummeTo/ForwardDiffChainRules.jl) provides us with a workaround. All we need to do is to apply the following macro:
# It is important to understand that implicit differentiation was necessary here. Indeed our solver alone doesn't support autodiff with ForwardDiff.jl (due to type constraints)
100
+
h =rand(size(x));
85
101
86
-
try
87
-
ForwardDiff.jacobian(dumb_identity, x)
88
-
catch e
89
-
e
90
-
end
102
+
# Assuming we need second-order derivatives, nesting calls to Zygote.jl is generally a bad idea. We can, however, nest calls to ForwardDiff.jl.
91
103
92
-
# ... nor was it compatible with Zygote.jl (due to unsupported `try/catch` statements).
104
+
D(x, h) = ForwardDiff.derivative(t ->implicit(x .+ t .* h), 0)
# The following tests are not included in the docs. #src
115
+
# The only requirement is to switch to a linear solver that is compatible with dual numbers (which the default `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) is not).
101
116
102
-
using ChainRulesTestUtils #src
103
-
using LinearAlgebra #src
104
-
using Test #src
117
+
linear_solver2(A, b) = (Matrix(A) \ b, (solved=true,))
0 commit comments