Skip to content

Commit bb61c5f

Browse files
committed
Faster Nested AD
1 parent 29e971e commit bb61c5f

File tree

3 files changed

+81
-10
lines changed

3 files changed

+81
-10
lines changed

Manifest.toml

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
497497

498498
[[deps.Lux]]
499499
deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "WeightInitializers"]
500-
git-tree-sha1 = "d7f49df9abfbb372fcbde5f41e547aa3679e9793"
500+
git-tree-sha1 = "295c76513705518749fd4e151d9de77c75049d43"
501501
repo-rev = "ap/nested_ad"
502502
repo-url = "https://github.com/LuxDL/Lux.jl.git"
503503
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
@@ -573,12 +573,13 @@ version = "0.1.20"
573573
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
574574

575575
[[deps.LuxLib]]
576-
deps = ["ChainRulesCore", "FastClosures", "KernelAbstractions", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics"]
577-
git-tree-sha1 = "b1f81a8aa8313c1f1b4cbfb18733db17c023427e"
576+
deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"]
577+
git-tree-sha1 = "7cb3cdf01835d508f2c81e09d2e93f309434b5d6"
578578
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
579-
version = "0.3.14"
579+
version = "0.3.15"
580580

581581
[deps.LuxLib.extensions]
582+
LuxLibAMDGPUExt = "AMDGPU"
582583
LuxLibForwardDiffExt = "ForwardDiff"
583584
LuxLibReverseDiffExt = "ReverseDiff"
584585
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
@@ -684,6 +685,12 @@ git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
684685
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
685686
version = "1.6.3"
686687

688+
[[deps.PackageExtensionCompat]]
689+
git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518"
690+
uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930"
691+
version = "1.0.2"
692+
weakdeps = ["Requires", "TOML"]
693+
687694
[[deps.Parameters]]
688695
deps = ["OrderedCollections", "UnPack"]
689696
git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe"
@@ -927,6 +934,24 @@ git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682"
927934
uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da"
928935
version = "0.5.6"
929936

937+
[[deps.Strided]]
938+
deps = ["LinearAlgebra", "StridedViews", "TupleTools"]
939+
git-tree-sha1 = "40c69be0e1b72ee2f42923b7d1ff13e0b04e675c"
940+
uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
941+
version = "2.0.4"
942+
943+
[[deps.StridedViews]]
944+
deps = ["LinearAlgebra", "PackageExtensionCompat"]
945+
git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e"
946+
uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
947+
version = "0.2.2"
948+
949+
[deps.StridedViews.extensions]
950+
StridedViewsCUDAExt = "CUDA"
951+
952+
[deps.StridedViews.weakdeps]
953+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
954+
930955
[[deps.SuiteSparse]]
931956
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
932957
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
@@ -985,6 +1010,11 @@ git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1"
9851010
uuid = "781d530d-4396-4725-bb49-402e4bee1e77"
9861011
version = "1.4.0"
9871012

1013+
[[deps.TupleTools]]
1014+
git-tree-sha1 = "41d61b1c545b06279871ef1a4b5fcb2cac2191cd"
1015+
uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1016+
version = "1.5.0"
1017+
9881018
[[deps.UUIDs]]
9891019
deps = ["Random", "SHA"]
9901020
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2020
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
2121

2222
[weakdeps]
23+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2324
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2425
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2526
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2627

2728
[extensions]
2829
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
29-
DeepEquilibriumNetworksZygoteExt = "Zygote"
30+
DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"]
3031

3132
[compat]
3233
ADTypes = "0.2.5, 1"
@@ -38,6 +39,7 @@ ConstructionBase = "1"
3839
DiffEqBase = "6.119"
3940
ExplicitImports = "1.4.1"
4041
FastClosures = "0.3"
42+
ForwardDiff = "0.10.36"
4143
Functors = "0.4.10"
4244
LinearSolve = "2.21.2"
4345
Lux = "0.5.37"
Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,58 @@
11
module DeepEquilibriumNetworksZygoteExt
22

33
using ADTypes: AutoZygote
4+
using ChainRulesCore: ChainRulesCore
5+
using DeepEquilibriumNetworks: DEQs
46
using FastClosures: @closure
7+
using ForwardDiff: ForwardDiff # This is a dependency of Zygote
8+
using Lux: Lux, StatefulLuxLayer
59
using Statistics: mean
610
using Zygote: Zygote
7-
using DeepEquilibriumNetworks: DEQs
811

9-
@inline __tupleify(u) = @closure x -> (u, x)
12+
const CRC = ChainRulesCore
13+
14+
@inline __tupleify(x) = @closure(u->(u, x))
15+
16+
## One day we will overload DI's APIs for Lux Layers and we can remove this
17+
## Main challenge with overloading Zygote.pullback is that we need to return the correct
18+
## tangent for the pullback to compute the correct gradient, which is quite hard. But
19+
## wrapping the overall vjp is not that hard.
20+
@inline function __compute_vector_jacobian_product(model::StatefulLuxLayer, ps, z, x, rng)
21+
res, back = Zygote.pullback(model __tupleify(x), z)
22+
return only(back(DEQs.__gaussian_like(rng, res)))
23+
end
24+
25+
function CRC.rrule(
26+
::typeof(__compute_vector_jacobian_product), model::StatefulLuxLayer, ps, z, x, rng)
27+
res, back = Zygote.pullback(model __tupleify(x), z)
28+
ε = DEQs.__gaussian_like(rng, res)
29+
y = only(back(ε))
30+
∇internal_gradient_capture = Δ -> begin
31+
isa CRC.NoTangent || Δ isa CRC.ZeroTangent) &&
32+
return ntuple(Returns(CRC.NoTangent()), 6)
33+
34+
Δ_ = reshape(CRC.unthunk(Δ), size(z))
35+
36+
Tag = typeof(ForwardDiff.Tag(model, eltype(z)))
37+
partials = ForwardDiff.Partials{1, eltype(z)}.(tuple.(Δ_))
38+
z_dual = ForwardDiff.Dual{Tag, eltype(z), 1}.(z, partials)
39+
40+
_, pb_f = Zygote.pullback((x1, x2, p) -> model((x1, x2), p), z_dual, x, ps)
41+
∂z_duals, ∂x_duals, ∂ps_duals = pb_f(ε)
42+
43+
∂z = Lux.__partials(Tag, ∂z_duals, 1)
44+
∂x = Lux.__partials(Tag, ∂x_duals, 1)
45+
∂ps = Lux.__partials(Tag, ∂ps_duals, 1)
46+
47+
return CRC.NoTangent(), CRC.NoTangent(), ∂ps, ∂z, ∂x, CRC.NoTangent()
48+
end
49+
return y, ∇internal_gradient_capture
50+
end
1051

1152
## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33
1253
## FIXME: This will be broken in the new Lux release let's fix this
1354
function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng)
14-
res, back = Zygote.pullback(model __tupleify, z)
15-
vjp_z = only(back(DEQs.__gaussian_like(rng, res)))
16-
return mean(abs2, vjp_z)
55+
return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng))
1756
end
1857

1958
end

0 commit comments

Comments
 (0)