Skip to content

Commit 19f68d9

Browse files
authored
Merge pull request #10 from gdalle/dev-giom
Prepare first release
2 parents 4208a61 + 04919ad commit 19f68d9

15 files changed

+310
-107
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@ version = "0.1.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
89
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
10+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
911

1012
[compat]
1113
ChainRulesCore = "1.14"
14+
Krylov = "0.8.1"
1215
LinearOperators = "2.2.3"
1316
julia = "1.7"

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
[![Build Status](https://github.com/gdalle/ImplicitDifferentiation.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/gdalle/ImplicitDifferentiation.jl/actions/workflows/CI.yml?query=branch%3Amain)
66
[![Coverage](https://codecov.io/gh/gdalle/ImplicitDifferentiation.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/gdalle/ImplicitDifferentiation.jl)
77
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
8+
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
89

9-
Automatic differentiation of implicit functions. See the [documentation](https://gdalle.github.io/ImplicitDifferentiation.jl/dev) for details.
10+
Automatic differentiation of implicit functions.
11+
See the [documentation](https://gdalle.github.io/ImplicitDifferentiation.jl/dev) for details.
1012

11-
> This package is in a very early development stage, so use it at your own risk!
13+
> This package is in a very early development stage, so use it at your own risk!

docs/Manifest.toml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ version = "3.43.0"
123123
deps = ["Artifacts", "Libdl"]
124124
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
125125

126+
[[deps.ComponentArrays]]
127+
deps = ["ArrayInterface", "ChainRulesCore", "LinearAlgebra", "Requires"]
128+
git-tree-sha1 = "243d8b8afc829a6707bbb1cd00da868703c2ef42"
129+
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
130+
version = "0.11.15"
131+
126132
[[deps.DataAPI]]
127133
git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8"
128134
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
@@ -192,12 +198,6 @@ version = "0.27.15"
192198
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
193199
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
194200

195-
[[deps.ExactOptimalTransport]]
196-
deps = ["Distances", "Distributions", "LinearAlgebra", "MathOptInterface", "PDMats", "QuadGK", "SparseArrays", "StatsBase"]
197-
git-tree-sha1 = "fb24e2b311008117679935b4e3d43cf285359153"
198-
uuid = "24df6009-d856-477c-ac5c-91f668376b31"
199-
version = "0.1.1"
200-
201201
[[deps.ExprTools]]
202202
git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d"
203203
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -256,7 +256,7 @@ uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
256256
version = "0.1.1"
257257

258258
[[deps.ImplicitDifferentiation]]
259-
deps = ["ChainRulesCore", "LinearOperators"]
259+
deps = ["ChainRulesCore", "LinearOperators", "SparseArrays"]
260260
path = ".."
261261
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
262262
version = "0.1.0"
@@ -455,9 +455,9 @@ version = "7.8.2"
455455

456456
[[deps.NNlib]]
457457
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
458-
git-tree-sha1 = "a59a614b8b4ea6dc1dcec8c6514e251f13ccbe10"
458+
git-tree-sha1 = "3a8dfd0cfb5bb3b82d09949e14423409b9334acb"
459459
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
460-
version = "0.8.4"
460+
version = "0.7.34"
461461

462462
[[deps.NaNMath]]
463463
git-tree-sha1 = "b086b7ea07f8e38cf122f5016af580881ac914fe"
@@ -494,10 +494,10 @@ uuid = "429524aa-4258-5aef-a3af-852621145aeb"
494494
version = "1.6.2"
495495

496496
[[deps.OptimalTransport]]
497-
deps = ["ExactOptimalTransport", "IterativeSolvers", "LinearAlgebra", "LogExpFunctions", "NNlib", "Reexport"]
498-
git-tree-sha1 = "79ba1dab46dfc7b677278ebe892a431788da86a9"
497+
deps = ["Distances", "Distributions", "IterativeSolvers", "LinearAlgebra", "LogExpFunctions", "MathOptInterface", "NNlib", "PDMats", "QuadGK", "SparseArrays", "StatsBase"]
498+
git-tree-sha1 = "d85a74ab73bb0f4ccec7f713a352ba4c11daf750"
499499
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
500-
version = "0.3.19"
500+
version = "0.3.15"
501501

502502
[[deps.OrderedCollections]]
503503
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
[deps]
22
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
33
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
45
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
56
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
7+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
68
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
79
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
810
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"

docs/src/index.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ CurrentModule = ImplicitDifferentiation
1111
To install it, open a Julia Pkg REPL and run:
1212
```julia
1313
pkg> add "https://github.com/gdalle/ImplicitDifferentiation.jl"
14-
```
14+
```
15+
16+
## Related packages
17+
18+
- [DiffOpt.jl](https://github.com/jump-dev/DiffOpt.jl): differentiation of convex optimization problems

src/ImplicitDifferentiation.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
module ImplicitDifferentiation
22

33
using ChainRulesCore
4+
using Krylov
45
using LinearOperators
6+
using SparseArrays
57

68
include("implicit_function.jl")
9+
include("simplex.jl")
710

811
export ImplicitFunction
12+
export simplex_projection
913

1014
end

src/implicit_function.jl

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,34 @@ If `x ∈ ℝⁿ`, `y ∈ ℝᵐ` and `F(x,y) ∈ ℝᶜ`, this amounts to solvi
1414
- `conditions::C`: callable of the form `(x,y) -> F(x,y)`
1515
- `linear_solver::L`: callable of the form `(A,b) -> u` such that `A * u = b`
1616
"""
17-
Base.@kwdef struct ImplicitFunction{F,C,L}
17+
struct ImplicitFunction{F,C,L}
1818
forward::F
1919
conditions::C
2020
linear_solver::L
2121
end
2222

23+
"""
24+
ImplicitFunction(forward, conditions)
25+
26+
Construct an `ImplicitFunction` with `Krylov.gmres` as the default linear solver.
27+
28+
# See also
29+
- [`ImplicitFunction{F,C,L}`](@ref)
30+
"""
31+
function ImplicitFunction(forward::F, conditions::C) where {F,C}
32+
return ImplicitFunction(forward, conditions, Krylov.gmres)
33+
end
34+
35+
struct SolverFailureException <: Exception
36+
msg::String
37+
end
38+
2339
"""
2440
implicit(x)
2541
2642
Make [`ImplicitFunction`](@ref) callable by applying `implicit.forward`.
2743
"""
28-
(implicit::ImplicitFunction)(x::AbstractVector{<:Real}) = implicit.forward(x)
44+
(implicit::ImplicitFunction)(x) = implicit.forward(x)
2945

3046
"""
3147
frule(rc, (_, dx), implicit, x)
@@ -34,30 +50,33 @@ Custom forward rule for [`ImplicitFunction`](@ref).
3450
3551
We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`.
3652
"""
37-
function ChainRulesCore.frule(rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractVector{R}) where {R<:Real}
38-
forward = implicit.forward
39-
conditions = implicit.conditions
40-
linear_solver = implicit.linear_solver
53+
function ChainRulesCore.frule(
54+
rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractVector
55+
)
56+
(; forward, conditions, linear_solver) = implicit
4157

4258
y = forward(x)
59+
n, m = length(x), length(y)
4360

44-
F₁(x̃) = conditions(x̃, y)
45-
F₂(ỹ) = -conditions(x, ỹ)
61+
conditions_x(x̃) = conditions(x̃, y)
62+
conditions_y(ỹ) = -conditions(x, ỹ)
4663

47-
pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), F₂, y)[2]
48-
pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), F₁, x)[2]
64+
pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y)[2]
65+
pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x)[2]
4966

50-
n, m, c = length(x), length(y), length(y)
5167
mul_A!(res, v) = res .= pushforward_A(v)
5268
mul_B!(res, v) = res .= pushforward_B(v)
5369

54-
A = LinearOperator(R, c, m, false, false, mul_A!)
55-
B = LinearOperator(R, c, m, false, false, mul_B!)
70+
A = LinearOperator(Float64, m, m, false, false, mul_A!)
71+
B = LinearOperator(Float64, m, n, false, false, mul_B!)
5672

57-
b = B * unthunk(dx)
58-
dy, stats = linear_solver(A, b)
59-
60-
return y, dy
73+
dx_vec = Vector(unthunk(dx))
74+
b = B * dx_vec
75+
dy_vec, stats = linear_solver(A, b)
76+
if !stats.solved
77+
throw(SolverFailureException("The linear solver failed to converge"))
78+
end
79+
return y, dy_vec
6180
end
6281

6382
"""
@@ -67,30 +86,32 @@ Custom reverse rule for [`ImplicitFunction`](@ref).
6786
6887
We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = Bᵀu`.
6988
"""
70-
function ChainRulesCore.rrule(rc::RuleConfig, implicit::ImplicitFunction, x::AbstractVector{R}) where {R<:Real}
71-
forward = implicit.forward
72-
conditions = implicit.conditions
73-
linear_solver = implicit.linear_solver
89+
function ChainRulesCore.rrule(rc::RuleConfig, implicit::ImplicitFunction, x::AbstractVector)
90+
(; forward, conditions, linear_solver) = implicit
7491

7592
y = forward(x)
93+
n, m = length(x), length(y)
7694

77-
F₁(x̃) = conditions(x̃, y)
78-
F₂(ỹ) = -conditions(x, ỹ)
95+
conditions_x(x̃) = conditions(x̃, y)
96+
conditions_y(ỹ) = -conditions(x, ỹ)
7997

80-
pullback_Aᵀ = last rrule_via_ad(rc, F₂, y)[2]
81-
pullback_Bᵀ = last rrule_via_ad(rc, F₁, x)[2]
98+
pullback_Aᵀ = last rrule_via_ad(rc, conditions_y, y)[2]
99+
pullback_Bᵀ = last rrule_via_ad(rc, conditions_x, x)[2]
82100

83-
n, m, c = length(x), length(y), length(y)
84101
mul_Aᵀ!(res, v) = res .= pullback_Aᵀ(v)
85102
mul_Bᵀ!(res, v) = res .= pullback_Bᵀ(v)
86103

87-
Aᵀ = LinearOperator(R, m, c, false, false, mul_Aᵀ!)
88-
Bᵀ = LinearOperator(R, n, c, false, false, mul_Bᵀ!)
104+
Aᵀ = LinearOperator(Float64, m, m, false, false, mul_Aᵀ!)
105+
Bᵀ = LinearOperator(Float64, n, m, false, false, mul_Bᵀ!)
89106

90107
function implicit_pullback(dy)
91-
u, stats = linear_solver(Aᵀ, unthunk(dy))
92-
dx = Bᵀ * u
93-
return (NoTangent(), dx)
108+
dy_vec = Vector(unthunk(dy))
109+
u, stats = linear_solver(Aᵀ, dy_vec)
110+
if !stats.solved
111+
throw(SolverFailureException("The linear solver failed to converge"))
112+
end
113+
dx_vec = Bᵀ * u
114+
return (NoTangent(), dx_vec)
94115
end
95116

96117
return y, implicit_pullback

src/simplex.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
simplex_projection_and_support(z)
3+
4+
Compute the Euclidean projection onto the probability simplex and the set of indices where it is nonzero.
5+
6+
See <https://arxiv.org/abs/1602.02068> for details.
7+
"""
8+
function simplex_projection_and_support(z::AbstractVector{R}) where {R<:Real}
9+
d = length(z)
10+
z_sorted = sort(z; rev=true)
11+
z_sorted_cumsum = cumsum(z_sorted)
12+
k = maximum(j for j in 1:d if (1 + j * z_sorted[j]) > z_sorted_cumsum[j])
13+
τ = (z_sorted_cumsum[k] - 1) / k
14+
p = Vector{R}(undef, d)
15+
s = Vector{Int}(undef, d)
16+
for i in 1:d
17+
p[i] = max(z[i] - τ, zero(R))
18+
s[i] = Int(!iszero(p[i]))
19+
end
20+
return p, s
21+
end;
22+
23+
"""
24+
simplex_projection(z)
25+
26+
Compute the Euclidean projection onto the probability simplex.
27+
"""
28+
function simplex_projection(z::AbstractVector{<:Real})
29+
p, _ = simplex_projection_and_support(z)
30+
return p
31+
end;
32+
33+
"""
34+
rrule(::typeof(simplex_projection), z)
35+
36+
Custom reverse rule for [`simplex_projection`](@ref) which bypasses the sorting step.
37+
38+
See <https://arxiv.org/abs/1602.02068> for details.
39+
"""
40+
function ChainRulesCore.rrule(::typeof(simplex_projection), z::AbstractVector{<:Real})
41+
p, s = simplex_projection_and_support(z)
42+
S = sum(s)
43+
function simplex_projection_pullback(dp)
44+
vjp = s .* (dp .- (dp's) / S)
45+
return (NoTangent(), vjp)
46+
end
47+
return p, simplex_projection_pullback
48+
end;

test/1_unconstrained_optimization.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ F(x, \hat{y}(x)) = 0 \quad \text{with} \quad F(x,y) = \nabla_2 f(x, y) = 0
1313
=#
1414

1515
using ImplicitDifferentiation
16-
using Krylov: gmres
1716
using Optim: optimize, minimizer, LBFGS
1817
using Zygote
1918

@@ -46,7 +45,7 @@ conditions(x, y) = 2(y - x);
4645

4746
# We now have all the ingredients to construct our implicit function.
4847

49-
implicit = ImplicitFunction(; forward=forward, conditions=conditions, linear_solver=gmres);
48+
implicit = ImplicitFunction(forward, conditions);
5049

5150
# ## Testing
5251

@@ -65,6 +64,6 @@ Zygote.jacobian(implicit, x)[1]
6564
# The following tests are not included in the docs. #src
6665

6766
@testset verbose = true "ChainRules" begin #src
68-
test_frule(implicit, x) #src
69-
test_rrule(implicit, x) #src
67+
test_frule(implicit, x; check_inferred=false) #src
68+
test_rrule(implicit, x; check_inferred=false) #src
7069
end #src

0 commit comments

Comments
 (0)