Skip to content

Commit c854b5d

Browse files
YingboMagdalle
andauthored
Improve type stability (#41)
* Improve type stability * Update ImplicitDifferentiationForwardDiffExt.jl * Add type inference tests and bump version to 0.4.1 --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent c6c8fc2 commit c854b5d

5 files changed

+57
-26
lines changed

CITATION.bib

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ @misc{ImplicitDifferentiation.jl
22
author = {Guillaume Dalle, Mohamed Tarek and contributors},
33
title = {ImplicitDifferentiation.jl},
44
url = {https://github.com/gdalle/ImplicitDifferentiation.jl},
5-
version = {v0.4.0},
5+
version = {v0.4.1},
66
year = {2023},
7-
month = {4}
7+
month = {5}
88
}

Project.toml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
@@ -10,6 +10,14 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1010
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
1111
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1212

13+
[weakdeps]
14+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
15+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
16+
17+
[extensions]
18+
ImplicitDifferentiationChainRulesExt = "ChainRulesCore"
19+
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
20+
1321
[compat]
1422
AbstractDifferentiation = "0.5"
1523
ChainRulesCore = "1.14"
@@ -19,10 +27,6 @@ LinearOperators = "2.2"
1927
Requires = "1.3"
2028
julia = "1.6"
2129

22-
[extensions]
23-
ImplicitDifferentiationChainRulesExt = "ChainRulesCore"
24-
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
25-
2630
[extras]
2731
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2832
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -43,7 +47,3 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4347

4448
[targets]
4549
test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "Documenter", "ForwardDiff", "JET", "JuliaFormatter", "LinearAlgebra", "NLsolve", "Optim", "Pkg", "Random", "SparseArrays", "Test", "Zygote"]
46-
47-
[weakdeps]
48-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
49-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

examples/0_basic.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
In this example, we demonstrate the basics of our package on a simple function that is not amenable to automatic differentiation.
55
=#
66

7+
using ChainRulesCore #src
8+
using ChainRulesTestUtils #src
79
using ForwardDiff
810
using ImplicitDifferentiation
11+
using JET #src
912
using LinearAlgebra
1013
using Random
1114
using Test #src
@@ -168,8 +171,16 @@ JJ = Diagonal(0.5 ./ sqrt.(vec(X))) #src
168171
@test ForwardDiff.jacobian(first implicit, X) JJ #src
169172
@test Zygote.jacobian(first implicit, X)[1] JJ #src
170173

171-
# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 #src
174+
# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities #src
172175
@testset verbose = true "ChainRulesTestUtils.jl" begin #src
173176
@test_skip test_rrule(implicit, x) #src
174177
@test_skip test_rrule(implicit, X) #src
175178
end #src
179+
180+
x_and_dx = [ForwardDiff.Dual(x[i], (0, 0)) for i in eachindex(x)] #src
181+
@inferred implicit(x_and_dx) #src
182+
183+
rc = Zygote.ZygoteRuleConfig() #src
184+
_, pullback = @inferred rrule(rc, implicit, x) #src
185+
dy, dz = zero(implicit(x)[1]), 0
186+
@inferred pullback((dy, dz))

ext/ImplicitDifferentiationChainRulesExt.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,35 @@ function ChainRulesCore.rrule(
2525
backend = ReverseRuleConfigBackend(rc)
2626
pbA = pullback_function(backend, _y -> conditions(x, _y, z; kwargs...), y)
2727
pbB = pullback_function(backend, _x -> conditions(_x, y, z; kwargs...), x)
28-
Aᵀ_op = LinearOperator(R, m, m, false, false, PullbackMul!(pbA, size(y)))
29-
Bᵀ_op = LinearOperator(R, n, m, false, false, PullbackMul!(pbB, size(y)))
30-
31-
function implicit_pullback((dy, dz))
32-
dy_vec = convert(Vector{R}, vec(unthunk(dy)))
33-
dF_vec, stats = linear_solver(Aᵀ_op, dy_vec)
34-
check_solution(linear_solver, stats)
35-
dx_vec = -(Bᵀ_op * dF_vec)
36-
dx = reshape(dx_vec, size(x))
37-
return (NoTangent(), dx)
38-
end
28+
pbmA = PullbackMul!(pbA, size(y))
29+
pbmB = PullbackMul!(pbB, size(y))
30+
Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA)
31+
Bᵀ_op = LinearOperator(R, n, m, false, false, pbmB)
32+
implicit_pullback = ImplicitPullback(Aᵀ_op, Bᵀ_op, linear_solver, x)
3933

4034
return (y, z), implicit_pullback
4135
end
4236

37+
struct ImplicitPullback{A,B,L,X}
38+
Aᵀ_op::A
39+
Bᵀ_op::B
40+
linear_solver::L
41+
x::X
42+
end
43+
44+
function (implicit_pullback::ImplicitPullback)((dy, dz))
45+
Aᵀ_op = implicit_pullback.Aᵀ_op
46+
Bᵀ_op = implicit_pullback.Bᵀ_op
47+
linear_solver = implicit_pullback.linear_solver
48+
x = implicit_pullback.x
49+
R = eltype(x)
50+
51+
dy_vec = convert(Vector{R}, vec(unthunk(dy)))
52+
dF_vec, stats = linear_solver(Aᵀ_op, dy_vec)
53+
check_solution(linear_solver, stats)
54+
dx_vec = -(Bᵀ_op * dF_vec)
55+
dx = reshape(dx_vec, size(x))
56+
return (NoTangent(), dx)
57+
end
58+
4359
end

ext/ImplicitDifferentiationForwardDiffExt.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,15 @@ function (implicit::ImplicitFunction)(
3838
reshape(dₖy_vec, size(y))
3939
end
4040

41-
y_and_dy = map(eachindex(y)) do i
42-
Dual{T}(y[i], Partials(Tuple(dy[k][i] for k in 1:N)))
41+
y_and_dy = let y = y, dy = dy
42+
map(eachindex(y)) do i
43+
Dual{T}(y[i], Partials(ntuple(k -> dy[k][i], Val(N))))
44+
end
4345
end
4446

45-
z_and_dz = Dual{T}(z, Partials(Tuple(zero(z) for k in 1:N)))
47+
z_and_dz = let z = z
48+
Dual{T}(z, Partials(ntuple(_ -> zero(z), Val(N))))
49+
end
4650

4751
return y_and_dy, z_and_dz
4852
end

0 commit comments

Comments
 (0)