Skip to content

Commit 04919ad

Browse files
committed
Simplify interface and add more tests for first release
1 parent ce5834a commit 04919ad

13 files changed

+133
-36
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ version = "0.1.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
89
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
910
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1011

1112
[compat]
1213
ChainRulesCore = "1.14"
14+
Krylov = "0.8.1"
1315
LinearOperators = "2.2.3"
1416
julia = "1.7"

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
88
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
99

10-
Automatic differentiation of implicit functions. See the [documentation](https://gdalle.github.io/ImplicitDifferentiation.jl/dev) for details.
10+
Automatic differentiation of implicit functions.
11+
See the [documentation](https://gdalle.github.io/ImplicitDifferentiation.jl/dev) for details.
1112

1213
> This package is in a very early development stage, so use it at your own risk!

docs/src/index.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ CurrentModule = ImplicitDifferentiation
1111
To install it, open a Julia Pkg REPL and run:
1212
```julia
1313
pkg> add "https://github.com/gdalle/ImplicitDifferentiation.jl"
14-
```
14+
```
15+
16+
## Related packages
17+
18+
- [DiffOpt.jl](https://github.com/jump-dev/DiffOpt.jl): differentiation of convex optimization problems

src/ImplicitDifferentiation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ImplicitDifferentiation
22

33
using ChainRulesCore
4+
using Krylov
45
using LinearOperators
56
using SparseArrays
67

src/implicit_function.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,24 @@ If `x ∈ ℝⁿ`, `y ∈ ℝᵐ` and `F(x,y) ∈ ℝᶜ`, this amounts to solvi
1414
- `conditions::C`: callable of the form `(x,y) -> F(x,y)`
1515
- `linear_solver::L`: callable of the form `(A,b) -> u` such that `A * u = b`
1616
"""
17-
Base.@kwdef struct ImplicitFunction{F,C,L}
17+
struct ImplicitFunction{F,C,L}
1818
forward::F
1919
conditions::C
2020
linear_solver::L
2121
end
2222

23+
"""
24+
ImplicitFunction(forward, conditions)
25+
26+
Construct an `ImplicitFunction` with `Krylov.gmres` as the default linear solver.
27+
28+
# See also
29+
- [`ImplicitFunction{F,C,L}`](@ref)
30+
"""
31+
function ImplicitFunction(forward::F, conditions::C) where {F,C}
32+
return ImplicitFunction(forward, conditions, Krylov.gmres)
33+
end
34+
2335
struct SolverFailureException <: Exception
2436
msg::String
2537
end
@@ -29,7 +41,7 @@ end
2941
3042
Make [`ImplicitFunction`](@ref) callable by applying `implicit.forward`.
3143
"""
32-
(implicit::ImplicitFunction)(x) = first(implicit.forward(x))
44+
(implicit::ImplicitFunction)(x) = implicit.forward(x)
3345

3446
"""
3547
frule(rc, (_, dx), implicit, x)
@@ -43,11 +55,11 @@ function ChainRulesCore.frule(
4355
)
4456
(; forward, conditions, linear_solver) = implicit
4557

46-
y, useful_info = forward(x)
58+
y = forward(x)
4759
n, m = length(x), length(y)
4860

49-
conditions_x(x̃) = conditions(x̃, y, useful_info)
50-
conditions_y(ỹ) = -conditions(x, ỹ, useful_info)
61+
conditions_x(x̃) = conditions(x̃, y)
62+
conditions_y(ỹ) = -conditions(x, ỹ)
5163

5264
pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y)[2]
5365
pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x)[2]
@@ -77,11 +89,11 @@ We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and settin
7789
function ChainRulesCore.rrule(rc::RuleConfig, implicit::ImplicitFunction, x::AbstractVector)
7890
(; forward, conditions, linear_solver) = implicit
7991

80-
y, useful_info = forward(x)
92+
y = forward(x)
8193
n, m = length(x), length(y)
8294

83-
conditions_x(x̃) = conditions(x̃, y, useful_info)
84-
conditions_y(ỹ) = -conditions(x, ỹ, useful_info)
95+
conditions_x(x̃) = conditions(x̃, y)
96+
conditions_y(ỹ) = -conditions(x, ỹ)
8597

8698
pullback_Aᵀ = last rrule_via_ad(rc, conditions_y, y)[2]
8799
pullback_Bᵀ = last rrule_via_ad(rc, conditions_x, x)[2]

src/simplex.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@ Compute the Euclidean projection onto the probability simplex and the set of ind
55
66
See <https://arxiv.org/abs/1602.02068> for details.
77
"""
8-
function simplex_projection_and_support(z::AbstractVector{<:Real})
8+
function simplex_projection_and_support(z::AbstractVector{R}) where {R<:Real}
99
d = length(z)
1010
z_sorted = sort(z; rev=true)
1111
z_sorted_cumsum = cumsum(z_sorted)
1212
k = maximum(j for j in 1:d if (1 + j * z_sorted[j]) > z_sorted_cumsum[j])
1313
τ = (z_sorted_cumsum[k] - 1) / k
14-
p = max.(z .- τ, 0)
15-
s = [Int(p[i] > eps()) for i in 1:d]
14+
p = Vector{R}(undef, d)
15+
s = Vector{Int}(undef, d)
16+
for i in 1:d
17+
p[i] = max(z[i] - τ, zero(R))
18+
s[i] = Int(!iszero(p[i]))
19+
end
1620
return p, s
1721
end;
1822

test/1_unconstrained_optimization.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ F(x, \hat{y}(x)) = 0 \quad \text{with} \quad F(x,y) = \nabla_2 f(x, y) = 0
1313
=#
1414

1515
using ImplicitDifferentiation
16-
using Krylov: gmres
1716
using Optim: optimize, minimizer, LBFGS
1817
using Zygote
1918

@@ -35,18 +34,18 @@ function forward(x)
3534
y0 = zero(x)
3635
res = optimize(f, y0, LBFGS(); autodiff=:forward)
3736
y = minimizer(res)
38-
return y, nothing
37+
return y
3938
end;
4039

4140
#=
4241
On the other hand, optimality conditions should be provided explicitly whenever possible, so as to avoid nesting automatic differentiation calls.
4342
=#
4443

45-
conditions(x, y, useful_info=nothing) = 2(y - x);
44+
conditions(x, y) = 2(y - x);
4645

4746
# We now have all the ingredients to construct our implicit function.
4847

49-
implicit = ImplicitFunction(; forward=forward, conditions=conditions, linear_solver=gmres);
48+
implicit = ImplicitFunction(forward, conditions);
5049

5150
# ## Testing
5251

test/2_constrained_optimization.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ F(x, \hat{y}(x)) = 0 \quad \text{with} \quad F(x,y) = \mathrm{proj}_{\mathcal{C}
1515
using ImplicitDifferentiation
1616
using Ipopt
1717
using JuMP
18-
using Krylov: gmres
1918
using Zygote
2019

2120
using ChainRulesTestUtils #src
@@ -44,12 +43,12 @@ function forward(x)
4443
@constraint(model, sum(y) == 1)
4544
@objective(model, Min, sum((y .- x) .^ 2))
4645
optimize!(model)
47-
return value.(y), nothing
46+
return value.(y)
4847
end;
4948

50-
conditions(x, y, useful_info=nothing) = simplex_projection(y - 0.1 * 2(y - x)) - y;
49+
conditions(x, y) = simplex_projection(y - 0.1 * 2(y - x)) - y;
5150

52-
implicit = ImplicitFunction(; forward=forward, conditions=conditions, linear_solver=gmres);
51+
implicit = ImplicitFunction(forward, conditions);
5352

5453
# ## Testing
5554

test/3_optimal_transport.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
In this example, we show how to differentiate through the solution of the entropy-regularized optimal transport problem.
55
=#
66

7+
using Distances
8+
using FiniteDiff
79
using ImplicitDifferentiation
810
using OptimalTransport
9-
using Distances
10-
using Krylov: gmres
1111
using Zygote
12-
using FiniteDiff
1312
using Test, LinearAlgebra #src
1413

1514
#=
@@ -84,17 +83,17 @@ function forward(C_vec)
8483
solver = OptimalTransport.build_solver(μ, ν, C, ε, SinkhornGibbs())
8584
OptimalTransport.solve!(solver)
8685
= solver.cache.u
87-
return, nothing
86+
return
8887
end
8988

90-
function conditions(C_vec, û, useful_info=nothing)
89+
function conditions(C_vec, û)
9190
C = reshape(C_vec, n, m)
9291
K = exp.(.-C ./ ε)
9392
= ν ./ (K' * û)
9493
return.- μ ./ (K * v̂)
9594
end
9695

97-
implicit = ImplicitFunction(; forward=forward, conditions=conditions, linear_solver=gmres);
96+
implicit = ImplicitFunction(forward, conditions);
9897

9998
# ## Testing
10099

test/4_struct.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ In this example, we demonstrate implicit differentiation through functions that
66

77
using ComponentArrays
88
using ImplicitDifferentiation
9-
using Krylov: gmres
109
using Zygote
1110

1211
using ChainRulesCore #src
@@ -19,13 +18,13 @@ using Test #src
1918
We replicate a componentwise square function with `NamedTuple`s, taking `a=(x,y)` as input and returning `b=(u,v)`.
2019
=#
2120

22-
forward(a::ComponentVector) = ComponentVector(u=a.x .^ 2, v=a.y .^ 2), nothing;
21+
forward(a::ComponentVector) = ComponentVector(u=a.x .^ 2, v=a.y .^ 2)
2322

24-
function conditions(a::ComponentVector, b::ComponentVector, useful_info=nothing)
23+
function conditions(a::ComponentVector, b::ComponentVector)
2524
return vcat(b.u .- a.x .^ 2, b.v .- a.y .^ 2)
2625
end
2726

28-
implicit = ImplicitFunction(; forward=forward, conditions=conditions, linear_solver=gmres);
27+
implicit = ImplicitFunction(forward, conditions);
2928

3029
#=
3130
In order to be able to call `Zygote.gradient`, we use `implicit` to define a convoluted version of the squared Euclidean norm, which takes a `ComponentVector` as input and returns a real number.

test/Manifest.toml

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.7.2"
3+
julia_version = "1.7.3"
44
manifest_format = "2.0"
55

6+
[[deps.AMD]]
7+
deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"]
8+
git-tree-sha1 = "fc66ffc5cff568936649445f58a55b81eaf9592c"
9+
uuid = "14f7f29c-3bd6-536c-9a0b-7339e30b5a3e"
10+
version = "0.4.0"
11+
612
[[deps.ASL_jll]]
713
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
814
git-tree-sha1 = "6252039f98492252f9e47c312c8ffda0e3b9e78d"
@@ -90,6 +96,12 @@ git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1"
9096
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
9197
version = "0.1.2"
9298

99+
[[deps.CodeTracking]]
100+
deps = ["InteractiveUtils", "UUIDs"]
101+
git-tree-sha1 = "6d4fa04343a7fc9f9cb9cff9558929f3d2752717"
102+
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
103+
version = "1.0.9"
104+
93105
[[deps.CodecBzip2]]
94106
deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"]
95107
git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7"
@@ -184,9 +196,22 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
184196
version = "0.8.6"
185197

186198
[[deps.Downloads]]
187-
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
199+
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
188200
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
189201

202+
[[deps.ExprTools]]
203+
git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d"
204+
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
205+
version = "0.1.8"
206+
207+
[[deps.FastClosures]]
208+
git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef"
209+
uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
210+
version = "0.3.2"
211+
212+
[[deps.FileWatching]]
213+
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
214+
190215
[[deps.FillArrays]]
191216
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
192217
git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286"
@@ -266,6 +291,12 @@ git-tree-sha1 = "1169632f425f79429f245113b775a0e3d121457c"
266291
uuid = "42fd0dbc-a981-5370-80f2-aaf504508153"
267292
version = "0.9.2"
268293

294+
[[deps.JET]]
295+
deps = ["InteractiveUtils", "JuliaInterpreter", "LoweredCodeUtils", "MacroTools", "Pkg", "Revise", "Test"]
296+
git-tree-sha1 = "8e78b0c297cfa6cefd579f87232c89bd6ed7a081"
297+
uuid = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
298+
version = "0.5.16"
299+
269300
[[deps.JLLWrappers]]
270301
deps = ["Preferences"]
271302
git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1"
@@ -290,12 +321,24 @@ git-tree-sha1 = "57c17a221a55f81890aabf00f478886859e25eaf"
290321
uuid = "4076af6c-e467-56ae-b986-b466b2749572"
291322
version = "0.21.5"
292323

324+
[[deps.JuliaInterpreter]]
325+
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
326+
git-tree-sha1 = "52617c41d2761cc05ed81fe779804d3b7f14fff7"
327+
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
328+
version = "0.9.13"
329+
293330
[[deps.Krylov]]
294331
deps = ["LinearAlgebra", "Printf", "SparseArrays"]
295332
git-tree-sha1 = "82f5afb342a5624dc4651981584a841f6088166b"
296333
uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
297334
version = "0.8.0"
298335

336+
[[deps.LDLFactorizations]]
337+
deps = ["AMD", "LinearAlgebra", "SparseArrays", "Test"]
338+
git-tree-sha1 = "736e01b9b2d443c4e3351aebe551b8a374ab9c05"
339+
uuid = "40e66cde-538c-5869-a4ad-c39174c6795b"
340+
version = "0.8.2"
341+
299342
[[deps.LibCURL]]
300343
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
301344
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
@@ -325,6 +368,12 @@ version = "7.1.1"
325368
deps = ["Libdl", "libblastrampoline_jll"]
326369
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
327370

371+
[[deps.LinearOperators]]
372+
deps = ["FastClosures", "LDLFactorizations", "LinearAlgebra", "Printf", "SparseArrays", "TimerOutputs"]
373+
git-tree-sha1 = "b404faa9b85e62c0eeec7a600d5b4316c58215ed"
374+
uuid = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
375+
version = "2.3.2"
376+
328377
[[deps.LogExpFunctions]]
329378
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
330379
git-tree-sha1 = "a970d55c2ad8084ca317a4658ba6ce99b7523571"
@@ -334,6 +383,12 @@ version = "0.3.12"
334383
[[deps.Logging]]
335384
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
336385

386+
[[deps.LoweredCodeUtils]]
387+
deps = ["JuliaInterpreter"]
388+
git-tree-sha1 = "dedbebe234e06e1ddad435f5c6f4b85cd8ce55f7"
389+
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
390+
version = "2.2.2"
391+
337392
[[deps.METIS_jll]]
338393
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
339394
git-tree-sha1 = "1d31872bb9c5e7ec1f618e8c4a56c8b0d9bddc7e"
@@ -531,6 +586,12 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
531586
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
532587
version = "1.3.0"
533588

589+
[[deps.Revise]]
590+
deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"]
591+
git-tree-sha1 = "4d4239e93531ac3e7ca7e339f15978d0b5149d03"
592+
uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
593+
version = "3.3.3"
594+
534595
[[deps.Richardson]]
535596
deps = ["LinearAlgebra"]
536597
git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949"
@@ -628,6 +689,12 @@ uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
628689
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
629690
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
630691

692+
[[deps.TimerOutputs]]
693+
deps = ["ExprTools", "Printf"]
694+
git-tree-sha1 = "7638550aaea1c9a1e86817a231ef0faa9aca79bd"
695+
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
696+
version = "0.5.19"
697+
631698
[[deps.TranscodingStreams]]
632699
deps = ["Random", "Test"]
633700
git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c"

0 commit comments

Comments
 (0)