From 2d709661818af4d1750fa56f02facdb53bd673db Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 22 Aug 2021 10:15:14 +0100 Subject: [PATCH 01/36] Upgrade FiniteDifferences --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 245761cef..267c60845 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] BenchmarkTools = "0.5" Compat = "2, 3" -FiniteDifferences = "0.10" +FiniteDifferences = "0.12" OffsetArrays = "1" StaticArrays = "0.11, 0.12, 1" julia = "1" From f5308689d0d6263d0b2b84238de8d7ad9bc0a387 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 22 Aug 2021 13:47:51 +0100 Subject: [PATCH 02/36] Initial implementation --- src/ChainRulesCore.jl | 3 +- src/destructure.jl | 158 ++++++++++++++++++++++++++++++++++++++++++ test/destructure.jl | 37 ++++++++++ test/runtests.jl | 26 +++---- 4 files changed, 211 insertions(+), 13 deletions(-) create mode 100644 src/destructure.jl create mode 100644 test/destructure.jl diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 08ac3847c..a603f22cd 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -27,12 +27,13 @@ include("differentials/notimplemented.jl") include("differential_arithmetic.jl") include("accumulation.jl") -include("projection.jl") include("config.jl") include("rules.jl") include("rule_definition_tools.jl") +include("destructure.jl") + include("deprecated.jl") end # module diff --git a/src/destructure.jl b/src/destructure.jl new file mode 100644 index 000000000..52413d0db --- /dev/null +++ b/src/destructure.jl @@ -0,0 +1,158 @@ +# Fallbacks for destructure +destructure(X::AbstractArray) = collect(X) + +pushforward_of_destructure(X) = dX -> frule((NoTangent(), dX), destructure, X)[2] + +pullback_of_destructure(X) = dY -> rrule(destructure, X)[2](dY)[2] + +# Restructure machinery. +struct Restructure{P, D} + data::D +end + +pullback_of_restructure(X) = dY -> rrule(Restructure(X), destructure(X))[2](dY)[2] + + + + + +# Array +destructure(X::Array) = X + +frule((_, dX)::Tuple{Any, AbstractArray}, ::typeof(destructure), X::Array) = X, dX + +function rrule(::typeof(destructure), X::Array) + destructure_pullback(dXm::AbstractArray) = NoTangent(), dXm + return X, destructure_pullback +end + +Restructure(X::P) where {P<:Array} = Restructure{P, Nothing}(nothing) + +(r::Restructure{P})(X::Array) where {P<:Array} = X + +function frule( + (_, dX)::Tuple{Any, AbstractArray}, ::Restructure{P}, X::Array, +) where {P<:Array} + return X, dX +end + +function rrule(::Restructure{P}, X::Array) where {P<:Array} + restructure_pullback(dY::AbstractArray) = NoTangent(), dY + return X, restructure_pullback +end + + + + + +# Diagonal +destructure(X::Diagonal) = collect(X) + +function frule((_, dX)::Tuple{Any, Tangent}, ::typeof(destructure), X::Diagonal) + des_diag, d_des_diag = frule((NoTangent(), dX.diag), destructure, X.diag) + return collect(X), Diagonal(d_des_diag) +end + +function rrule(::typeof(destructure), X::P) where {P<:Diagonal} + _, des_diag_pb = rrule(destructure, X.diag) + function destructure_pullback(dY::AbstractMatrix) + d_des_diag = diag(dY) + _, d_diag = des_diag_pb(d_des_diag) + return NoTangent(), Tangent{P}(diag=d_diag) + end + return destructure(X), destructure_pullback +end + +Restructure(X::P) where {P<:Diagonal} = Restructure{P, Nothing}(nothing) + +function (r::Restructure{P})(X::Array) where {P<:Diagonal} + @assert isdiag(X) # for illustration. Remove in actual because numerics. + return Diagonal(diag(X)) +end + +function frule( + (_, dX)::Tuple{Any, AbstractArray}, ::Restructure{P}, X::Array, +) where {P<:Diagonal} + return Diagoonal(diag(X)), Tangent{P}(diag=diag(dX)) +end + +function rrule(::Restructure{P}, X::Array) where {P<:Diagonal} + restructure_pullback(dY::Tangent) = NoTangent(), Diagonal(dY.diag) + return X, restructure_pullback +end + + + + + +# Symmetric +function destructure(X::Symmetric) + des_data = destructure(X.data) + if X.uplo == 'U' + U = UpperTriangular(des_data) + return U + U' - Diagonal(des_data) + else + L = LowerTriangular(des_data) + return L' + L - Diagonal(X.data) + end +end + +# This gives you the natural tangent! +function frule((_, dx)::Tuple{Any, Tangent}, ::typeof(destructure), x::Symmetric) + des_data, d_des_data = frule((NoTangent(), dx.data), destructure, x.data) + + if x.uplo == 'U' + dU = UpperTriangular(d_des_data) + return destructure(x), dU + dU' - Diagonal(d_des_data) + else + dL = LowerTriangular(d_des_data) + return destructure(x), dL + dL' - Diagonal(d_des_data) + end +end + +function rrule(::typeof(destructure), X::P) where {P<:Symmetric} + function destructure_pullback(dXm::AbstractMatrix) + U = UpperTriangular(dXm) + L = LowerTriangular(dXm) + if X.uplo == 'U' + return NoTangent(), Tangent{P}(data=U + L' - Diagonal(dXm)) + else + return NoTangent(), Tangent{P}(data=U' + L - Diagonal(dXm)) + end + end + return destructure(X), destructure_pullback +end + +Restructure(X::P) where {P<:Symmetric} = Restructure{P, P}(X) + +# In generic-abstractarray-rrule land, assume getindex was used, so the +# strict-lower-triangle was never touched. +function (r::Restructure{P})(X::Array) where {P<:Symmetric} + strict_lower_triangle_of_data = LowerTriangular(r.data.data) - Diagonal(r.data.data) + return Symmetric(UpperTriangular(X) + strict_lower_triangle_of_data) +end + +function frule( + (_, dX)::Tuple{Any, AbstractArray}, r::Restructure{P}, X::Array, +) where {P<:Symmetric} + return r(X), Tangent{P}(data=UpperTriangular(X)) +end + +function rrule(r::Restructure{P}, X::Array) where {P<:Symmetric} + function restructure_pullback(dY::Tangent) + d_restructure = Tangent{Restructure{P}}(data=Tangent{P}(data=tril(dY.data))) + return d_restructure, UpperTriangular(dY.data) + end + return r(X), restructure_pullback +end + + + +# Cholesky -- you get to choose whatever destructuring operation is helpful for a given +# type. This one is helpful for writing generic pullbacks for `cholesky`, the output of +# which is a Cholesky. +# I've not completed the implementation, but it would just require a pushforward and a +# pullback. +destructure(C::Cholesky) = Cholesky(destructure(C.factors), C.uplo, C.info) + +# Restructure(C::P) where {P<:Cholesky} = Restructure{P, Nothing}() diff --git a/test/destructure.jl b/test/destructure.jl new file mode 100644 index 000000000..ae524fb07 --- /dev/null +++ b/test/destructure.jl @@ -0,0 +1,37 @@ +using ChainRulesCore: + destructure, + Restructure, + pushforward_of_destructure, + pullback_of_destructure, + pullback_of_restructure + +function check_destructure(X::AbstractArray, dX) + + # Check that the round-trip in tangent / cotangent space is the identity function. + pf_des = pushforward_of_destructure(X) + pb_des = pullback_of_destructure(X) + + # dX_dense = pf_des(dX) + # @test dX_dense ≈ pf_des(pb_des(dX_dense)) + + # Check that the round-trip is the identity function. + @test X ≈ Restructure(X)(destructure(X)) + + # Check the rrule for destructure. + pb_des = pullback_of_destructure(X) + pb_res = pullback_of_restructure(X) + dX_des = pb_res(dX) + + # I don't have access to test_approx here. + # Since I need to test these, maybe chunks of this PR belong in ChainRules.jl. + @test dX_des ≈ pb_res(pb_des(dX_des)) + + # # Check the rrule for restructure. + +end + +@testset "destructure" begin + check_destructure(randn(3, 3), randn(3, 3)) + check_destructure(Diagonal(randn(3)), Tangent{Diagonal}(diag=randn(3))) + check_destructure(Symmetric(randn(3, 3)), Tangent{Symmetric}(data=randn(3, 3))) +end diff --git a/test/runtests.jl b/test/runtests.jl index e4499b0ff..199a0e078 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Base.Broadcast: broadcastable using BenchmarkTools using ChainRulesCore +using FiniteDifferences using LinearAlgebra using LinearAlgebra.BLAS: ger!, gemv!, gemv, scal! using StaticArrays @@ -8,19 +9,20 @@ using SparseArrays using Test @testset "ChainRulesCore" begin - @testset "differentials" begin - include("differentials/abstract_zero.jl") - include("differentials/thunks.jl") - include("differentials/composite.jl") - include("differentials/notimplemented.jl") - end + # @testset "differentials" begin + # include("differentials/abstract_zero.jl") + # include("differentials/thunks.jl") + # include("differentials/composite.jl") + # include("differentials/notimplemented.jl") + # end - include("accumulation.jl") - include("projection.jl") + # include("accumulation.jl") - include("rules.jl") - include("rule_definition_tools.jl") - include("config.jl") + # include("rules.jl") + # include("rule_definition_tools.jl") + # include("config.jl") - include("deprecated.jl") + include("destructure.jl") + + # include("deprecated.jl") end From 31f67d36b0a3ed5c3a592f50ba2c897741ab4159 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 22 Aug 2021 22:43:16 +0100 Subject: [PATCH 03/36] Some more work --- src/ChainRulesCore.jl | 1 + src/destructure.jl | 10 +++--- test/destructure.jl | 84 ++++++++++++++++++++++++++++++++++--------- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index a603f22cd..e3191795f 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -27,6 +27,7 @@ include("differentials/notimplemented.jl") include("differential_arithmetic.jl") include("accumulation.jl") +include("projection.jl") include("config.jl") include("rules.jl") diff --git a/src/destructure.jl b/src/destructure.jl index 52413d0db..a975f4dcc 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -50,7 +50,7 @@ destructure(X::Diagonal) = collect(X) function frule((_, dX)::Tuple{Any, Tangent}, ::typeof(destructure), X::Diagonal) des_diag, d_des_diag = frule((NoTangent(), dX.diag), destructure, X.diag) - return collect(X), Diagonal(d_des_diag) + return collect(X), collect(Diagonal(d_des_diag)) end function rrule(::typeof(destructure), X::P) where {P<:Diagonal} @@ -73,7 +73,7 @@ end function frule( (_, dX)::Tuple{Any, AbstractArray}, ::Restructure{P}, X::Array, ) where {P<:Diagonal} - return Diagoonal(diag(X)), Tangent{P}(diag=diag(dX)) + return Diagonal(diag(X)), Tangent{P}(diag=diag(dX)) end function rrule(::Restructure{P}, X::Array) where {P<:Diagonal} @@ -135,19 +135,21 @@ end function frule( (_, dX)::Tuple{Any, AbstractArray}, r::Restructure{P}, X::Array, ) where {P<:Symmetric} - return r(X), Tangent{P}(data=UpperTriangular(X)) + return r(X), Tangent{P}(data=UpperTriangular(dX)) end function rrule(r::Restructure{P}, X::Array) where {P<:Symmetric} function restructure_pullback(dY::Tangent) d_restructure = Tangent{Restructure{P}}(data=Tangent{P}(data=tril(dY.data))) - return d_restructure, UpperTriangular(dY.data) + return d_restructure, collect(UpperTriangular(dY.data)) end return r(X), restructure_pullback end + + # Cholesky -- you get to choose whatever destructuring operation is helpful for a given # type. This one is helpful for writing generic pullbacks for `cholesky`, the output of # which is a Cholesky. diff --git a/test/destructure.jl b/test/destructure.jl index ae524fb07..2cae67cb8 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -5,33 +5,85 @@ using ChainRulesCore: pullback_of_destructure, pullback_of_restructure -function check_destructure(X::AbstractArray, dX) +# Need structural versions for tests, rather than the thing we currently have in +# FiniteDifferences. +function FiniteDifferences.to_vec(X::Diagonal) + diag_vec, diag_from_vec = to_vec(X.diag) + Diagonal_from_vec(diag_vec) = Diagonal(diag_from_vec(diag_vec)) + return diag_vec, Diagonal_from_vec +end + +function FiniteDifferences.to_vec(X::Symmetric) + data_vec, data_from_vec = to_vec(X.data) + Symmetric_from_vec(data_vec) = Symmetric(data_from_vec(data_vec)) + return data_vec, Symmetric_from_vec +end + +interpret_as_Tangent(x::Array) = x + +interpret_as_Tangent(d::Diagonal) = Tangent{Diagonal}(diag=d.diag) + +interpret_as_Tangent(s::Symmetric) = Tangent{Symmetric}(data=s.data) + +Base.isapprox(t::Tangent{<:Diagonal}, d::Diagonal) = isapprox(t.diag, d.diag) + +Base.isapprox(t::Tangent{<:Symmetric}, s::Symmetric) = isapprox(t.data, s.data) + +function check_destructure(x::AbstractArray, ȳ, ẋ) + + # Verify correctness of frule. + yf, ẏ = frule((NoTangent(), ẋ), destructure, x) + @test yf ≈ destructure(x) + + ẏ_fd = jvp(central_fdm(5, 1), destructure, (x, ẋ)) + @test ẏ ≈ ẏ_fd + + yr, pb = rrule(destructure, x) + _, x̄_r = pb(ȳ) + @test yr ≈ destructure(x) + + # Use inner product relationship to avoid needing CRTU. + @test dot(ȳ, ẏ_fd) ≈ dot(x̄_r, ẋ) # Check that the round-trip in tangent / cotangent space is the identity function. - pf_des = pushforward_of_destructure(X) - pb_des = pullback_of_destructure(X) + pf_des = pushforward_of_destructure(x) + pb_des = pullback_of_destructure(x) - # dX_dense = pf_des(dX) - # @test dX_dense ≈ pf_des(pb_des(dX_dense)) + ẋ_dense = pf_des(ẋ) + # @test ẋ_dense ≈ pf_des(pb_des(ẋ_dense)) # Check that the round-trip is the identity function. - @test X ≈ Restructure(X)(destructure(X)) + @test x ≈ Restructure(x)(destructure(x)) + + # Verify frule of restructure. + x_dense = destructure(x) + x_re, ẋ_re = frule((NoTangent(), ẋ_dense), Restructure(x), x_dense) + @test x_re ≈ Restructure(x)(x_dense) + + ẋ_re_fd = FiniteDifferences.jvp(central_fdm(5, 1), Restructure(x), (x_dense, ẋ_dense)) + @test ẋ_re ≈ ẋ_re_fd + + # Verify rrule of restructure. + x_re_r, pb_r = rrule(Restructure(x), x_dense) + _, x̄_dense = pb_r(ẋ) + + # ẋ serves as the cotangent for the reconstructed x + @test dot(ẋ, interpret_as_Tangent(ẋ_re_fd)) ≈ dot(x̄_dense, ẋ_dense) # Check the rrule for destructure. - pb_des = pullback_of_destructure(X) - pb_res = pullback_of_restructure(X) - dX_des = pb_res(dX) + pb_des = pullback_of_destructure(x) + pb_res = pullback_of_restructure(x) + x̄_des = pb_res(ẋ) # I don't have access to test_approx here. # Since I need to test these, maybe chunks of this PR belong in ChainRules.jl. - @test dX_des ≈ pb_res(pb_des(dX_des)) - - # # Check the rrule for restructure. - + @test x̄_des ≈ pb_res(pb_des(x̄_des)) end @testset "destructure" begin - check_destructure(randn(3, 3), randn(3, 3)) - check_destructure(Diagonal(randn(3)), Tangent{Diagonal}(diag=randn(3))) - check_destructure(Symmetric(randn(3, 3)), Tangent{Symmetric}(data=randn(3, 3))) + check_destructure(randn(3, 3), randn(3, 3), randn(3, 3)) + check_destructure(Diagonal(randn(3)), randn(3, 3), Tangent{Diagonal}(diag=randn(3))) + check_destructure( + Symmetric(randn(3, 3)), randn(3, 3), Tangent{Symmetric}(data=randn(3, 3)), + ) end From 4b554c5370f97db180f4addeed358c60f9515d0f Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 14:37:36 +0100 Subject: [PATCH 04/36] Add testset comment --- test/destructure.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/destructure.jl b/test/destructure.jl index 2cae67cb8..5c37d2a13 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -49,6 +49,9 @@ function check_destructure(x::AbstractArray, ȳ, ẋ) pf_des = pushforward_of_destructure(x) pb_des = pullback_of_destructure(x) + # I thought that maybe the pushforward of destructure would be equivalent to the + # pullback of restructure, but that doesn't seem to hold. Not sure why / whether I + # should have thought it might be a thing in the first place. ẋ_dense = pf_des(ẋ) # @test ẋ_dense ≈ pf_des(pb_des(ẋ_dense)) From a51abe9d43634939f5317cc64e22184299dc436a Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 17:24:54 +0100 Subject: [PATCH 05/36] Examples and notes --- examples.jl | 462 ++++++++++++++++++++++++++++++++++++++++++++++++++++ notes.md | 68 ++++++++ 2 files changed, 530 insertions(+) create mode 100644 examples.jl create mode 100644 notes.md diff --git a/examples.jl b/examples.jl new file mode 100644 index 000000000..a32484fcc --- /dev/null +++ b/examples.jl @@ -0,0 +1,462 @@ +using ChainRulesCore +using ChainRulesTestUtils +using FiniteDifferences +using LinearAlgebra +using Zygote + +function ChainRulesCore.rrule(::typeof(getindex), x::Symmetric, p::Int, q::Int) + function structural_getindex_pullback(dy) + ddata = zeros(size(x.data)) + if p > q + ddata[q, p] = dy + else + ddata[p, q] = dy + end + return NoTangent(), Tangent{Symmetric}(data=ddata), NoTangent(), NoTangent() + end + return getindex(x, p, q), structural_getindex_pullback +end + +function my_mul(X::AbstractMatrix{Float64}, Y::AbstractMatrix{Float64}) + y1 = [Y[1, 1], Y[2, 1]] + y2 = [Y[1, 2], Y[2, 2]] + return reshape([X[1, :]'y1, X[2, :]'y1, X[1, :]'y2, X[2, :]'y2], 2, 2) +end + +X = randn(2, 2) +Y = Symmetric(randn(2, 2)) +Z, pb = Zygote.pullback(my_mul, X, Y) + +Z̄ = randn(4) +X̄, Ȳ_zygote = pb(Z̄) + +# Convert Ȳ to Tangent. +Ȳ = Tangent{typeof(Y)}(data=Ȳ_zygote.data) + +# Essentially produces a structural tangent. +function FiniteDifferences.to_vec(X::Symmetric) + x_vec, parent_from_vec = to_vec(X.data) + function Symmetric_from_vec(x) + return Symmetric(parent_from_vec(x)) + end + return x_vec, Symmetric_from_vec +end + +X̄_fd, Ȳ_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, Z̄, X, Y) + +# to_vec doesn't know how to make `Tangent`s, so instead I map it to a `Tangent` manually. +Ȳ_fd = Tangent{typeof(Y)}(data=Ȳ_fd_sym.data) + +Z_m, pb_m = Zygote.pullback(*, X, Y) +X̄_m, Ȳ_m = pb_m(reshape(Z̄, 2, 2)) + +# This is fine. +test_approx(X̄, X̄_fd) +@assert X + X̄ ≈ X + X̄_fd + +# This is fine. +test_approx(X̄, X̄_m) +@assert X + X̄ ≈ X + X̄_m + +# This is fine. +test_approx(Ȳ, Ȳ_fd) +@assert Y + Ȳ ≈ Y + Ȳ_fd + +# This doesn't pass. To be expected, because Ȳ_m is a natural, and Ȳ a structural. +test_approx(Ȳ, Ȳ_m) +@assert Y + Ȳ_m ≈ Y + Ȳ_fd + + +A = randn(3, 2); +B = randn(2, 4); +C, pb = Zygote.pullback(*, A, B); +C̄ = randn(3, 4); +Ā, B̄ = pb(C̄); + +Cm, pbm = Zygote.pullback(my_mul, A, B); +Ām, B̄m = pbm(C̄) + +@assert C ≈ Cm +@assert Ā ≈ Ām +@assert B̄ ≈ B̄m + +# Essentially the same as `collect`, but we get to choose semantics. +# I would imagine that `collect` will give us what we need most of the time, but sometimes +# it might not do what we want if e.g. the array in question lives on another device, or +# collect isn't implemented in a differentiable manner, or Zygote already implements the +# rrule for collect in a manner that confuses structurals and naturals. +destructure = collect + +# I've had to implement a number of new functions here to ensure that they do things +# structurally, because Zygote currently has a number of non-structural implementations +# of these things. + +# THIS GUARANTEES ROUND-TRIP CONSISTENCY! + +# IMPLEMENTATION 1: Very literal implementation. Not optimal, but hopefully the clearest +# about what is going on. + +using ChainRulesCore +using ChainRulesTestUtils +using FiniteDifferences +using LinearAlgebra +using Zygote + +# destructure is probably usually similar to collect, but we get to pick whatever semantics +# turn out to be useful. + +# A real win of this approach is that we can test the correctness of people's +# destructure and restructure pullbacks using CRTU as per usual. +# We also just have a single simple requirement on the nature of destructure and +# restructure: restructure(destructure(X)) must be identical to X. Stronger than `==`. +function destructure(X::Symmetric) + @assert X.uplo == 'U' + return UpperTriangular(X.data) + UpperTriangular(X.data)' - Diagonal(X.data) +end + +# Shouldn't need an rrule for this, since the above operations should all be fine, but +# Zygote currently has implementations of these that aren't structural, which is a problem. +function ChainRulesCore.rrule(::typeof(destructure), X::Symmetric) + # As the type author in this context, I get to assert back type comes back. + # I might also have chosen e.g. a GPUArray + function destructure_pullback(dXm::Matrix) + return NoTangent(), Tangent{Symmetric}(data=UpperTriangular(dXm) + LowerTriangular(dXm)' - Diagonal(dXm)) + end + return destructure(X), destructure_pullback +end + +destructure(X::Matrix) = X +function ChainRulesCore.rrule(::typeof(destructure), X::Matrix) + destructure_pullback(dXm::Matrix) = NoTangent(), dXm + return X, destructure_pullback +end + +struct Restructure{P, D} + required_primal_info::D +end + +Restructure(X::P) where {P<:Matrix} = Restructure{P, Nothing}(nothing) + +# Since the operation in question will return a `Matrix`, I don't need restructure for +# Symmetric matrices in this instance. +restructure(::Restructure{<:Matrix}, X::Matrix) = X + +function ChainRulesCore.rrule(::typeof(restructure), ::Restructure{<:Matrix}, X::Matrix) + restructure_matrix_pullback(dXm::Matrix) = NoTangent(), NoTangent(), dXm + return X, restructure_matrix_pullback +end + +Restructure(X::P) where {P<:Symmetric} = Restructure{P, Nothing}(nothing) + +function restructure(r::Restructure{<:Symmetric}, X::Matrix) + @assert issymmetric(X) + return Symmetric(X) +end + +function ChainRulesCore.rrule(::typeof(restructure), r::Restructure{<:Symmetric}, X::Matrix) + function restructure_Symmetric_pullback(dX::Tangent) + return NoTangent(), NoTangent(), dX.data + end + return restructure(r, X), restructure_Symmetric_pullback +end + + +my_mul(A::AbstractMatrix, B::AbstractMatrix) = A * B + +function ChainRulesCore.rrule(::typeof(my_mul), A::Matrix, B::Matrix) + function my_mul_pullback(C::Matrix) + return NoTangent(), C * B', A' * C + end + return A * B, my_mul_pullback +end + +# Could also use AD inside this definition. +function ChainRulesCore.rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix) + + # Produce dense versions of A and B, and the pullbacks of this operation. + Am, destructure_A_pb = ChainRulesCore.rrule(destructure, A) + Bm, destructure_B_pb = ChainRulesCore.rrule(destructure, B) + + # Compute the rrule in dense-land. This works by assumption. + Cm, my_mul_strict_pullback = ChainRulesCore.rrule(my_mul, Am, Bm) + + # We need the output from the usual forwards pass in order to guarantee that we can + # recover the correct structured type on the output side. + C = my_mul(A, B) + + # Get the structured version back. + _, restructure_C_pb = Zygote._pullback(Zygote.Context(), Restructure(C), Cm) + + # Note that I'm insisting on a `Tangent` here. Would also need to cover Thunks. + function my_mul_generic_pullback(dC) + _, dCm = restructure_C_pb(dC) + _, dAm, dBm = my_mul_strict_pullback(dCm) + _, dA = destructure_A_pb(dAm) + _, dB = destructure_B_pb(dBm) + return NoTangent(), dA, dB + end + + return C, my_mul_generic_pullback +end + +A = randn(4, 3) +B = Symmetric(randn(3, 3)) +C, pb = Zygote.pullback(my_mul, A, B) + +@assert C ≈ my_mul(A, B) + +dC = randn(4, 3) +dA, dB_zg = pb(dC) +dB = Tangent{typeof(B)}(data=dB_zg.data) + + +# Test correctness. +dA_fd, dB_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, dC, A, B) + +# to_vec doesn't know how to make `Tangent`s, so instead I map it to a `Tangent` manually. +dB_fd = Tangent{typeof(B)}(data=dB_fd_sym.data) + +test_approx(dA, dA_fd) +test_approx(dB, dB_fd) + + + + + +# Example 2: something where the output isn't a matrix. +my_sum(x::AbstractArray) = sum(x) + +function ChainRulesCore.rrule(::typeof(my_sum), x::Array) + my_sum_strict_pullback(dy::Real) = (NoTangent(), dy * ones(size(x))) + return sum(x), my_sum_strict_pullback +end + +function ChainRulesCore.rrule(::typeof(my_sum), x::AbstractArray) + x_dense, destructure_pb = ChainRulesCore.rrule(destructure, x) + y, my_sum_strict_pullback = ChainRulesCore.rrule(my_sum, x_dense) + + function my_sum_generic_pullback(dy::Real) + _, dx_dense = my_sum_strict_pullback(dy) + _, dx = destructure_pb(dx_dense) + return NoTangent(), dx + end + + return y, my_sum_generic_pullback +end + +A = Symmetric(randn(2, 2)) +y, pb = Zygote.pullback(my_sum, A) + +test_approx(y, my_sum(A)) + +dy = randn() +dA_zg, = pb(dy) +dA = Tangent{typeof(A)}(data=dA_zg.data) + +dA_fd_sym, = FiniteDifferences.j′vp(central_fdm(5, 1), my_sum, dy, A) +dA_fd = Tangent{typeof(A)}(data=dA_fd_sym.data) + +test_approx(dA, dA_fd) + + + + + +# Example 3: structured-input-structured-output + +my_scale(a::Real, x::AbstractMatrix) = a * x + +function ChainRulesCore.rrule(::typeof(my_inv), x::Matrix) + + y, pb = ChainRulesCore.rrule(inv, x) + + # We know that a * x isa Array. Any AbstractArray is an okay tangent for an Array. + function my_scale_pullback(ȳ::AbstractArray) + return NoTangent(), dot(ȳ, x), ȳ * a + end + return a * x, my_scale_pullback +end + +function ChainRulesCore.rrule(::typeof(my_scale), a::Real, x::AbstractMatrix) + x_dense, destructure_x_pb = ChainRulesCore.rrule(destructure, x) + y_dense, my_scale_strict_pb = ChainRulesCore.rrule(my_scale, a, x_dense) + y = my_scale(a, x) + y_reconstruct, restructure_pb = ChainRulesCore.rrule(Restructure(y), y_dense) + + function my_scale_generic_pullback(dy) + _, dy_dense = restructure_pb(dy) + _, da, dx_dense = my_scale_strict_pb(dy_dense) + _, dx = destructure_x_pb(dx_dense) + return NoTangent(), da, dx + end + + return y_reconstruct, my_scale_generic_pullback +end + +Zygote.refresh() + +# SYMMETRIC TEST + +a = randn() +x = Symmetric(randn(2, 2)) +y, pb = Zygote.pullback(my_scale, a, x) + +dy = Tangent{typeof(y)}(data=randn(2, 2)) +da, dx_zg = pb(dy) +dx = Tangent{typeof(x)}(data=dx_zg.data) + +da_fd, dx_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x) +dx_fd = Tangent{typeof(x)}(data=dx_fd_sym.data) + +test_approx(y.data, my_scale(a, x).data) +test_approx(da, da_fd) +test_approx(dx, dx_fd) + +# DENSE TEST +x_dense = collect(x) +y, pb = Zygote.pullback(my_scale, a, x_dense) + +dy = randn(size(y)) +da, dx = pb(dy) + +da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x_dense) + +test_approx(y, my_scale(a, x_dense)) +test_approx(da, da_fd) +test_approx(dx, dx_fd) + + + + + +# Example 4: ScaledVector + +struct ScaledVector <: AbstractVector{Float64} + v::Vector{Float64} + α::Float64 +end + +Base.getindex(x::ScaledVector, n::Int) = x.α * x.v[n] + +Base.size(x::ScaledVector) = size(x.v) + +ChainRulesCore.destructure(x::ScaledVector) = x.α * x.v + +ChainRulesCore.Restructure(x::P) where {P<:ScaledVector} = Restructure{P}(x.α) + +(r::Restructure{<:ScaledVector})(x::AbstractVector) = ScaledVector(r.α, x ./ r.α) + + + + + +# ALTERNATIVELY: just do forwards-mode through destructure, and avoid the need to implement +# restructure entirely. I think... hopefully this is correct? + + +# Under the current implementation, we have to do lots of things twice. +# Is there an implementation in which we don't have this problem? + +function ChainRulesCore.rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix) + + # It's possible that our generic types have an optimised forwards pass involving + # mutation. We would like to exploit this. + C = A * B + + # Get the pullback for destructuring the arguments. + # Currently actually does the destructuring, but this will often be entirely + # unnecessary. Semantically, this is what we want though. + _, destructure_A_pb = ChainRulesCore.rrule(destructure, A) + _, destructure_B_pb = ChainRulesCore.rrule(destructure, B) + + # Somehow, we need to know how destructure would work, so that we can get its pullback. + # This remains a very literal implementation -- it'll generally cheaper in practice. + C_dense = destructure(C) + _, restructure_C_pb = ChainRulesCore.rrule(restructure, Restructure(C), C_dense) + + # Restricted to Tangent for illustration purposes. + # Does not permit a natural. + # Thunk also needs to be supported. + function my_mul_pullback(dC::Tangent) + + # Obtain natural tangent for output via restructure pullback. + dC_natural = restructure_C_pb(dC) + + # Code to implement pullback in a generic manner, using natural tangents. + dA_natural = dC_natural * B' + dB_natural = A' * dC_natural + + # Obtain structural tangents for inputs via destructure pullback. + _, dA = destructure_A_pb(dA_natural) + _, dB = destructure_B_pb(dB_natural) + + return NoTangent(), dA, dB + end + + return C, my_mul_pullback +end + +# FORWARD-MODE AD THROUGH DESTRUCTURE YIELDS THE NATURAL TANGENT! +# The important thing is that e.g. a `Diagonal` is a valid tangent for a `Matrix`, because +# you can always find a matrix to which it is `==`. + +# CLAIM: any AbstractArray is an acceptable tangent type for an Array. + +# CLAIM: every AbstractArray has a natural tangent, induced by running forwards-mode on +# destructure. + + + + +# PR format: +1. basic claim -- find an equivalent programme, and work with that. +2. + + + + + + +In my on-going mission to figure out what these natural tangent things are really about, I've arrived at a scheme which gives us the following: + +1. a generic construction for deriving generic rrules in terms of an equivalent primal programme, +2. a candidate method for formalising natural tangents as the result of doing AD on pieces of this equivalent primal programme. + +The explanation of this PR will come in two chunks: +1. an explanation of the equivalent programme and its implications for natural tangents, and +2. approaches to optimising AD in the equivalent programme without changing its semantics. + +I'm explaining it in this order because in my explanation of the programme, I'll have to run AD twice. This is useful for explaining what's going on, but isn't needed in practice. + +# A Sketch of the Equivalent Programme + +To begin with, consider a function +```julia +foo(x::AbstractArray)::AbstractArray +``` +for which we want to write a generic `rrule`. To achieve this, assume that we have access to the following two functions: +```julia +destructure(::AbstractArray)::Array +``` +defined such that +```julia +foo_equiv(x) = restructure(foo(x), foo(destructure(x))) +struct_isapprox(foo_equiv(x), foo(x)) +``` +for all `x`, where `struct_isapprox(a, b)` is defined to mean that all of the fields of `a` and `b` must be `struct_isapprox` with each other (i.e. the default version of `==` that people often ask for), with appropriate base-cases defined for non-composite types. + +The first argument of `restructure` tells it what bit of data to aim for, and the second is the output of running `foo` on the `Array` which `==` `x`. Furthermore, assume that the pullback of `restructure` with respect to its first argument is always `ZeroTangent` or `NoTangent`, meaning that there's no need to AD back through it. This is the case for all of the types I've encountered so far, but I'm sure that there are types for which it will not be the case. + +Hopefully it's clear that running AD on `foo_equiv` will yield the same answer as running AD on `foo`, up to known generic AD limitations (things like `x == 0 ? 0 : x` giving the wrong answer at `0`). Moreover, it is hopefully clear that `destructure` is trivial to implement -- `collect` will do. `restructure` is often simple -- for example, +```julia +restructure(D::Diagonal) +``` + + + +# Relaxing the Formulation a Bit + + +# Assumptions + +1. Methods of `foo` specialised to specific subtypes of `AbstractArray` access via `getindex`. I believe we assume this implicitly generic rrules currently, and I believe that we need this assumption to guarantee correctness (I can construct an example that gives the wrong answer if this assumption is violated). diff --git a/notes.md b/notes.md new file mode 100644 index 000000000..6227e3429 --- /dev/null +++ b/notes.md @@ -0,0 +1,68 @@ +# A General Mechanism for Generic Rules for AbstractArrays + +That we don't have a general formalism for deriving natural derivatives has been discussed quite a bit recently. As has our lack of understanding of the precise relationship between the generic rrules we're writing, and what AD would do. This PR proposes a recipe for deriving generic rules, which leads to a possible formalism for natural derivatives. This formalism can be applied to any AbstractArray, and AD can be in principle be used to obtain default values for the natural tangent. + +I want reviewers to determine whether they agree that the proposed recipe +1. is sufficient for implementing generic rrules on AbstractArrays, +2. is correct, in the sense that it produces the same answer as AD, and +3. the definition of natural tangents proposed indeed applies to any AbstractArray, and broadly agrees with our intuitions about what a natural tangent should be. + +This is a long read. I've tried to condense where possible. + + + +## Starting Point + +Imagine a clean-slate version of Zygote / Diffractor, in which any rules writen are either +1. necessary to define what AD should do (e.g. `+` or `*` on `Float64`s, `getindex` on `Array`s, `getfield` on composite types, etc), or +2. produce _exactly_ the same answer that AD would produce -- crucially they always return a structural tangent for a non-primitive composite type. + +Moreover, assume that users provide structural tangents -- we'll show how to remove this particular assumption later. + +The consequence of the above is that we can safely assume that the input and output of any rule will be either +1. a structural tangent (if the primal is a non-primitive composite type) +2. whatever we have defined its tangent type to be (if it's a primitive type like `Float64` or `Array`). + +I'm intentionally not trying to define precisely what a primitive is, but will assume that everyone agrees that `Array`s are an example of a primitive, and that we are happy with representing the tangent of an `Array` by another `Array`. + +Assume that structural tangents are a valid representation of a tangent of any non-primitive composite type, convenience for rule-writing aside. + +More generally, assume that if Zygote / Diffractor successfully run on a given function under the above assumptions, they give the answer desired (the "correct" answer). Consequently, the core goals of the proposed recipe are to make it possible to both +1. never write a rule which prevents Zygote / Diffractor from differentiating a programme that they already know how to differentiate, +2. make it easy to write rules using intuitive representations of tangents. + + + + + +## The Formalism + +First consider a specific case -- we'll both generalise and optimise the implementation later. + +Consider a function `f(x::AbstractArray) -> y::AbstractArray`. Lets assume that there's just one method, so we can be sure that a generic fallback will be hit, regardless the concrete type of the argument. + +The intuition behind the recipe is to find a function which is equivalent to `f`, whose rules we know how to write safely. If we can find such a function, AD-ing it will clearly give the correct answer -- the following lays out an approach to doing this. + +The recipe is: +1. Map `x` to an `Array`, `x_dense`, using `getindex`. Call this operation `destructure`. +2. Apply `f` to `x_dense` to obtain `y_dense`. +3. Map `y_dense` onto `y`. Call this operation `restructure`. + +I'm going to define equivalence of output structurally -- two `AbstractArray`s are equal if +1. they are primitives, and whatever notion of equality we have for them holds, or +2. they are composite types, and each of their fields are equal under this definition. + +The reason for this notion of equality is that AD (as proposed above) treats concrete subtypes of AbstractArray no differently from any other composite type. + +Step 2 of the recipe is possible to easily write an rrule for, because we know that its arguments are `Array`s. Step 1 is clearly defined for any `AbstractArray`, so we can implement e.g. a pullback for `destructure` which accepts an `Array` and returns a `Tangent`. + +Step 3 is the trickier step. We'll get on to it later. + +The PR shows how to implement steps 1 and 3 for `Array`s (trivial), `Diagonal`s, and `Symmetric`s. + + + +### The Internals of `f` matter for consistency with AD + +Must only use the AbstractArrays API (no access to internal fields, just uses `getindex` and +`size`). From 8205ce6293638a0448acd429e2240162d374d0b0 Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 21:06:14 +0100 Subject: [PATCH 06/36] Some work --- examples.jl | 416 ++++++++++++----------------------- notes.md | 172 ++++++++++++++- src/rule_definition_tools.jl | 42 ++++ 3 files changed, 342 insertions(+), 288 deletions(-) diff --git a/examples.jl b/examples.jl index a32484fcc..c1c0f6317 100644 --- a/examples.jl +++ b/examples.jl @@ -4,225 +4,48 @@ using FiniteDifferences using LinearAlgebra using Zygote -function ChainRulesCore.rrule(::typeof(getindex), x::Symmetric, p::Int, q::Int) - function structural_getindex_pullback(dy) - ddata = zeros(size(x.data)) - if p > q - ddata[q, p] = dy - else - ddata[p, q] = dy - end - return NoTangent(), Tangent{Symmetric}(data=ddata), NoTangent(), NoTangent() - end - return getindex(x, p, q), structural_getindex_pullback -end - -function my_mul(X::AbstractMatrix{Float64}, Y::AbstractMatrix{Float64}) - y1 = [Y[1, 1], Y[2, 1]] - y2 = [Y[1, 2], Y[2, 2]] - return reshape([X[1, :]'y1, X[2, :]'y1, X[1, :]'y2, X[2, :]'y2], 2, 2) -end - -X = randn(2, 2) -Y = Symmetric(randn(2, 2)) -Z, pb = Zygote.pullback(my_mul, X, Y) - -Z̄ = randn(4) -X̄, Ȳ_zygote = pb(Z̄) - -# Convert Ȳ to Tangent. -Ȳ = Tangent{typeof(Y)}(data=Ȳ_zygote.data) - -# Essentially produces a structural tangent. -function FiniteDifferences.to_vec(X::Symmetric) - x_vec, parent_from_vec = to_vec(X.data) - function Symmetric_from_vec(x) - return Symmetric(parent_from_vec(x)) - end - return x_vec, Symmetric_from_vec -end - -X̄_fd, Ȳ_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, Z̄, X, Y) - -# to_vec doesn't know how to make `Tangent`s, so instead I map it to a `Tangent` manually. -Ȳ_fd = Tangent{typeof(Y)}(data=Ȳ_fd_sym.data) - -Z_m, pb_m = Zygote.pullback(*, X, Y) -X̄_m, Ȳ_m = pb_m(reshape(Z̄, 2, 2)) - -# This is fine. -test_approx(X̄, X̄_fd) -@assert X + X̄ ≈ X + X̄_fd - -# This is fine. -test_approx(X̄, X̄_m) -@assert X + X̄ ≈ X + X̄_m - -# This is fine. -test_approx(Ȳ, Ȳ_fd) -@assert Y + Ȳ ≈ Y + Ȳ_fd - -# This doesn't pass. To be expected, because Ȳ_m is a natural, and Ȳ a structural. -test_approx(Ȳ, Ȳ_m) -@assert Y + Ȳ_m ≈ Y + Ȳ_fd - - -A = randn(3, 2); -B = randn(2, 4); -C, pb = Zygote.pullback(*, A, B); -C̄ = randn(3, 4); -Ā, B̄ = pb(C̄); - -Cm, pbm = Zygote.pullback(my_mul, A, B); -Ām, B̄m = pbm(C̄) - -@assert C ≈ Cm -@assert Ā ≈ Ām -@assert B̄ ≈ B̄m +import ChainRulesCore: rrule -# Essentially the same as `collect`, but we get to choose semantics. -# I would imagine that `collect` will give us what we need most of the time, but sometimes -# it might not do what we want if e.g. the array in question lives on another device, or -# collect isn't implemented in a differentiable manner, or Zygote already implements the -# rrule for collect in a manner that confuses structurals and naturals. -destructure = collect +using ChainRulesCore: + pullback_of_destructure, + pullback_of_restructure, + RuleConfig, + wrap_natural_pullback -# I've had to implement a number of new functions here to ensure that they do things -# structurally, because Zygote currently has a number of non-structural implementations -# of these things. - -# THIS GUARANTEES ROUND-TRIP CONSISTENCY! - -# IMPLEMENTATION 1: Very literal implementation. Not optimal, but hopefully the clearest -# about what is going on. - -using ChainRulesCore -using ChainRulesTestUtils -using FiniteDifferences -using LinearAlgebra -using Zygote - -# destructure is probably usually similar to collect, but we get to pick whatever semantics -# turn out to be useful. - -# A real win of this approach is that we can test the correctness of people's -# destructure and restructure pullbacks using CRTU as per usual. -# We also just have a single simple requirement on the nature of destructure and -# restructure: restructure(destructure(X)) must be identical to X. Stronger than `==`. -function destructure(X::Symmetric) - @assert X.uplo == 'U' - return UpperTriangular(X.data) + UpperTriangular(X.data)' - Diagonal(X.data) -end - -# Shouldn't need an rrule for this, since the above operations should all be fine, but -# Zygote currently has implementations of these that aren't structural, which is a problem. -function ChainRulesCore.rrule(::typeof(destructure), X::Symmetric) - # As the type author in this context, I get to assert back type comes back. - # I might also have chosen e.g. a GPUArray - function destructure_pullback(dXm::Matrix) - return NoTangent(), Tangent{Symmetric}(data=UpperTriangular(dXm) + LowerTriangular(dXm)' - Diagonal(dXm)) - end - return destructure(X), destructure_pullback -end - -destructure(X::Matrix) = X -function ChainRulesCore.rrule(::typeof(destructure), X::Matrix) - destructure_pullback(dXm::Matrix) = NoTangent(), dXm - return X, destructure_pullback -end - -struct Restructure{P, D} - required_primal_info::D -end - -Restructure(X::P) where {P<:Matrix} = Restructure{P, Nothing}(nothing) - -# Since the operation in question will return a `Matrix`, I don't need restructure for -# Symmetric matrices in this instance. -restructure(::Restructure{<:Matrix}, X::Matrix) = X - -function ChainRulesCore.rrule(::typeof(restructure), ::Restructure{<:Matrix}, X::Matrix) - restructure_matrix_pullback(dXm::Matrix) = NoTangent(), NoTangent(), dXm - return X, restructure_matrix_pullback -end - -Restructure(X::P) where {P<:Symmetric} = Restructure{P, Nothing}(nothing) - -function restructure(r::Restructure{<:Symmetric}, X::Matrix) - @assert issymmetric(X) - return Symmetric(X) -end - -function ChainRulesCore.rrule(::typeof(restructure), r::Restructure{<:Symmetric}, X::Matrix) - function restructure_Symmetric_pullback(dX::Tangent) - return NoTangent(), NoTangent(), dX.data - end - return restructure(r, X), restructure_Symmetric_pullback -end +# All of the examples here involve new functions (`my_mul` etc) so that it's possible to +# ensure that Zygote's existing adjoints don't get in the way. +# Example 1: matrix-matrix multiplication. my_mul(A::AbstractMatrix, B::AbstractMatrix) = A * B -function ChainRulesCore.rrule(::typeof(my_mul), A::Matrix, B::Matrix) - function my_mul_pullback(C::Matrix) - return NoTangent(), C * B', A' * C - end - return A * B, my_mul_pullback -end - -# Could also use AD inside this definition. -function ChainRulesCore.rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix) - - # Produce dense versions of A and B, and the pullbacks of this operation. - Am, destructure_A_pb = ChainRulesCore.rrule(destructure, A) - Bm, destructure_B_pb = ChainRulesCore.rrule(destructure, B) - - # Compute the rrule in dense-land. This works by assumption. - Cm, my_mul_strict_pullback = ChainRulesCore.rrule(my_mul, Am, Bm) - - # We need the output from the usual forwards pass in order to guarantee that we can - # recover the correct structured type on the output side. - C = my_mul(A, B) - - # Get the structured version back. - _, restructure_C_pb = Zygote._pullback(Zygote.Context(), Restructure(C), Cm) - - # Note that I'm insisting on a `Tangent` here. Would also need to cover Thunks. - function my_mul_generic_pullback(dC) - _, dCm = restructure_C_pb(dC) - _, dAm, dBm = my_mul_strict_pullback(dCm) - _, dA = destructure_A_pb(dAm) - _, dB = destructure_B_pb(dBm) - return NoTangent(), dA, dB - end - - return C, my_mul_generic_pullback +function rrule(config::RuleConfig, ::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix) + C = A * B + natural_pullback_for_mul(C̄) = NoTangent(), C̄ * B', A' * C̄ + return C, wrap_natural_pullback(config, natural_pullback_for_mul, C, A, B) end -A = randn(4, 3) -B = Symmetric(randn(3, 3)) -C, pb = Zygote.pullback(my_mul, A, B) +A = randn(4, 3); +B = Symmetric(randn(3, 3)); +C, pb = Zygote.pullback(my_mul, A, B); @assert C ≈ my_mul(A, B) -dC = randn(4, 3) -dA, dB_zg = pb(dC) -dB = Tangent{typeof(B)}(data=dB_zg.data) - +dC = randn(4, 3); +dA, dB_zg = pb(dC); +dB = Tangent{typeof(B)}(data=dB_zg.data); # Test correctness. -dA_fd, dB_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, dC, A, B) +dA_fd, dB_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, dC, A, B); # to_vec doesn't know how to make `Tangent`s, so instead I map it to a `Tangent` manually. -dB_fd = Tangent{typeof(B)}(data=dB_fd_sym.data) +dB_fd = Tangent{typeof(B)}(data=dB_fd_sym.data); test_approx(dA, dA_fd) test_approx(dB, dB_fd) - - # Example 2: something where the output isn't a matrix. my_sum(x::AbstractArray) = sum(x) @@ -331,132 +154,165 @@ test_approx(dx, dx_fd) # Example 4: ScaledVector -struct ScaledVector <: AbstractVector{Float64} - v::Vector{Float64} +using ChainRulesCore +using ChainRulesCore: Restructure, destructure, Restructure +using ChainRulesTestUtils +using FiniteDifferences +using LinearAlgebra +using Zygote + +# Implement AbstractArray interface. +struct ScaledMatrix <: AbstractMatrix{Float64} + v::Matrix{Float64} α::Float64 end -Base.getindex(x::ScaledVector, n::Int) = x.α * x.v[n] - -Base.size(x::ScaledVector) = size(x.v) +Base.getindex(x::ScaledMatrix, p::Int, q::Int) = x.α * x.v[p, q] -ChainRulesCore.destructure(x::ScaledVector) = x.α * x.v +Base.size(x::ScaledMatrix) = size(x.v) -ChainRulesCore.Restructure(x::P) where {P<:ScaledVector} = Restructure{P}(x.α) -(r::Restructure{<:ScaledVector})(x::AbstractVector) = ScaledVector(r.α, x ./ r.α) +# Implement destructure and restructure. +ChainRulesCore.destructure(x::ScaledMatrix) = x.α * x.v +ChainRulesCore.Restructure(x::P) where {P<:ScaledMatrix} = Restructure{P, Float64}(x.α) +(r::Restructure{<:ScaledMatrix})(x::AbstractArray) = ScaledMatrix(x ./ r.data, r.data) -# ALTERNATIVELY: just do forwards-mode through destructure, and avoid the need to implement -# restructure entirely. I think... hopefully this is correct? -# Under the current implementation, we have to do lots of things twice. -# Is there an implementation in which we don't have this problem? +# Define a function on the type. -function ChainRulesCore.rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix) +my_dot(x::AbstractArray, y::AbstractArray) = dot(x, y) - # It's possible that our generic types have an optimised forwards pass involving - # mutation. We would like to exploit this. - C = A * B +function ChainRulesCore.rrule( + config::RuleConfig, ::typeof(my_dot), x::AbstractArray, y::AbstractArray, +) + _, destructure_x_pb = rrule_via_ad(config, destructure, x) + _, destructure_y_pb = rrule_via_ad(config, destructure, y) - # Get the pullback for destructuring the arguments. - # Currently actually does the destructuring, but this will often be entirely - # unnecessary. Semantically, this is what we want though. - _, destructure_A_pb = ChainRulesCore.rrule(destructure, A) - _, destructure_B_pb = ChainRulesCore.rrule(destructure, B) + function pullback_my_dot(z̄::Real) + x̄_dense = z̄ * y + ȳ_dense = z̄ * x + _, x̄ = destructure_x_pb(x̄_dense) + _, ȳ = destructure_y_pb(ȳ_dense) + return NoTangent(), x̄, ȳ + end + return my_dot(x, y), pullback_my_dot +end - # Somehow, we need to know how destructure would work, so that we can get its pullback. - # This remains a very literal implementation -- it'll generally cheaper in practice. - C_dense = destructure(C) - _, restructure_C_pb = ChainRulesCore.rrule(restructure, Restructure(C), C_dense) - # Restricted to Tangent for illustration purposes. - # Does not permit a natural. - # Thunk also needs to be supported. - function my_mul_pullback(dC::Tangent) +# Check correctness of `my_dot` rrule. Build `ScaledMatrix` internally to avoid technical +# issues with FiniteDifferences. +V = randn(2, 2) +α = randn() +z̄ = randn() - # Obtain natural tangent for output via restructure pullback. - dC_natural = restructure_C_pb(dC) +foo_scal(V, α) = my_dot(ScaledMatrix(V, α), V) - # Code to implement pullback in a generic manner, using natural tangents. - dA_natural = dC_natural * B' - dB_natural = A' * dC_natural +z, pb = Zygote.pullback(foo_scal, V, α) +dx_ad = pb(z̄) - # Obtain structural tangents for inputs via destructure pullback. - _, dA = destructure_A_pb(dA_natural) - _, dB = destructure_B_pb(dB_natural) +dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_scal, z̄, V, α) - return NoTangent(), dA, dB - end - - return C, my_mul_pullback -end +test_approx(dx_ad, dx_fd) -# FORWARD-MODE AD THROUGH DESTRUCTURE YIELDS THE NATURAL TANGENT! -# The important thing is that e.g. a `Diagonal` is a valid tangent for a `Matrix`, because -# you can always find a matrix to which it is `==`. -# CLAIM: any AbstractArray is an acceptable tangent type for an Array. - -# CLAIM: every AbstractArray has a natural tangent, induced by running forwards-mode on -# destructure. +# A function with a specialised rule for ScaledMatrix. +my_scale(a::Real, X::AbstractArray) = a * X +my_scale(a::Real, X::ScaledMatrix) = ScaledMatrix(X.v, X.α * a) +# Generic rrule. +function ChainRulesCore.rrule( + config::RuleConfig, ::typeof(my_scale), a::Real, X::AbstractArray, +) + _, destructure_X_pb = rrule_via_ad(config, destructure, X) + Y = my_scale(a, X) + _, restructure_Y_pb = rrule_via_ad(config, Restructure(Y), collect(Y)) + function pullback_my_scale(Ȳ) + _, Ȳ_dense = restructure_Y_pb(Ȳ) + ā = dot(Ȳ_dense, X) + X̄_dense = Ȳ_dense * a + _, X̄ = destructure_X_pb(X̄_dense) + return NoTangent(), ā, X̄ + end + return Y, pullback_my_scale +end -# PR format: -1. basic claim -- find an equivalent programme, and work with that. -2. +# Verify correctness. +a = randn() +V = randn(2, 2) +α = randn() +z̄ = randn() +# A more complicated programme involving `my_scale`. +B = randn(2, 2) +foo_my_scale(a, V, α) = my_dot(B, my_scale(a, ScaledMatrix(V, α))) +z, pb = Zygote.pullback(foo_my_scale, a, V, α) +da, dV, dα = pb(z̄) +da_fd, dV_fd, dα_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, z̄, a, V, α) +test_approx(da, da_fd) +test_approx(dV, dV_fd) +test_approx(dα, dα_fd) -In my on-going mission to figure out what these natural tangent things are really about, I've arrived at a scheme which gives us the following: -1. a generic construction for deriving generic rrules in terms of an equivalent primal programme, -2. a candidate method for formalising natural tangents as the result of doing AD on pieces of this equivalent primal programme. -The explanation of this PR will come in two chunks: -1. an explanation of the equivalent programme and its implications for natural tangents, and -2. approaches to optimising AD in the equivalent programme without changing its semantics. -I'm explaining it in this order because in my explanation of the programme, I'll have to run AD twice. This is useful for explaining what's going on, but isn't needed in practice. +# Utility functionality. -# A Sketch of the Equivalent Programme +# This will often make life really easy. Just requires that pullback_of_restructure is +# defined for C, and pullback_of_destructure for A and B. Could be generalised to make +# different assumptions (e.g. some arguments don't require destructuring, output doesn't +# require restructuring, etc). Would need to be generalised to arbitrary numbers of +# arguments (clearly doable -- at worst requires a generated function). +function wrap_natural_pullback(natural_pullback, C, A, B) -To begin with, consider a function -```julia -foo(x::AbstractArray)::AbstractArray -``` -for which we want to write a generic `rrule`. To achieve this, assume that we have access to the following two functions: -```julia -destructure(::AbstractArray)::Array -``` -defined such that -```julia -foo_equiv(x) = restructure(foo(x), foo(destructure(x))) -struct_isapprox(foo_equiv(x), foo(x)) -``` -for all `x`, where `struct_isapprox(a, b)` is defined to mean that all of the fields of `a` and `b` must be `struct_isapprox` with each other (i.e. the default version of `==` that people often ask for), with appropriate base-cases defined for non-composite types. + # Generate enclosing pullbacks. Notice that C / A / B only appear here, and aren't + # part of the closure returned. This means that they don't need to be carried around, + # which is good. + destructure_A_pb = pullback_of_destructure(A) + destructure_B_pb = pullback_of_destructure(B) + restructure_C_pb = pullback_of_restructure(C) -The first argument of `restructure` tells it what bit of data to aim for, and the second is the output of running `foo` on the `Array` which `==` `x`. Furthermore, assume that the pullback of `restructure` with respect to its first argument is always `ZeroTangent` or `NoTangent`, meaning that there's no need to AD back through it. This is the case for all of the types I've encountered so far, but I'm sure that there are types for which it will not be the case. + # Wrap natural_pullback to make it play nicely with AD. + function generic_pullback(C̄) + _, C̄_natural = restructure_C_pb(C̄) + f̄, Ā_natural, B̄_natural = natural_pullback(C̄_natural) + _, Ā = destructure_A_pb(Ā_natural) + _, B̄ = destructure_B_pb(B̄_natural) + return f̄, Ā, B̄ + end + return generic_pullback +end -Hopefully it's clear that running AD on `foo_equiv` will yield the same answer as running AD on `foo`, up to known generic AD limitations (things like `x == 0 ? 0 : x` giving the wrong answer at `0`). Moreover, it is hopefully clear that `destructure` is trivial to implement -- `collect` will do. `restructure` is often simple -- for example, -```julia -restructure(D::Diagonal) -``` +# Sketch of rrule for my_mul making use of utility functionality. +function rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix) + # Do the primal computation. + C = A * B + # "natural pullback" + function my_mul_natural_pullback(C̄_natural) + Ā_natural = C̄_natural * B' + B̄_natural = A' * C̄_natural + return NoTangent(), Ā_natural, B̄_natural + end -# Relaxing the Formulation a Bit + return C, wrap_natural_pullback(my_mul_natural_pullback, C, A, B) +end -# Assumptions -1. Methods of `foo` specialised to specific subtypes of `AbstractArray` access via `getindex`. I believe we assume this implicitly generic rrules currently, and I believe that we need this assumption to guarantee correctness (I can construct an example that gives the wrong answer if this assumption is violated). +# Order in which to present stuff. +# 1. Fully worked-through example (matrix-matrix) multiplication: +# a. Most stupid implementation. +# b. Optimal manual implementation. +# c. Optimal implementation using utility functionality. diff --git a/notes.md b/notes.md index 6227e3429..c41edb135 100644 --- a/notes.md +++ b/notes.md @@ -1,16 +1,63 @@ # A General Mechanism for Generic Rules for AbstractArrays -That we don't have a general formalism for deriving natural derivatives has been discussed quite a bit recently. As has our lack of understanding of the precise relationship between the generic rrules we're writing, and what AD would do. This PR proposes a recipe for deriving generic rules, which leads to a possible formalism for natural derivatives. This formalism can be applied to any AbstractArray, and AD can be in principle be used to obtain default values for the natural tangent. +That we don't have a general formalism for deriving natural derivatives has been discussed quite a bit recently. As has our lack of understanding of the precise relationship between the generic rrules we're writing, and what AD would do. This PR proposes a recipe for deriving generic rules, which leads to a possible formalism for natural derivatives. This formalism can be applied to any AbstractArray, and AD can be in principle be used to obtain default values for the natural tangent. Moreover, there's some utility functionality proposed to make working with this formalism straightforward for rule-writers. + I want reviewers to determine whether they agree that the proposed recipe 1. is sufficient for implementing generic rrules on AbstractArrays, 2. is correct, in the sense that it produces the same answer as AD, and 3. the definition of natural tangents proposed indeed applies to any AbstractArray, and broadly agrees with our intuitions about what a natural tangent should be. +I think it should be doable without making breaking changes since it just involves a changing the output types of some rules, which isn't something that we consider breaking provided that they represent the same thing. I'd prefer we worry about this if we think this is a good idea though. + This is a long read. I've tried to condense where possible. + + +## Cutting to the Chase + +Rule-implementers would write rules that look like this: +```julia +function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + + # Do the primal computation. + C = A * B + + # "natural pullback": write intuitive pullback, closing over stuff in the usual manner. + function natural_pullback_for_mul(C̄_natural) + Ā_natural = C̄_natural * B' + B̄_natural = A' * C̄_natural + return NoTangent(), Ā_natural, B̄_natural + end + + # Make a call to utility functionality which transforms cotangents of C, A, and B. + # Rule-writing without this utility has similar requirements to `ProjectTo`. + return C, wrap_natural_pullback(config, natural_pullback_for_mul, C, A, B) +end +``` +I'm proposing to coin the term "natural pullback" for pullbacks written within this system, as they're rules written involving natural (co)tangents. + +Authors will have to implement two functions for their `AbstractArray` type `P`: +```julia +pullback_of_destructure(::P) +pullback_of_restructure(::P) +``` +which are the pullbacks of two functions, `destructure` and `(::Restructure)`, that we'll define later. + +The proposed candidates for natural (co)tangents are obtained as follows: +1. natural tangents are obtained from structural tangents via the pushforward of `destructure`, +2. natural cotangents are obtained from structural cotangents via the pullback of `(::Restructure)`. + +This comes with some wrinkles for some types, including `Symmetric`. More on this later. + +In the proposed system, natural (co)tangents remain confined to `rrule`s, and rule authors can choose to work with either natural, structural, or a mixture of (co)tangents. + + + + + ## Starting Point Imagine a clean-slate version of Zygote / Diffractor, in which any rules writen are either @@ -46,7 +93,7 @@ The intuition behind the recipe is to find a function which is equivalent to `f` The recipe is: 1. Map `x` to an `Array`, `x_dense`, using `getindex`. Call this operation `destructure`. 2. Apply `f` to `x_dense` to obtain `y_dense`. -3. Map `y_dense` onto `y`. Call this operation `restructure`. +3. Map `y_dense` onto `y`. Call this operation `(::Restructure)`. I'm going to define equivalence of output structurally -- two `AbstractArray`s are equal if 1. they are primitives, and whatever notion of equality we have for them holds, or @@ -54,15 +101,124 @@ I'm going to define equivalence of output structurally -- two `AbstractArray`s a The reason for this notion of equality is that AD (as proposed above) treats concrete subtypes of AbstractArray no differently from any other composite type. -Step 2 of the recipe is possible to easily write an rrule for, because we know that its arguments are `Array`s. Step 1 is clearly defined for any `AbstractArray`, so we can implement e.g. a pullback for `destructure` which accepts an `Array` and returns a `Tangent`. +The most literal implementation of this for a function like `*` is therefore something like the following: +```julia +function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + + # Produce dense versions of A and B, and the pullbacks of this operation. + A_dense, destructure_A_pb = rrule(destructure, A) + B_dense, destructure_B_pb = rrule(destructure, B) + + # Compute dense primal. + C_dense = A_dense * B_dense + + # Compute structured primal without densifying to ensure that we get structured `C` back + # if that's what the primal would do. + C = A * B + + # Construct pullback of Restructure. We generally need to extract some information from + # C in order to find the structured version. + _, restructure_C_pb = rrule_via_ad(config, Restructure(C), C_dense) + + function my_mul_generic_pullback(C̄) + + # Recover natural cotangent. + _, C̄_nat = restructure_C_pb(C̄) + + # Compute pullback using natural cotangent of C. + Ā_nat = C̄_nat * B_dense' + B̄_nat = A_dense' * C̄_nat + + # Transform natural cotangents w.r.t. A and B into structural (if non-primitive). + _, Ā = destructure_A_pb(Ā_nat) + _, B̄ = destructure_B_pb(B̄_nat) + return NoTangent(), Ā, B̄ + end + + # The output that we want is `C`, not `C_dense`, so return `C`. + return C, my_mul_generic_pullback +end +``` +I've just written out by hand the rrule for differentiating through the equivalent function. +We'll optimise this implementation shortly to avoid e.g. having to densify primals, and computing the same function twice. +`my_mul` in `examples.jl` verifies the correctness of the above implementation. + + +`destructure` is quite straightforward to define -- essentially equivalent to `collect`. I'm confident that this is always going to be simple to define, because `collect` is always easy to define. + +`Restructure(C)(C_dense)` is a bit trickier. It's the function which takes an `Array` `C_dense` and transforms it into `C`. This feels like a slightly odd thing to do, since we already have `C`, but it's necessary to already know what `C` is in order to construct this function in general -- for example, over-parametrised matrices require this (see the `ScaledMatrix` example in the tests / examples). I'm _reasonably_ confident that this is always going to be possible to define, but I might have missed something. + +The PR shows how to implement steps 1 and 3 for `Array`s, `Diagonal`s, `Symmetric`s, and a custom `AbstractArray` `ScaledMatrix`. + + + + + +## Acceptable (Co)Tangents for `Array` + +Any `AbstractArray` is an acceptable (co)tangent for an `Array` (provided it's the right size, and its elements are appropriate (co)tangents for the elements of the primal `Array`). +I'm going to assume this is true, because I can't see any obvious reason why it wouldn't be. +If anyone feels otherwise, please say. + +For example, this means that a `Diagonal{Float64}` is a valid (co)tangent for an `Array{Float64}`. + + + + + +## Optimising rrules using Natural Pullbacks + +The basic example layed out above was very sub-optimal. Consider the following (equivalent) re-write +```julia +function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + + # Produce pullbacks of destructure. + destructure_A_pb = pullback_of_destructure(config, A) + destructure_B_pb = pullback_of_destructure(config, B) + + # Primal computation. + C = A * B + + # Find pullback of restructure. + restructure_C_pb = pullback_of_restructure(config, C) + + function my_mul_generic_pullback(C̄) + + # Recover natural cotangent. + _, C̄_nat = restructure_C_pb(C̄) + + # Compute pullback using natural cotangent of C. + Ā_nat = C̄_nat * B' + B̄_nat = A' * C̄_nat + + # Transform natural cotangents w.r.t. A and B into structural (if non-primitive). + _, Ā = destructure_A_pb(Ā_nat) + _, B̄ = destructure_B_pb(B̄_nat) + return NoTangent(), Ā, B̄ + end + + return C, my_mul_generic_pullback +end +``` +A few observations: +1. All dense primals are gone. In the pullback, they only appeared in places where they can be safely replaced with the primals themselves because they're doing array-like things. `C_dense` appeared in the construction of `restructure_C_pb`, however, we were using a sub-optimal implementation of that function. Much of the time, `restructure_of_pb` doesn't require `C_dense` in order to know what the pullback would look like and, if it does, it can be obtained from `C`. +2. All direct calls to `rrule_via_ad` have been replaced with calls to functions which are defined to returns the things we actually need (the pullbacks). These typically have efficient (and easy to write) implementations. + +Roughly speaking, the above implementation has only one additional operation than our existing rrules involving `ProjectTo`, which is a call to `restructure_C_pb`, which handles converting a structural tangent for `C̄` into the corresponding natural. Currently we require users to do this by hand, and no clear guidance is provided regarding the correct way to handle this conversion, in contrast to the clarity provided here. + +Almost all of the boilerplate in the above example can be removed by utilising the `wrap_natural_pullback` utility function defined in the PR. + + + -Step 3 is the trickier step. We'll get on to it later. -The PR shows how to implement steps 1 and 3 for `Array`s (trivial), `Diagonal`s, and `Symmetric`s. +## Summary +The above lays out a mechanism or writing generic rrules for AbstractArrays, out of which drops what I believe to be a good candidate for a precise definition of the natural (co)tangent of any particular AbstractArray. +There are a lot more examples in `examples.jl` that I would encourage people to work through. Moreover, the `Symmetric` results are a little odd, but I think make sense. +Additionally, `pullback_of_destructure` and `pullback_of_restructure` are implemented in `src`, while `destructure` and `Restructure` themselves are typically defined in the tests so that it's possible to verify consistency. -### The Internals of `f` matter for consistency with AD +I've presented this work specifically in the context of `AbstractArray`s, but the general scheme could probably be extended to other types by finding other canonical types (like `Array`) on which people's intuition about what ought to happen holds. -Must only use the AbstractArrays API (no access to internal fields, just uses `getindex` and -`size`). +I'm sure there's stuff above which is unclear -- please let me know if so. There's more to say about a lot of this stuff, but I'll discuss as they come up in the interest of keeping this is as brief as possible. diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 10912ce61..fff6a8e57 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -547,3 +547,45 @@ function _constrain_and_name(arg::Expr, _) error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type + + + +# This comment is for reviewing purposes, and will need to be replaced later. +# This will often make life really easy. Just requires that pullback_of_restructure is +# defined for C, and pullback_of_destructure for A and B. Could be generalised to make +# different assumptions (e.g. some arguments don't require destructuring, output doesn't +# require restructuring, etc). Would need to be generalised to arbitrary numbers of +# arguments (clearly doable -- at worst requires a generated function). +# I'm assuming that functions don't need to have the destructure pullback applied to them, +# but this probably won't always be true. +function wrap_natural_pullback(config, natural_pullback, C, A, B) + + # Generate enclosing pullbacks. Notice that C / A / B only appear here, and aren't + # part of the closure returned. This means that they don't need to be carried around, + # which is good. + destructure_A_pb = pullback_of_destructure(config, A) + destructure_B_pb = pullback_of_destructure(config, B) + restructure_C_pb = pullback_of_restructure(config, C) + + # Wrap natural_pullback to make it play nicely with AD. + function generic_pullback(C̄) + _, C̄_natural = restructure_C_pb(C̄) + f̄, Ā_natural, B̄_natural = natural_pullback(C̄_natural) + _, Ā = destructure_A_pb(Ā_natural) + _, B̄ = destructure_B_pb(B̄_natural) + return f̄, Ā, B̄ + end + return generic_pullback +end + +function wrap_natural_pullback(config, natural_pullback, B, A) + destructure_input_pb = pullback_of_destructure(config, A) + restructure_output_pb = pullback_of_restructure(config, B) + function generic_pullback(B̄) + _, B̄_natural = restructure_output_pb(B̄) + f̄, Ā_natural = natural_pullback(B̄_natural) + _, Ā = destructure_input_pb(Ā_natural) + return f̄, Ā + end + return generic_pullback +end From bccece20c6e1c20130910c6a6c18d98b2f71b9c4 Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:16:48 +0100 Subject: [PATCH 07/36] Tidy up examples --- examples.jl | 241 ++++++++++++++++++---------------------------------- 1 file changed, 85 insertions(+), 156 deletions(-) diff --git a/examples.jl b/examples.jl index c1c0f6317..13c2f2b59 100644 --- a/examples.jl +++ b/examples.jl @@ -4,16 +4,12 @@ using FiniteDifferences using LinearAlgebra using Zygote -import ChainRulesCore: rrule +import ChainRulesCore: rrule, pullback_of_destructure, pullback_of_restructure -using ChainRulesCore: - pullback_of_destructure, - pullback_of_restructure, - RuleConfig, - wrap_natural_pullback +using ChainRulesCore: RuleConfig, wrap_natural_pullback # All of the examples here involve new functions (`my_mul` etc) so that it's possible to -# ensure that Zygote's existing adjoints don't get in the way. +# ensure that Zygote's / ChainRules' existing adjoints don't get in the way. # Example 1: matrix-matrix multiplication. @@ -46,25 +42,21 @@ test_approx(dB, dB_fd) -# Example 2: something where the output isn't a matrix. -my_sum(x::AbstractArray) = sum(x) +# pullbacks for `Real`s so that they play nicely with the utility functionality. -function ChainRulesCore.rrule(::typeof(my_sum), x::Array) - my_sum_strict_pullback(dy::Real) = (NoTangent(), dy * ones(size(x))) - return sum(x), my_sum_strict_pullback -end +ChainRulesCore.pullback_of_destructure(config::RuleConfig, x::Real) = identity -function ChainRulesCore.rrule(::typeof(my_sum), x::AbstractArray) - x_dense, destructure_pb = ChainRulesCore.rrule(destructure, x) - y, my_sum_strict_pullback = ChainRulesCore.rrule(my_sum, x_dense) +ChainRulesCore.pullback_of_restructure(config::RuleConfig, x::Real) = identity - function my_sum_generic_pullback(dy::Real) - _, dx_dense = my_sum_strict_pullback(dy) - _, dx = destructure_pb(dx_dense) - return NoTangent(), dx - end - return y, my_sum_generic_pullback +# Example 2: something where the output isn't a matrix. + +my_sum(x::AbstractArray) = sum(x) + +function ChainRulesCore.rrule(config::RuleConfig, ::typeof(my_sum), x::AbstractArray) + y = my_sum(x) + natural_pullback_my_sum(ȳ::Real) = NoTangent(), fill(ȳ, size(x)) + return y, wrap_natural_pullback(config, natural_pullback_my_sum, y, x) end A = Symmetric(randn(2, 2)) @@ -89,36 +81,59 @@ test_approx(dA, dA_fd) my_scale(a::Real, x::AbstractMatrix) = a * x -function ChainRulesCore.rrule(::typeof(my_inv), x::Matrix) +function ChainRulesCore.rrule( + config::RuleConfig, ::typeof(my_scale), a::Real, x::AbstractMatrix, +) + y = my_scale(a, x) + natural_pullback_my_scale(ȳ::AbstractMatrix) = NoTangent(), dot(ȳ, x), a * ȳ + return y, wrap_natural_pullback(config, natural_pullback_my_scale, y, a, x) +end - y, pb = ChainRulesCore.rrule(inv, x) +# DENSE TEST +a = randn() +x = randn(2, 2) +y, pb = Zygote.pullback(my_scale, a, x) - # We know that a * x isa Array. Any AbstractArray is an okay tangent for an Array. - function my_scale_pullback(ȳ::AbstractArray) - return NoTangent(), dot(ȳ, x), ȳ * a - end - return a * x, my_scale_pullback -end +dy = randn(size(y)) +da, dx = pb(dy) + +da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x) + +test_approx(y, my_scale(a, x)) +test_approx(da, da_fd) +test_approx(dx, dx_fd) -function ChainRulesCore.rrule(::typeof(my_scale), a::Real, x::AbstractMatrix) - x_dense, destructure_x_pb = ChainRulesCore.rrule(destructure, x) - y_dense, my_scale_strict_pb = ChainRulesCore.rrule(my_scale, a, x_dense) - y = my_scale(a, x) - y_reconstruct, restructure_pb = ChainRulesCore.rrule(Restructure(y), y_dense) - function my_scale_generic_pullback(dy) - _, dy_dense = restructure_pb(dy) - _, da, dx_dense = my_scale_strict_pb(dy_dense) - _, dx = destructure_x_pb(dx_dense) - return NoTangent(), da, dx - end - return y_reconstruct, my_scale_generic_pullback +# DIAGONAL TEST + +# `diag` now returns a `Diagonal` as a tangnet, so have to define `my_diag` to make this +# work with Diagonal`s. +my_diag(x) = diag(x) +function ChainRulesCore.rrule(::typeof(my_diag), D::P) where {P<:Diagonal} + my_diag_pullback(d) = NoTangent(), Tangent{P}(diag=d) + return diag(D), my_diag_pullback end -Zygote.refresh() +a = randn() +x = Diagonal(randn(2)) +y, pb = Zygote.pullback(my_diag ∘ my_scale, a, x) + +ȳ = randn(2) +ā, x̄_zg = pb(ȳ) +x̄ = Tangent{typeof(x)}(diag=x̄_zg.diag) + +ā_fd, _x̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_diag ∘ my_scale, ȳ, a, x) +x̄_fd = Tangent{typeof(x)}(diag=_x̄_fd.diag) + +test_approx(y, (my_diag ∘ my_scale)(a, x)) +test_approx(ā, ā_fd) +test_approx(x̄, x̄_fd) -# SYMMETRIC TEST + + +# SYMMETRIC TEST - FAILS BECAUSE HIDDEN ELEMENTS IN LOWER-DIAGONAL ACCESSED IN PRIMAL! +# I would be surprised if we're doing this consistently at the minute though. a = randn() x = Symmetric(randn(2, 2)) @@ -135,31 +150,12 @@ test_approx(y.data, my_scale(a, x).data) test_approx(da, da_fd) test_approx(dx, dx_fd) -# DENSE TEST -x_dense = collect(x) -y, pb = Zygote.pullback(my_scale, a, x_dense) - -dy = randn(size(y)) -da, dx = pb(dy) - -da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x_dense) - -test_approx(y, my_scale(a, x_dense)) -test_approx(da, da_fd) -test_approx(dx, dx_fd) - - -# Example 4: ScaledVector -using ChainRulesCore -using ChainRulesCore: Restructure, destructure, Restructure -using ChainRulesTestUtils -using FiniteDifferences -using LinearAlgebra -using Zygote +# Example 4: ScaledVector. This is an interesting example because I truly had no idea how to +# specify a natural tangent for this before. # Implement AbstractArray interface. struct ScaledMatrix <: AbstractMatrix{Float64} @@ -172,13 +168,29 @@ Base.getindex(x::ScaledMatrix, p::Int, q::Int) = x.α * x.v[p, q] Base.size(x::ScaledMatrix) = size(x.v) -# Implement destructure and restructure. +# Implement destructure and restructure pullbacks. + +function pullback_of_destructure(config::RuleConfig, x::P) where {P<:ScaledMatrix} + function pullback_destructure_ScaledMatrix(X̄::AbstractArray) + return Tangent{P}(v = X̄ * x.α, α = dot(X̄, x.v)) + end + return pullback_destructure_ScaledMatrix +end + +function pullback_of_restructure(config::RuleConfig, x::ScaledMatrix) + function pullback_restructure_ScaledMatrix(x̄::Tangent) + return x̄.v / x.α + end + return pullback_restructure_ScaledMatrix +end -ChainRulesCore.destructure(x::ScaledMatrix) = x.α * x.v +# What destructure and restructure would look like if implemented. pullbacks were derived +# based on these. +# ChainRulesCore.destructure(x::ScaledMatrix) = x.α * x.v -ChainRulesCore.Restructure(x::P) where {P<:ScaledMatrix} = Restructure{P, Float64}(x.α) +# ChainRulesCore.Restructure(x::P) where {P<:ScaledMatrix} = Restructure{P, Float64}(x.α) -(r::Restructure{<:ScaledMatrix})(x::AbstractArray) = ScaledMatrix(x ./ r.data, r.data) +# (r::Restructure{<:ScaledMatrix})(x::AbstractArray) = ScaledMatrix(x ./ r.data, r.data) @@ -190,17 +202,9 @@ my_dot(x::AbstractArray, y::AbstractArray) = dot(x, y) function ChainRulesCore.rrule( config::RuleConfig, ::typeof(my_dot), x::AbstractArray, y::AbstractArray, ) - _, destructure_x_pb = rrule_via_ad(config, destructure, x) - _, destructure_y_pb = rrule_via_ad(config, destructure, y) - - function pullback_my_dot(z̄::Real) - x̄_dense = z̄ * y - ȳ_dense = z̄ * x - _, x̄ = destructure_x_pb(x̄_dense) - _, ȳ = destructure_y_pb(ȳ_dense) - return NoTangent(), x̄, ȳ - end - return my_dot(x, y), pullback_my_dot + z = my_dot(x, y) + natural_pullback_my_dot(z̄::Real) = NoTangent(), z̄ * y, z̄ * x + return z, wrap_natural_pullback(config, natural_pullback_my_dot, z, x, y) end @@ -220,29 +224,9 @@ dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_scal, z̄, V, α) test_approx(dx_ad, dx_fd) -# A function with a specialised rule for ScaledMatrix. -my_scale(a::Real, X::AbstractArray) = a * X +# A function with a specialised method for ScaledMatrix. my_scale(a::Real, X::ScaledMatrix) = ScaledMatrix(X.v, X.α * a) -# Generic rrule. -function ChainRulesCore.rrule( - config::RuleConfig, ::typeof(my_scale), a::Real, X::AbstractArray, -) - _, destructure_X_pb = rrule_via_ad(config, destructure, X) - Y = my_scale(a, X) - _, restructure_Y_pb = rrule_via_ad(config, Restructure(Y), collect(Y)) - - function pullback_my_scale(Ȳ) - _, Ȳ_dense = restructure_Y_pb(Ȳ) - ā = dot(Ȳ_dense, X) - X̄_dense = Ȳ_dense * a - _, X̄ = destructure_X_pb(X̄_dense) - return NoTangent(), ā, X̄ - end - - return Y, pullback_my_scale -end - # Verify correctness. a = randn() V = randn(2, 2) @@ -261,58 +245,3 @@ da_fd, dV_fd, dα_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, test_approx(da, da_fd) test_approx(dV, dV_fd) test_approx(dα, dα_fd) - - - - - -# Utility functionality. - -# This will often make life really easy. Just requires that pullback_of_restructure is -# defined for C, and pullback_of_destructure for A and B. Could be generalised to make -# different assumptions (e.g. some arguments don't require destructuring, output doesn't -# require restructuring, etc). Would need to be generalised to arbitrary numbers of -# arguments (clearly doable -- at worst requires a generated function). -function wrap_natural_pullback(natural_pullback, C, A, B) - - # Generate enclosing pullbacks. Notice that C / A / B only appear here, and aren't - # part of the closure returned. This means that they don't need to be carried around, - # which is good. - destructure_A_pb = pullback_of_destructure(A) - destructure_B_pb = pullback_of_destructure(B) - restructure_C_pb = pullback_of_restructure(C) - - # Wrap natural_pullback to make it play nicely with AD. - function generic_pullback(C̄) - _, C̄_natural = restructure_C_pb(C̄) - f̄, Ā_natural, B̄_natural = natural_pullback(C̄_natural) - _, Ā = destructure_A_pb(Ā_natural) - _, B̄ = destructure_B_pb(B̄_natural) - return f̄, Ā, B̄ - end - return generic_pullback -end - -# Sketch of rrule for my_mul making use of utility functionality. -function rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix) - - # Do the primal computation. - C = A * B - - # "natural pullback" - function my_mul_natural_pullback(C̄_natural) - Ā_natural = C̄_natural * B' - B̄_natural = A' * C̄_natural - return NoTangent(), Ā_natural, B̄_natural - end - - return C, wrap_natural_pullback(my_mul_natural_pullback, C, A, B) -end - - - -# Order in which to present stuff. -# 1. Fully worked-through example (matrix-matrix) multiplication: -# a. Most stupid implementation. -# b. Optimal manual implementation. -# c. Optimal implementation using utility functionality. From 569f064b82ac33e0c998a19a7553d2fc86e7d83e Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:16:54 +0100 Subject: [PATCH 08/36] Tidy up PR notes --- notes.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/notes.md b/notes.md index c41edb135..4b8a6f5a3 100644 --- a/notes.md +++ b/notes.md @@ -54,6 +54,16 @@ This comes with some wrinkles for some types, including `Symmetric`. More on thi In the proposed system, natural (co)tangents remain confined to `rrule`s, and rule authors can choose to work with either natural, structural, or a mixture of (co)tangents. +Other than the headlines at the top, additional benefits of (a correct implementation of) this approach include: +1. no chance of trying to sum tangents failing because one is natural and the other structural, +1. no risk of obstructing AD, +1. rand_tangent can just use structural tangents, simplifying its implementation and improving its reliability, +1. we can probably finally make `to_vec` treat things structurally (although we also need to extend it in other ways), which will also deal with reliability / simplicity of implementation problems, +1. generic constructor for composite types easy to implement, +1. due to the utility functionality, all of the examples that I've encountered so far are very concise. + +The potential downside is additional conversions between natural and structural tangents. Most of the time, these are free. When they're not, you ideally want to minimise them. I'm not sure how often this is going to be a problem, but it's something we're essentially ignoring at the minute (as far as I know), so we're probably going to have to incur some additional cost even if we don't go down the route proposed here. + @@ -212,6 +222,20 @@ Almost all of the boilerplate in the above example can be removed by utilising t +## Gotchas + +There does seem to be something that goes wrong when primals access non-public fields of types (array authors are obviously allowed to do this), but the generic rrules assume that only the AbstractArray API is used. +I don't think this differs from what we're doing at the minute, so probably we're suffering from this already and just haven't hit it yet. +See the third example in `examples.jl`, involving a `Symmetric`. + +This is a particularly interesting case because `parent` is exported from `Base`, effectively making the field of a `Symmetric` part of its public API. + +I'm not really sure how to think about this but, as I say, I suspect we're already suffering from it, so I'm not going to worry about it for now. + + + + + ## Summary The above lays out a mechanism or writing generic rrules for AbstractArrays, out of which drops what I believe to be a good candidate for a precise definition of the natural (co)tangent of any particular AbstractArray. @@ -221,4 +245,6 @@ Additionally, `pullback_of_destructure` and `pullback_of_restructure` are implem I've presented this work specifically in the context of `AbstractArray`s, but the general scheme could probably be extended to other types by finding other canonical types (like `Array`) on which people's intuition about what ought to happen holds. +The implementation are also limited to arrays of real numbers to avoid the need to recursively apply `destructure` / `restructure`. This could be done in practice if it were thought helpful. + I'm sure there's stuff above which is unclear -- please let me know if so. There's more to say about a lot of this stuff, but I'll discuss as they come up in the interest of keeping this is as brief as possible. From f321a9721ea8cd1248e3bcfbc4ce5de455af39ee Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:17:19 +0100 Subject: [PATCH 09/36] Provide optimised pullback implementations --- src/destructure.jl | 51 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index a975f4dcc..c2c937dfd 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -1,22 +1,32 @@ # Fallbacks for destructure destructure(X::AbstractArray) = collect(X) -pushforward_of_destructure(X) = dX -> frule((NoTangent(), dX), destructure, X)[2] - -pullback_of_destructure(X) = dY -> rrule(destructure, X)[2](dY)[2] +function pullback_of_destructure(config::RuleConfig, X) + return dY -> rrule_via_ad(config, destructure, X)[2](dY)[2] +end # Restructure machinery. struct Restructure{P, D} data::D end -pullback_of_restructure(X) = dY -> rrule(Restructure(X), destructure(X))[2](dY)[2] - +function pullback_of_restructure(config::RuleConfig, X) + return dY -> rrule_via_ad(config, Restructure(X), destructure(X))[2](dY)[2] +end # Array +function pullback_of_destructure(config::RuleConfig, X::Array{<:Real}) + pullback_destructure_Array(X̄::AbstractArray{<:Real}) = X̄ + return pullback_destructure_Array +end + +function pullback_of_restructure(config::RuleConfig, X::Array{<:Real}) + pullback_restructure_Array(X̄::AbstractArray{<:Real}) = X̄ +end + destructure(X::Array) = X frule((_, dX)::Tuple{Any, AbstractArray}, ::typeof(destructure), X::Array) = X, dX @@ -46,6 +56,16 @@ end # Diagonal +function pullback_of_destructure(config::RuleConfig, D::P) where {P<:Diagonal} + pullback_destructure_Diagonal(D̄::AbstractArray) = Tangent{P}(diag=diag(D̄)) + return pullback_destructure_Diagonal +end + +function pullback_of_restructure(config::RuleConfig, D::P) where {P<:Diagonal} + pullback_restructure_Diagonal(D̄::Tangent) = Diagonal(D̄.diag) + return pullback_restructure_Diagonal +end + destructure(X::Diagonal) = collect(X) function frule((_, dX)::Tuple{Any, Tangent}, ::typeof(destructure), X::Diagonal) @@ -86,6 +106,27 @@ end # Symmetric +function pullback_of_destructure(config::RuleConfig, S::P) where {P<:Symmetric} + function destructure_pullback_Symmetric(dXm::AbstractMatrix) + U = UpperTriangular(dXm) + L = LowerTriangular(dXm) + if S.uplo == 'U' + return Tangent{P}(data=U + L' - Diagonal(dXm)) + else + return Tangent{P}(data=U' + L - Diagonal(dXm)) + end + end + return destructure_pullback_Symmetric +end + +# Assume upper-triangular for now. +function pullback_of_restructure(config::RuleConfig, S::P) where {P<:Symmetric} + function restructure_pullback_Symmetric(dY::Tangent) + return collect(UpperTriangular(dY.data)) + end + return restructure_pullback_Symmetric +end + function destructure(X::Symmetric) des_data = destructure(X.data) if X.uplo == 'U' From 8a06561fd9ac72ac6e802625edf808513f0085ec Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:17:41 +0100 Subject: [PATCH 10/36] Change pullback implementation specifications --- src/rule_definition_tools.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index fff6a8e57..9d66700b2 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -569,10 +569,10 @@ function wrap_natural_pullback(config, natural_pullback, C, A, B) # Wrap natural_pullback to make it play nicely with AD. function generic_pullback(C̄) - _, C̄_natural = restructure_C_pb(C̄) + C̄_natural = restructure_C_pb(C̄) f̄, Ā_natural, B̄_natural = natural_pullback(C̄_natural) - _, Ā = destructure_A_pb(Ā_natural) - _, B̄ = destructure_B_pb(B̄_natural) + Ā = destructure_A_pb(Ā_natural) + B̄ = destructure_B_pb(B̄_natural) return f̄, Ā, B̄ end return generic_pullback @@ -582,9 +582,9 @@ function wrap_natural_pullback(config, natural_pullback, B, A) destructure_input_pb = pullback_of_destructure(config, A) restructure_output_pb = pullback_of_restructure(config, B) function generic_pullback(B̄) - _, B̄_natural = restructure_output_pb(B̄) + B̄_natural = restructure_output_pb(B̄) f̄, Ā_natural = natural_pullback(B̄_natural) - _, Ā = destructure_input_pb(Ā_natural) + Ā = destructure_input_pb(Ā_natural) return f̄, Ā end return generic_pullback From 72fc725fccf90f5ee8292a9d83857183511e0d99 Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:25:36 +0100 Subject: [PATCH 11/36] Add extra methods to avoid config --- src/destructure.jl | 49 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index c2c937dfd..dc735d52a 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -1,6 +1,8 @@ # Fallbacks for destructure destructure(X::AbstractArray) = collect(X) +pushforward_of_destructure(X) = dX -> frule((NoTangent(), dX), destructure, X)[2] + function pullback_of_destructure(config::RuleConfig, X) return dY -> rrule_via_ad(config, destructure, X)[2](dY)[2] end @@ -18,15 +20,28 @@ end # Array -function pullback_of_destructure(config::RuleConfig, X::Array{<:Real}) + +function pullback_of_destructure(X::Array{<:Real}) pullback_destructure_Array(X̄::AbstractArray{<:Real}) = X̄ return pullback_destructure_Array end -function pullback_of_restructure(config::RuleConfig, X::Array{<:Real}) + +function pullback_of_restructure(X::Array{<:Real}) pullback_restructure_Array(X̄::AbstractArray{<:Real}) = X̄ + return pullback_restructure_Array +end + +function pullback_of_destructure(config::RuleConfig, X::Array{<:Real}) + return pullback_of_destructure(X) +end + +function pullback_of_restructure(config::RuleConfig, X::Array{<:Real}) + return pullback_of_destructure(X) end + +# Stuff below here for Array to move to tests. destructure(X::Array) = X frule((_, dX)::Tuple{Any, AbstractArray}, ::typeof(destructure), X::Array) = X, dX @@ -56,16 +71,25 @@ end # Diagonal -function pullback_of_destructure(config::RuleConfig, D::P) where {P<:Diagonal} +function pullback_of_destructure(D::P) where {P<:Diagonal} pullback_destructure_Diagonal(D̄::AbstractArray) = Tangent{P}(diag=diag(D̄)) return pullback_destructure_Diagonal end -function pullback_of_restructure(config::RuleConfig, D::P) where {P<:Diagonal} +function pullback_of_restructure(D::P) where {P<:Diagonal} pullback_restructure_Diagonal(D̄::Tangent) = Diagonal(D̄.diag) return pullback_restructure_Diagonal end +function pullback_of_destructure(config::RuleConfig, X::Diagonal) + return pullback_of_destructure(X) +end + +function pullback_of_restructure(config::RuleConfig, X::Diagonal) + return pullback_of_destructure(X) +end + +# Stuff below here for Diagonal to move to tests. destructure(X::Diagonal) = collect(X) function frule((_, dX)::Tuple{Any, Tangent}, ::typeof(destructure), X::Diagonal) @@ -106,7 +130,7 @@ end # Symmetric -function pullback_of_destructure(config::RuleConfig, S::P) where {P<:Symmetric} +function pullback_of_destructure(S::P) where {P<:Symmetric} function destructure_pullback_Symmetric(dXm::AbstractMatrix) U = UpperTriangular(dXm) L = LowerTriangular(dXm) @@ -120,13 +144,22 @@ function pullback_of_destructure(config::RuleConfig, S::P) where {P<:Symmetric} end # Assume upper-triangular for now. -function pullback_of_restructure(config::RuleConfig, S::P) where {P<:Symmetric} +function pullback_of_restructure(S::P) where {P<:Symmetric} function restructure_pullback_Symmetric(dY::Tangent) return collect(UpperTriangular(dY.data)) end return restructure_pullback_Symmetric end +function pullback_of_destructure(config::RuleConfig, X::Symmetric) + return pullback_of_destructure(X) +end + +function pullback_of_restructure(config::RuleConfig, X::Symmetric) + return pullback_of_destructure(X) +end + +# Stuff below here for Symmetric to move to tests. function destructure(X::Symmetric) des_data = destructure(X.data) if X.uplo == 'U' @@ -193,9 +226,7 @@ end # Cholesky -- you get to choose whatever destructuring operation is helpful for a given # type. This one is helpful for writing generic pullbacks for `cholesky`, the output of -# which is a Cholesky. -# I've not completed the implementation, but it would just require a pushforward and a -# pullback. +# which is a Cholesky. Not completed. Probably won't be included in initial merge. destructure(C::Cholesky) = Cholesky(destructure(C.factors), C.uplo, C.info) # Restructure(C::P) where {P<:Cholesky} = Restructure{P, Nothing}() From fbaaf86db598782f6c00f4c7a499f1f223f8f596 Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:29:31 +0100 Subject: [PATCH 12/36] Fix typo --- src/destructure.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index dc735d52a..fb6f7b988 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -26,7 +26,6 @@ function pullback_of_destructure(X::Array{<:Real}) return pullback_destructure_Array end - function pullback_of_restructure(X::Array{<:Real}) pullback_restructure_Array(X̄::AbstractArray{<:Real}) = X̄ return pullback_restructure_Array @@ -37,7 +36,7 @@ function pullback_of_destructure(config::RuleConfig, X::Array{<:Real}) end function pullback_of_restructure(config::RuleConfig, X::Array{<:Real}) - return pullback_of_destructure(X) + return pullback_of_restructure(X) end @@ -86,7 +85,7 @@ function pullback_of_destructure(config::RuleConfig, X::Diagonal) end function pullback_of_restructure(config::RuleConfig, X::Diagonal) - return pullback_of_destructure(X) + return pullback_of_restructure(X) end # Stuff below here for Diagonal to move to tests. @@ -156,7 +155,7 @@ function pullback_of_destructure(config::RuleConfig, X::Symmetric) end function pullback_of_restructure(config::RuleConfig, X::Symmetric) - return pullback_of_destructure(X) + return pullback_of_restructure(X) end # Stuff below here for Symmetric to move to tests. From 313590d2277a237543a3b0c62d88c11221c206a6 Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:30:22 +0100 Subject: [PATCH 13/36] Tweak comment in examples --- examples.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples.jl b/examples.jl index 13c2f2b59..a88aeff54 100644 --- a/examples.jl +++ b/examples.jl @@ -132,7 +132,7 @@ test_approx(x̄, x̄_fd) -# SYMMETRIC TEST - FAILS BECAUSE HIDDEN ELEMENTS IN LOWER-DIAGONAL ACCESSED IN PRIMAL! +# SYMMETRIC TEST - FAILS BECAUSE PRIVATE ELEMENTS IN LOWER-DIAGONAL ACCESSED IN PRIMAL! # I would be surprised if we're doing this consistently at the minute though. a = randn() From 321fceb2f0d52bca1d726674d7cb00876e40693a Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:54:48 +0100 Subject: [PATCH 14/36] Update notes --- notes.md | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/notes.md b/notes.md index 4b8a6f5a3..6d7c53b79 100644 --- a/notes.md +++ b/notes.md @@ -8,9 +8,9 @@ I want reviewers to determine whether they agree that the proposed recipe 2. is correct, in the sense that it produces the same answer as AD, and 3. the definition of natural tangents proposed indeed applies to any AbstractArray, and broadly agrees with our intuitions about what a natural tangent should be. -I think it should be doable without making breaking changes since it just involves a changing the output types of some rules, which isn't something that we consider breaking provided that they represent the same thing. I'd prefer we worry about this if we think this is a good idea though. +I think what is proposed should be doable without making breaking changes since it just involves a changing the output types of some rules, which isn't something that we consider breaking provided that they represent the same thing. I'd prefer we worry about this later if we think this is a good idea though. -This is a long read. I've tried to condense where possible. +Apologies in advance for the lenght. I've tried to condense where possible, but there's a decent amount of content to get through. @@ -50,7 +50,7 @@ The proposed candidates for natural (co)tangents are obtained as follows: 1. natural tangents are obtained from structural tangents via the pushforward of `destructure`, 2. natural cotangents are obtained from structural cotangents via the pullback of `(::Restructure)`. -This comes with some wrinkles for some types, including `Symmetric`. More on this later. +Our current `ProjectTo` functionality is roughly the same as the pullback of `destructure`. In the proposed system, natural (co)tangents remain confined to `rrule`s, and rule authors can choose to work with either natural, structural, or a mixture of (co)tangents. @@ -74,7 +74,7 @@ Imagine a clean-slate version of Zygote / Diffractor, in which any rules writen 1. necessary to define what AD should do (e.g. `+` or `*` on `Float64`s, `getindex` on `Array`s, `getfield` on composite types, etc), or 2. produce _exactly_ the same answer that AD would produce -- crucially they always return a structural tangent for a non-primitive composite type. -Moreover, assume that users provide structural tangents -- we'll show how to remove this particular assumption later. +Moreover, assume that users provide structural tangents -- this restriction could be removed by AD authors by shoving a call to `destructure` at the end of their AD if that's something they want to do. The consequence of the above is that we can safely assume that the input and output of any rule will be either 1. a structural tangent (if the primal is a non-primitive composite type) @@ -84,8 +84,8 @@ I'm intentionally not trying to define precisely what a primitive is, but will a Assume that structural tangents are a valid representation of a tangent of any non-primitive composite type, convenience for rule-writing aside. -More generally, assume that if Zygote / Diffractor successfully run on a given function under the above assumptions, they give the answer desired (the "correct" answer). Consequently, the core goals of the proposed recipe are to make it possible to both -1. never write a rule which prevents Zygote / Diffractor from differentiating a programme that they already know how to differentiate, +More generally, assume that if an AD successfully runs on a given function under the above assumptions, they give the answer desired (the "correct" answer). Consequently, the core goals of the proposed recipe are to make it +1. easy to never write a rule which prevents an AD from differentiating a programme that they already know how to differentiate, 2. make it easy to write rules using intuitive representations of tangents. @@ -94,14 +94,14 @@ More generally, assume that if Zygote / Diffractor successfully run on a given f ## The Formalism -First consider a specific case -- we'll both generalise and optimise the implementation later. +First consider a specific case -- we'll optimise the implementation later, and provide more general examples in `examples.jl`. Consider a function `f(x::AbstractArray) -> y::AbstractArray`. Lets assume that there's just one method, so we can be sure that a generic fallback will be hit, regardless the concrete type of the argument. The intuition behind the recipe is to find a function which is equivalent to `f`, whose rules we know how to write safely. If we can find such a function, AD-ing it will clearly give the correct answer -- the following lays out an approach to doing this. The recipe is: -1. Map `x` to an `Array`, `x_dense`, using `getindex`. Call this operation `destructure`. +1. Map `x` to an `Array`, `x_dense`, using `getindex`. Call this operation `destructure` (it's essentially `collect`). 2. Apply `f` to `x_dense` to obtain `y_dense`. 3. Map `y_dense` onto `y`. Call this operation `(::Restructure)`. @@ -111,7 +111,7 @@ I'm going to define equivalence of output structurally -- two `AbstractArray`s a The reason for this notion of equality is that AD (as proposed above) treats concrete subtypes of AbstractArray no differently from any other composite type. -The most literal implementation of this for a function like `*` is therefore something like the following: +A very literal implementation of this for a function like `*` is something like the following: ```julia function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMatrix) @@ -213,10 +213,11 @@ end A few observations: 1. All dense primals are gone. In the pullback, they only appeared in places where they can be safely replaced with the primals themselves because they're doing array-like things. `C_dense` appeared in the construction of `restructure_C_pb`, however, we were using a sub-optimal implementation of that function. Much of the time, `restructure_of_pb` doesn't require `C_dense` in order to know what the pullback would look like and, if it does, it can be obtained from `C`. 2. All direct calls to `rrule_via_ad` have been replaced with calls to functions which are defined to returns the things we actually need (the pullbacks). These typically have efficient (and easy to write) implementations. +3. `C̄_nat` could be any old `AbstractArray`. For example, `pullback_of_restructure` for a `Diagonal` returns a `Diagonal`. This is good -- it means we might occassionally get faster computations in the pullback. -Roughly speaking, the above implementation has only one additional operation than our existing rrules involving `ProjectTo`, which is a call to `restructure_C_pb`, which handles converting a structural tangent for `C̄` into the corresponding natural. Currently we require users to do this by hand, and no clear guidance is provided regarding the correct way to handle this conversion, in contrast to the clarity provided here. +Roughly speaking, the above implementation has only one additional operation than our existing rrules involving `ProjectTo`, which is a call to `restructure_C_pb`, which handles converting a structural tangent for `C̄` into the corresponding natural. Currently we require users to do this by hand, and no clear guidance is provided regarding the correct way to handle this conversion, in contrast to the clarity provided here. In this sense, all that the above is doing is providing a well-defined mechanism by which users can obtain natural cotangents from structural cotangents, so it should ease the burden on rule-implementers. -Almost all of the boilerplate in the above example can be removed by utilising the `wrap_natural_pullback` utility function defined in the PR. +Almost all of the boilerplate in the above example can be removed by utilising the `wrap_natural_pullback` utility function defined in the PR, as in the example at the top of this note. @@ -238,13 +239,15 @@ I'm not really sure how to think about this but, as I say, I suspect we're alrea ## Summary -The above lays out a mechanism or writing generic rrules for AbstractArrays, out of which drops what I believe to be a good candidate for a precise definition of the natural (co)tangent of any particular AbstractArray. +The above lays out a mechanism for writing generic rrules for AbstractArrays, out of which drops what I believe to be a good candidate for a precise definition of the natural (co)tangent of any particular AbstractArray. There are a lot more examples in `examples.jl` that I would encourage people to work through. Moreover, the `Symmetric` results are a little odd, but I think make sense. -Additionally, `pullback_of_destructure` and `pullback_of_restructure` are implemented in `src`, while `destructure` and `Restructure` themselves are typically defined in the tests so that it's possible to verify consistency. +Additionally, implementations of `destructure` and `Restructure` can be moved to the tests because they're really just used to verify the correctness of manually implementations of their pullbacks. -I've presented this work specifically in the context of `AbstractArray`s, but the general scheme could probably be extended to other types by finding other canonical types (like `Array`) on which people's intuition about what ought to happen holds. +I've presented this work in the context of `AbstractArray`s, but the general scheme could probably be extended to other types by finding other canonical types (like `Array`) on which people's intuition about what ought to happen holds. -The implementation are also limited to arrays of real numbers to avoid the need to recursively apply `destructure` / `restructure`. This could be done in practice if it were thought helpful. +The implementation are also limited to arrays of real numbers to avoid the need to recursively apply `destructure` / `restructure`. This restriction could be dropped in practice, and recursive definitions applied. -I'm sure there's stuff above which is unclear -- please let me know if so. There's more to say about a lot of this stuff, but I'll discuss as they come up in the interest of keeping this is as brief as possible. +I'm sure there's stuff above which is unclear -- please let me know if so. There's more to say about a lot of this stuff, but I'll stop here in the interest of keeping this concise. + +Please now go and look at `examples.jl`. From a15f46918c46940fe373e098b47e5d6d6b8820bc Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:57:53 +0100 Subject: [PATCH 15/36] Tweak notes --- notes.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/notes.md b/notes.md index 6d7c53b79..020c07b27 100644 --- a/notes.md +++ b/notes.md @@ -195,15 +195,15 @@ function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMa function my_mul_generic_pullback(C̄) # Recover natural cotangent. - _, C̄_nat = restructure_C_pb(C̄) + C̄_nat = restructure_C_pb(C̄) # Compute pullback using natural cotangent of C. Ā_nat = C̄_nat * B' B̄_nat = A' * C̄_nat # Transform natural cotangents w.r.t. A and B into structural (if non-primitive). - _, Ā = destructure_A_pb(Ā_nat) - _, B̄ = destructure_B_pb(B̄_nat) + Ā = destructure_A_pb(Ā_nat) + B̄ = destructure_B_pb(B̄_nat) return NoTangent(), Ā, B̄ end From 22efec34d614988eff3f8e76d5f6be05fdcc6036 Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 22:59:00 +0100 Subject: [PATCH 16/36] Fix typo: --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index 020c07b27..a6fc2e50b 100644 --- a/notes.md +++ b/notes.md @@ -18,7 +18,7 @@ Apologies in advance for the lenght. I've tried to condense where possible, but ## Cutting to the Chase -Rule-implementers would write rules that look like this: +Under the proposed system, rule-implementers would write rules that look like this: ```julia function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMatrix) From 2b0853eb74e179e2441cd9e8e0ed063185d48b4c Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 23:00:46 +0100 Subject: [PATCH 17/36] Tweak notes --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index a6fc2e50b..52852089b 100644 --- a/notes.md +++ b/notes.md @@ -1,6 +1,6 @@ # A General Mechanism for Generic Rules for AbstractArrays -That we don't have a general formalism for deriving natural derivatives has been discussed quite a bit recently. As has our lack of understanding of the precise relationship between the generic rrules we're writing, and what AD would do. This PR proposes a recipe for deriving generic rules, which leads to a possible formalism for natural derivatives. This formalism can be applied to any AbstractArray, and AD can be in principle be used to obtain default values for the natural tangent. Moreover, there's some utility functionality proposed to make working with this formalism straightforward for rule-writers. +That we don't have a general formalism for deriving natural (co)tangents has been discussed quite a bit recently. As has our lack of understanding of the precise relationship between the generic rrules we're writing, and what AD would do. This PR proposes a recipe for deriving generic rules, which leads to a possible formalism for natural derivatives. This formalism can be applied to any AbstractArray, and AD can be in principle be used to obtain default values for the natural tangent. Moreover, there's some utility functionality proposed to make working with this formalism straightforward for rule-writers. I want reviewers to determine whether they agree that the proposed recipe From 73a5fb7dadca7687f7b8d63c88e4bcf603221a72 Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 23:03:25 +0100 Subject: [PATCH 18/36] Tweak notes --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index 52852089b..b6feeb863 100644 --- a/notes.md +++ b/notes.md @@ -62,7 +62,7 @@ Other than the headlines at the top, additional benefits of (a correct implement 1. generic constructor for composite types easy to implement, 1. due to the utility functionality, all of the examples that I've encountered so far are very concise. -The potential downside is additional conversions between natural and structural tangents. Most of the time, these are free. When they're not, you ideally want to minimise them. I'm not sure how often this is going to be a problem, but it's something we're essentially ignoring at the minute (as far as I know), so we're probably going to have to incur some additional cost even if we don't go down the route proposed here. +The potential downside is additional conversions between natural and structural tangents. Most of the time, these are free. When they're not, you ideally want to minimise them. I'm not sure how often this is going to be a problem, but it's something we're essentially ignoring at the minute (as far as I know), so we're probably going to have to incur some additional cost at some point even if we don't go down the route proposed here. From ddb32376b9b970add33b7f41c5c6290d06a1264d Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 23:04:04 +0100 Subject: [PATCH 19/36] Tweak notes --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index b6feeb863..2a1f56256 100644 --- a/notes.md +++ b/notes.md @@ -84,7 +84,7 @@ I'm intentionally not trying to define precisely what a primitive is, but will a Assume that structural tangents are a valid representation of a tangent of any non-primitive composite type, convenience for rule-writing aside. -More generally, assume that if an AD successfully runs on a given function under the above assumptions, they give the answer desired (the "correct" answer). Consequently, the core goals of the proposed recipe are to make it +More generally, assume that if an AD successfully runs on a given function under the above assumptions, they give the answer desired (the "correct" answer). Consequently, two core goals of the proposed recipe are to make it 1. easy to never write a rule which prevents an AD from differentiating a programme that they already know how to differentiate, 2. make it easy to write rules using intuitive representations of tangents. From 34e43d60e5279612b96fe8f88b7ce78798d9963d Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 23:05:01 +0100 Subject: [PATCH 20/36] Clarify notes --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index 2a1f56256..43a854d28 100644 --- a/notes.md +++ b/notes.md @@ -98,7 +98,7 @@ First consider a specific case -- we'll optimise the implementation later, and p Consider a function `f(x::AbstractArray) -> y::AbstractArray`. Lets assume that there's just one method, so we can be sure that a generic fallback will be hit, regardless the concrete type of the argument. -The intuition behind the recipe is to find a function which is equivalent to `f`, whose rules we know how to write safely. If we can find such a function, AD-ing it will clearly give the correct answer -- the following lays out an approach to doing this. +The intuition behind the recipe is to find a function which is equivalent to `f`, whose rules we know how to write safely. If we can find such a function, AD-ing it will clearly give the correct answer (the same as running AD on `f` itself) -- the following lays out an approach to doing this. The recipe is: 1. Map `x` to an `Array`, `x_dense`, using `getindex`. Call this operation `destructure` (it's essentially `collect`). From c7730456b773eb1641da6e37f44c244b8e7e47b0 Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 23:06:09 +0100 Subject: [PATCH 21/36] Clarify notes --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index 43a854d28..7c398406b 100644 --- a/notes.md +++ b/notes.md @@ -111,7 +111,7 @@ I'm going to define equivalence of output structurally -- two `AbstractArray`s a The reason for this notion of equality is that AD (as proposed above) treats concrete subtypes of AbstractArray no differently from any other composite type. -A very literal implementation of this for a function like `*` is something like the following: +Applying this recipe, a very literal implementation for the `rrule` of the equivalent function for `*` is something like the following: ```julia function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMatrix) From eaaeba7b60a7b255c47435de8fa66280c4781fcf Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 23:06:46 +0100 Subject: [PATCH 22/36] Clarify notes --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index 7c398406b..4391b78ff 100644 --- a/notes.md +++ b/notes.md @@ -149,7 +149,7 @@ function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMa return C, my_mul_generic_pullback end ``` -I've just written out by hand the rrule for differentiating through the equivalent function. +All I've done is write out the rrule for differentiating through the equivalent function by hand. We'll optimise this implementation shortly to avoid e.g. having to densify primals, and computing the same function twice. `my_mul` in `examples.jl` verifies the correctness of the above implementation. From d94514d6055d5decaf08b7bb2d8de2c2c5a6752d Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 23:07:43 +0100 Subject: [PATCH 23/36] Tweak notes --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index 4391b78ff..5c4688907 100644 --- a/notes.md +++ b/notes.md @@ -176,7 +176,7 @@ For example, this means that a `Diagonal{Float64}` is a valid (co)tangent for an -## Optimising rrules using Natural Pullbacks +## Optimising rrules for the equivalent function The basic example layed out above was very sub-optimal. Consider the following (equivalent) re-write ```julia From d571ad175f89ac233e9084add9a07f941150e5db Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 23:09:44 +0100 Subject: [PATCH 24/36] Tweak notes --- notes.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notes.md b/notes.md index 5c4688907..5b5781eeb 100644 --- a/notes.md +++ b/notes.md @@ -212,8 +212,8 @@ end ``` A few observations: 1. All dense primals are gone. In the pullback, they only appeared in places where they can be safely replaced with the primals themselves because they're doing array-like things. `C_dense` appeared in the construction of `restructure_C_pb`, however, we were using a sub-optimal implementation of that function. Much of the time, `restructure_of_pb` doesn't require `C_dense` in order to know what the pullback would look like and, if it does, it can be obtained from `C`. -2. All direct calls to `rrule_via_ad` have been replaced with calls to functions which are defined to returns the things we actually need (the pullbacks). These typically have efficient (and easy to write) implementations. -3. `C̄_nat` could be any old `AbstractArray`. For example, `pullback_of_restructure` for a `Diagonal` returns a `Diagonal`. This is good -- it means we might occassionally get faster computations in the pullback. +2. All direct calls to `rrule_via_ad` have been replaced with calls to functions which are defined to return the things we actually need (the pullbacks). These typically have efficient (and easy to write) implementations. +3. `C̄_nat` could be any old `AbstractArray`, because it's conceptually the cotangent for an `Array`. For example, `pullback_of_restructure` for a `Diagonal` returns a `Diagonal`. This is good -- it means we might occassionally get faster computations in the pullback. Roughly speaking, the above implementation has only one additional operation than our existing rrules involving `ProjectTo`, which is a call to `restructure_C_pb`, which handles converting a structural tangent for `C̄` into the corresponding natural. Currently we require users to do this by hand, and no clear guidance is provided regarding the correct way to handle this conversion, in contrast to the clarity provided here. In this sense, all that the above is doing is providing a well-defined mechanism by which users can obtain natural cotangents from structural cotangents, so it should ease the burden on rule-implementers. From d87d5c4f60d2c951c2f24579ceb796d83e5541b2 Mon Sep 17 00:00:00 2001 From: WT Date: Sat, 28 Aug 2021 23:16:46 +0100 Subject: [PATCH 25/36] Tweak examples comments --- examples.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples.jl b/examples.jl index a88aeff54..079e9cdd1 100644 --- a/examples.jl +++ b/examples.jl @@ -55,7 +55,7 @@ my_sum(x::AbstractArray) = sum(x) function ChainRulesCore.rrule(config::RuleConfig, ::typeof(my_sum), x::AbstractArray) y = my_sum(x) - natural_pullback_my_sum(ȳ::Real) = NoTangent(), fill(ȳ, size(x)) + natural_pullback_my_sum(ȳ::Real) = NoTangent(), fill(ȳ, size(x)) # Fill also fine here. return y, wrap_natural_pullback(config, natural_pullback_my_sum, y, x) end @@ -155,7 +155,7 @@ test_approx(dx, dx_fd) # Example 4: ScaledVector. This is an interesting example because I truly had no idea how to -# specify a natural tangent for this before. +# specify a natural tangent prior to this work. # Implement AbstractArray interface. struct ScaledMatrix <: AbstractMatrix{Float64} From a7eb01cf46e94199a2ed8a2f7a018d831725e0b2 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 29 Aug 2021 12:03:41 +0100 Subject: [PATCH 26/36] Tidy up + add Fill example --- examples.jl | 268 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 186 insertions(+), 82 deletions(-) diff --git a/examples.jl b/examples.jl index 079e9cdd1..6f3cb52ef 100644 --- a/examples.jl +++ b/examples.jl @@ -21,24 +21,26 @@ function rrule(config::RuleConfig, ::typeof(my_mul), A::AbstractMatrix, B::Abstr return C, wrap_natural_pullback(config, natural_pullback_for_mul, C, A, B) end -A = randn(4, 3); -B = Symmetric(randn(3, 3)); -C, pb = Zygote.pullback(my_mul, A, B); +let + A = randn(4, 3); + B = Symmetric(randn(3, 3)); + C, pb = Zygote.pullback(my_mul, A, B); -@assert C ≈ my_mul(A, B) + @assert C ≈ my_mul(A, B) -dC = randn(4, 3); -dA, dB_zg = pb(dC); -dB = Tangent{typeof(B)}(data=dB_zg.data); + dC = randn(4, 3); + dA, dB_zg = pb(dC); + dB = Tangent{typeof(B)}(data=dB_zg.data); -# Test correctness. -dA_fd, dB_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, dC, A, B); + # Test correctness. + dA_fd, dB_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, dC, A, B); -# to_vec doesn't know how to make `Tangent`s, so instead I map it to a `Tangent` manually. -dB_fd = Tangent{typeof(B)}(data=dB_fd_sym.data); + # to_vec doesn't know how to make `Tangent`s, so translate manually. + dB_fd = Tangent{typeof(B)}(data=dB_fd_sym.data); -test_approx(dA, dA_fd) -test_approx(dB, dB_fd) + test_approx(dA, dA_fd) + test_approx(dB, dB_fd) +end @@ -59,19 +61,21 @@ function ChainRulesCore.rrule(config::RuleConfig, ::typeof(my_sum), x::AbstractA return y, wrap_natural_pullback(config, natural_pullback_my_sum, y, x) end -A = Symmetric(randn(2, 2)) -y, pb = Zygote.pullback(my_sum, A) +let + A = Symmetric(randn(2, 2)) + y, pb = Zygote.pullback(my_sum, A) -test_approx(y, my_sum(A)) + test_approx(y, my_sum(A)) -dy = randn() -dA_zg, = pb(dy) -dA = Tangent{typeof(A)}(data=dA_zg.data) + dy = randn() + dA_zg, = pb(dy) + dA = Tangent{typeof(A)}(data=dA_zg.data) -dA_fd_sym, = FiniteDifferences.j′vp(central_fdm(5, 1), my_sum, dy, A) -dA_fd = Tangent{typeof(A)}(data=dA_fd_sym.data) + dA_fd_sym, = FiniteDifferences.j′vp(central_fdm(5, 1), my_sum, dy, A) + dA_fd = Tangent{typeof(A)}(data=dA_fd_sym.data) -test_approx(dA, dA_fd) + test_approx(dA, dA_fd) +end @@ -90,66 +94,70 @@ function ChainRulesCore.rrule( end # DENSE TEST -a = randn() -x = randn(2, 2) -y, pb = Zygote.pullback(my_scale, a, x) - -dy = randn(size(y)) -da, dx = pb(dy) +let + a = randn() + x = randn(2, 2) + y, pb = Zygote.pullback(my_scale, a, x) -da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x) + dy = randn(size(y)) + da, dx = pb(dy) -test_approx(y, my_scale(a, x)) -test_approx(da, da_fd) -test_approx(dx, dx_fd) + da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x) + test_approx(y, my_scale(a, x)) + test_approx(da, da_fd) + test_approx(dx, dx_fd) +end # DIAGONAL TEST # `diag` now returns a `Diagonal` as a tangnet, so have to define `my_diag` to make this # work with Diagonal`s. + my_diag(x) = diag(x) function ChainRulesCore.rrule(::typeof(my_diag), D::P) where {P<:Diagonal} my_diag_pullback(d) = NoTangent(), Tangent{P}(diag=d) return diag(D), my_diag_pullback end -a = randn() -x = Diagonal(randn(2)) -y, pb = Zygote.pullback(my_diag ∘ my_scale, a, x) +let + a = randn() + x = Diagonal(randn(2)) + y, pb = Zygote.pullback(my_diag ∘ my_scale, a, x) -ȳ = randn(2) -ā, x̄_zg = pb(ȳ) -x̄ = Tangent{typeof(x)}(diag=x̄_zg.diag) + ȳ = randn(2) + ā, x̄_zg = pb(ȳ) + x̄ = Tangent{typeof(x)}(diag=x̄_zg.diag) -ā_fd, _x̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_diag ∘ my_scale, ȳ, a, x) -x̄_fd = Tangent{typeof(x)}(diag=_x̄_fd.diag) + ā_fd, _x̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_diag ∘ my_scale, ȳ, a, x) + x̄_fd = Tangent{typeof(x)}(diag=_x̄_fd.diag) -test_approx(y, (my_diag ∘ my_scale)(a, x)) -test_approx(ā, ā_fd) -test_approx(x̄, x̄_fd) + test_approx(y, (my_diag ∘ my_scale)(a, x)) + test_approx(ā, ā_fd) + test_approx(x̄, x̄_fd) +end # SYMMETRIC TEST - FAILS BECAUSE PRIVATE ELEMENTS IN LOWER-DIAGONAL ACCESSED IN PRIMAL! # I would be surprised if we're doing this consistently at the minute though. +let + a = randn() + x = Symmetric(randn(2, 2)) + y, pb = Zygote.pullback(my_scale, a, x) -a = randn() -x = Symmetric(randn(2, 2)) -y, pb = Zygote.pullback(my_scale, a, x) + dy = Tangent{typeof(y)}(data=randn(2, 2)) + da, dx_zg = pb(dy) + dx = Tangent{typeof(x)}(data=dx_zg.data) -dy = Tangent{typeof(y)}(data=randn(2, 2)) -da, dx_zg = pb(dy) -dx = Tangent{typeof(x)}(data=dx_zg.data) - -da_fd, dx_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x) -dx_fd = Tangent{typeof(x)}(data=dx_fd_sym.data) - -test_approx(y.data, my_scale(a, x).data) -test_approx(da, da_fd) -test_approx(dx, dx_fd) + da_fd, dx_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x) + dx_fd = Tangent{typeof(x)}(data=dx_fd_sym.data) + test_approx(y.data, my_scale(a, x).data) + test_approx(da, da_fd) + test_approx(dx, dx_fd) +end @@ -207,41 +215,137 @@ function ChainRulesCore.rrule( return z, wrap_natural_pullback(config, natural_pullback_my_dot, z, x, y) end +let + # Check correctness of `my_dot` rrule. Build `ScaledMatrix` internally to avoid + # technical issues with FiniteDifferences. + V = randn(2, 2) + α = randn() + z̄ = randn() -# Check correctness of `my_dot` rrule. Build `ScaledMatrix` internally to avoid technical -# issues with FiniteDifferences. -V = randn(2, 2) -α = randn() -z̄ = randn() + foo_scal(V, α) = my_dot(ScaledMatrix(V, α), V) -foo_scal(V, α) = my_dot(ScaledMatrix(V, α), V) + z, pb = Zygote.pullback(foo_scal, V, α) + dx_ad = pb(z̄) -z, pb = Zygote.pullback(foo_scal, V, α) -dx_ad = pb(z̄) + dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_scal, z̄, V, α) -dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_scal, z̄, V, α) - -test_approx(dx_ad, dx_fd) + test_approx(dx_ad, dx_fd) +end -# A function with a specialised method for ScaledMatrix. +# Specialised method of my_scale for ScaledMatrix my_scale(a::Real, X::ScaledMatrix) = ScaledMatrix(X.v, X.α * a) -# Verify correctness. -a = randn() -V = randn(2, 2) -α = randn() -z̄ = randn() +let + # Verify correctness. + a = randn() + V = randn(2, 2) + α = randn() + z̄ = randn() -# A more complicated programme involving `my_scale`. -B = randn(2, 2) -foo_my_scale(a, V, α) = my_dot(B, my_scale(a, ScaledMatrix(V, α))) + # A more complicated programme involving `my_scale`. + B = randn(2, 2) + foo_my_scale(a, V, α) = my_dot(B, my_scale(a, ScaledMatrix(V, α))) -z, pb = Zygote.pullback(foo_my_scale, a, V, α) -da, dV, dα = pb(z̄) + z, pb = Zygote.pullback(foo_my_scale, a, V, α) + da, dV, dα = pb(z̄) -da_fd, dV_fd, dα_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, z̄, a, V, α) + da_fd, dV_fd, dα_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, z̄, a, V, α) + + test_approx(da, da_fd) + test_approx(dV, dV_fd) + test_approx(dα, dα_fd) +end + + + + +# Example 5: Fill + +using FillArrays + +# What you would implement: +# destucture(x::Fill) = collect(x) + +function pullback_of_destructure(config::RuleConfig, x::P) where {P<:Fill} + pullback_destructure_Fill(X̄::AbstractArray) = Tangent{P}(value=sum(X̄)) + return pullback_destructure_Fill +end -test_approx(da, da_fd) -test_approx(dV, dV_fd) -test_approx(dα, dα_fd) +# There are multiple equivalent choices for Restructure here. I present two options below, +# both yield the correct answer. +# To understand why there is a range of options here, recall that the input to +# (::Restructure)(x::AbstractArray) must be an array that is equal (in the `getindex` sense) +# to a `Fill`. Moreover, `(::Restructure)` simply has to promise that the `Fill` it outputs +# is equal to the `Fill` from which the `Restructure` is constructed (in the structural +# sense). Since the elements of `x` must all be equal, any affine combination will do ( +# weighted sum, whose weights sum to 1). +# While there is no difference in the primal for `(::Restructure)`, the pullback is quite +# different, depending upon your choice. Since we don't ever want to evaluate the primal for +# `(::Restructure)`, just the pullback, we are free to choose whatever definition of +# `(::Restructure)` makes its pullback pleasant. In particular, defining `(::Restructure)` +# to take the mean of its argument yields a pleasant pullback (see below). + +# Restucture option 1: + +# Restructure(x::P) where {P<:Fill} = Restructure{P, typeof(x.axes)}(x.axes) +# (r::Restructure{<:Fill})(x::AbstractArray) = Fill(x[1], r.data) + +function pullback_of_restructure(config::RuleConfig, x::Fill) + println("Option 1") + function pullback_restructure_Fill(x̄::Tangent) + X̄ = zeros(size(x)) + X̄[1] = x̄.value + return X̄ + end + return pullback_restructure_Fill +end + +# Restructure option 2: + +# Restructure(x::P) where {P<:Fill} = Restructure{P, typeof(x.axes)}(x.axes) +# (r::Restructure{<:Fill})(x::AbstractArray) = Fill(mean(x), r.data) + +function pullback_of_restructure(config::RuleConfig, x::Fill) + println("Option 2") + pullback_restructure_Fill(x̄::Tangent) = Fill(x̄.value / length(x), x.axes) + return pullback_restructure_Fill +end + + +let + A = randn(2, 3) + v = randn() + + # Build the Fill inside because FiniteDifferenes doesn't play nicely with Fills, even + # if one adds a `to_vec` call. + foo_my_mul(A, v) = my_mul(A, Fill(v, 3, 4)) + C, pb = Zygote.pullback(foo_my_mul, A, v) + + @assert C ≈ foo_my_mul(A, v) + + C̄ = randn(2, 4); + Ā, v̄ = pb(C̄); + + Ā_fd, v̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_mul, C̄, A, v); + + test_approx(Ā, Ā_fd) + test_approx(v̄, v̄_fd) +end + +let + foo_my_scale(a, v) = my_sum(my_scale(a, Fill(v, 3, 4))) + a = randn() + v = randn() + + c, pb = Zygote.pullback(foo_my_scale, a, v) + @assert c ≈ foo_my_scale(a, v) + + c̄ = randn() + ā, v̄ = pb(c̄) + + ā_fd, v̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, c̄, a, v); + + test_approx(ā, ā_fd) + test_approx(v̄, v̄_fd) +end From e691b6a5f612e88376fe9723b27e378fd543f953 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 29 Aug 2021 12:24:24 +0100 Subject: [PATCH 27/36] Add SArray example --- examples.jl | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/examples.jl b/examples.jl index 6f3cb52ef..dad5ab09d 100644 --- a/examples.jl +++ b/examples.jl @@ -136,7 +136,6 @@ let test_approx(y, (my_diag ∘ my_scale)(a, x)) test_approx(ā, ā_fd) test_approx(x̄, x̄_fd) - end @@ -312,7 +311,7 @@ function pullback_of_restructure(config::RuleConfig, x::Fill) return pullback_restructure_Fill end - +# An example which uses `pullback_of_destructure(::Fill)` because `Fill` is an input. let A = randn(2, 3) v = randn() @@ -333,6 +332,23 @@ let test_approx(v̄, v̄_fd) end +# Another example using `pullback_of_destructure(::Fill)`. +let + foo_my_sum(v) = my_sum(Fill(v, 4, 3)) + v = randn() + + c, pb = Zygote.pullback(foo_my_sum, v) + @assert c ≈ foo_my_sum(v) + + c̄ = randn() + v̄ = pb(c̄) + + v̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_sum, c̄, v); + + test_approx(v̄, v̄_fd) +end + +# An example using `pullback_of_restructure(::Fill)`. let foo_my_scale(a, v) = my_sum(my_scale(a, Fill(v, 3, 4))) a = randn() @@ -349,3 +365,53 @@ let test_approx(ā, ā_fd) test_approx(v̄, v̄_fd) end + + + + +# Example 6: SArray +# This example demonstrates that within this framework we can easily work with structural +# tangents for `SArray`s. Unclear that we _want_ to do this, but it's nice to know that +# it's an option requiring minimal work. +# Notice that this should be performant, since `pullback_of_destructure` and +# `pullback_of_restructure` should be performant, and the operations in the pullback +# will all happen on `SArray`s. + +using StaticArrays + +function pullback_of_destructure(config::RuleConfig, x::P) where {P<:SArray} + pullback_destructure_SArray(X̄::AbstractArray) = Tangent{P}(data=X̄) + return pullback_destructure_SArray +end + +function pullback_of_restructure( + config::RuleConfig, x::SArray{S, T, N, L}, +) where {S, T, N, L} + pullback_restructure_SArray(x̄::Tangent) = SArray{S, T, N, L}(x̄.data) + return pullback_restructure_SArray +end + +# destructure + restructure example with `my_mul`. +let + A = SMatrix{2, 2}(randn(4)...) + B = SMatrix{2, 1}(randn(2)...) + C, pb = Zygote.pullback(my_mul, A, B) + + @assert C ≈ my_mul(A, B) + + C̄ = Tangent{typeof(C)}(data=(randn(2)..., )) + Ā_, B̄_ = pb(C̄) + + # Manually convert Ā_ and B̄_ to Tangents from Zygote types. + Ā = Tangent{typeof(A)}(data=Ā_.data) + B̄ = Tangent{typeof(B)}(data=B̄_.data) + + Ā_fd_, B̄_fd_ = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, C̄, A, B) + + # Manually convert Ā_fd and B̄_fd into Tangents from to_vec output. + Ā_fd = Tangent{typeof(A)}(data=Ā_fd_) + B̄_fd = Tangent{typeof(B)}(data=B̄_fd_) + + test_approx(Ā, Ā_fd) + test_approx(B̄, B̄_fd) +end From d0b0eb09672424067d9debfbf75edffa37de5212 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 29 Aug 2021 12:45:01 +0100 Subject: [PATCH 28/36] Add note on Symmetric restructure --- src/destructure.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/destructure.jl b/src/destructure.jl index fb6f7b988..45107d88b 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -205,6 +205,10 @@ function (r::Restructure{P})(X::Array) where {P<:Symmetric} return Symmetric(UpperTriangular(X) + strict_lower_triangle_of_data) end +# We get to assume that `issymmetric(X)` is (at least roughly) true, so we could also +# implement restructure as Symmetric((X + X') / 2), provided that we then take care of the +# lower triangle as above. + function frule( (_, dX)::Tuple{Any, AbstractArray}, r::Restructure{P}, X::Array, ) where {P<:Symmetric} From d6eef62a25879600b74d4815f66538be843f9ddd Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 29 Aug 2021 12:57:22 +0100 Subject: [PATCH 29/36] Add UpperTriangular example with my_mul --- examples.jl | 30 ++++++++++++++++++++++++++++++ src/destructure.jl | 21 +++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/examples.jl b/examples.jl index dad5ab09d..36244d983 100644 --- a/examples.jl +++ b/examples.jl @@ -42,6 +42,36 @@ let test_approx(dB, dB_fd) end +# Zygote doesn't like the structural tangent here because of an @adjoint +my_upper_triangular(X) = UpperTriangular(X) + +function rrule(::typeof(my_upper_triangular), X::Matrix{<:Real}) + pullback_my_upper_triangular(Ū::Tangent) = NoTangent(), Ū.data + return my_upper_triangular(X), pullback_my_upper_triangular +end + +let + foo_my_mul(U_data, B) = my_mul(my_upper_triangular(U_data), B) + U_data = randn(3, 3) + B = randn(3, 4) + C, pb = Zygote.pullback(foo_my_mul, U_data, B) + + @assert C ≈ foo_my_mul(U_data, B) + + C̄ = randn(3, 4) + Ū_data, B̄ = pb(C̄) + + Ū_data_fd, B̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_mul, C̄, U_data, B) + + test_approx(Ū_data, Ū_data_fd) + test_approx(B̄, B̄_fd) + + display(Ū_data_fd) + println() + display(Ū_data) + println() +end + # pullbacks for `Real`s so that they play nicely with the utility functionality. diff --git a/src/destructure.jl b/src/destructure.jl index 45107d88b..5ce15cd88 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -227,6 +227,27 @@ end +# UpperTriangular +function pullback_of_destructure(U::P) where {P<:UpperTriangular{<:Real, <:StridedMatrix}} + pullback_destructure(Ū::AbstractArray) = Tangent{P}(data=UpperTriangular(Ū)) + return pullback_destructure +end + +function pullback_of_restructure(U::P) where {P<:UpperTriangular} + pullback_restructure(Ū::Tangent) = UpperTriangular(Ū.data) + return pullback_restructure +end + +function pullback_of_destructure(config::RuleConfig, U::UpperTriangular) + return pullback_of_destructure(U) +end + +function pullback_of_restructure(config::RuleConfig, U::UpperTriangular) + return pullback_of_restructure(U) +end + + + # Cholesky -- you get to choose whatever destructuring operation is helpful for a given # type. This one is helpful for writing generic pullbacks for `cholesky`, the output of # which is a Cholesky. Not completed. Probably won't be included in initial merge. From a4ce7d23ef6c87d34d30c676fdb1be4891149474 Mon Sep 17 00:00:00 2001 From: WT Date: Tue, 31 Aug 2021 16:52:36 +0100 Subject: [PATCH 30/36] Add WoodburyPDMat example --- examples.jl | 152 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/examples.jl b/examples.jl index 36244d983..84617f3a5 100644 --- a/examples.jl +++ b/examples.jl @@ -445,3 +445,155 @@ let test_approx(Ā, Ā_fd) test_approx(B̄, B̄_fd) end + + + + + +# Example 7: WoodburyPDMat +# WoodburyPDMat doesn't currently know anything about AD. I have no intention of +# implementing any of the functionality here on it, because it's just fine as it is. +# However, it _is_ any interesting case-study, because it's an example where addition in +# natural (co)tangent space disagrees with addition in structural space. Since we know that +# the notion of addition we have on structural tangents is the desirable one, this indicates +# that we don't always want to add natural tangents. +# It's also interesting because, as with the `ScaledMatrix` example, I had no idea how to +# find a natural (co)tangent prior to this PR. It's a comparatively complicated example, +# and destructure and restructure are non-linear in the fields of `x`, which is another +# interesting property. +# I've only bothered deriving stuff for destructure because restructure is really quite +# complicated and I don't have the time right now to work through the example. +# It does serve to show that it's not always going to be easy for authors of complicated +# array types to make their type work with the natural pullback machinery. At least we +# understand what an author would have to do though, even if it's not straightforward to +# do all of the time. + +using PDMatsExtras: WoodburyPDMat +import ChainRulesCore: destructure, Restructure + +# What destructure would do if we actually implemented it. +# destructure(x::WoodburyPDMat) = x.A * x.D * x.A' + x.S + +# This is an interesting pullback, because it looks like the destructuring mechanism +# is required here to ensure that the fields `D` and `S` are handled appropriately. +# A would also be necessary in general, but I'm assuming it's just a `Matrix{<:Real}` for +# now. +function pullback_of_destructure(config::RuleConfig, x::P) where {P<:WoodburyPDMat} + println("Hitting pullback") + pb_destructure_D = pullback_of_destructure(x.D) + pb_destructure_S = pullback_of_destructure(x.S) + function pullback_destructure_WoodburyPDMat(x̄::AbstractArray) + S̄ = pb_destructure_S(x̄) + D̄ = pb_destructure_D(x.A' * x̄ * x.A) + Ā = (x̄ + x̄') * x.A * x.D + return Tangent{P}(A=Ā, D=D̄, S=S̄) + end + return pullback_destructure_WoodburyPDMat +end + +# Check my_sum correctness. Doesn't require Restructure since the output is a Real. +let + A = randn(4, 2) + d = rand(2) .+ 1 + s = rand(4) .+ 1 + + foo(A, d, s) = my_sum(WoodburyPDMat(A, Diagonal(d), Diagonal(s))) + + Y, pb = Zygote.pullback(foo, A, d, s) + test_approx(Y, foo(A, d, s)) + + Ȳ = randn() + Ā, d̄, s̄ = pb(Ȳ) + + Ā_fd, d̄_fd, s̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, Ȳ, A, d, s) + + test_approx(Ā, Ā_fd) + test_approx(d̄, d̄_fd) + test_approx(s̄, s̄_fd) +end + +# Check my_mul correctness. Doesn't require Restructure since the output is an Array. +# This is a truly awful implementation (generic fallback for my_mul, and getindex is really +# very expensive for WoodburyPDMat), but it ought to work. +let + A = randn(4, 2) + d = rand(2) .+ 1 + s = rand(4) .+ 1 + b = randn() + + # Multiply some interesting types together. + foo(A, d, s, b) = my_mul(WoodburyPDMat(A, Diagonal(d), Diagonal(s)), Fill(b, 4, 3)) + + Y, pb = Zygote.pullback(foo, A, d, s, b) + test_approx(Y, foo(A, d, s, b)) + + Ȳ = randn(4, 3) + Ā, d̄, s̄, b̄ = pb(Ȳ) + + Ā_fd, d̄_fd, s̄_fd, b̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, Ȳ, A, d, s, b) + + test_approx(Ā, Ā_fd) + test_approx(d̄, d̄_fd) + test_approx(s̄, s̄_fd) + test_approx(b̄, b̄_fd) +end + + +# THIS EXAMPLE DOESN'T WORK. +# I'm not really sure why, but it's not the responsibility of this PR because I'm just +# trying to opt out of the generic rrule for my_scale, because there's a specialised +# implementation available. +# I've tried all of the opt-outs I can think of, but no luck -- it keeps hitting +# ChainRules for me :( + +# Opt-out and refresh. +ChainRulesCore.@opt_out rrule(::typeof(my_scale), ::Real, ::WoodburyPDMat) +ChainRulesCore.@opt_out rrule(::typeof(*), ::Real, ::WoodburyPDMat) + +ChainRulesCore.@opt_out rrule(::typeof(my_scale), ::WoodburyPDMat, ::Real) +ChainRulesCore.@opt_out rrule(::typeof(*), ::WoodburyPDMat, ::Real) + +ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(my_scale), ::Real, ::WoodburyPDMat) +ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(my_scale), ::WoodburyPDMat, ::Real) +ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(*), ::Real, ::WoodburyPDMat) +ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(*), ::WoodburyPDMat, ::Real) +Zygote.refresh() + +# Something currently produces a `Diagonal` cotangent somewhere, so have to add this +# accumulate rule. +Zygote.accum(x::NamedTuple{(:diag, )}, y::Diagonal) = (diag=x.diag + y.diag, ) + +# Something else is a producing a `Matrix`... +Zygote.accum(x::NamedTuple{(:diag, )}, y::Matrix) = (diag=x.diag + diag(y), ) + +# This should just hit [this code](https://github.com/invenia/PDMatsExtras.jl/blob/b7b3a2035682465f1471c2d2e1e017b9fd75cec0/src/woodbury_pd_mat.jl#L92) +let + α = rand() + A = randn(4, 2) + d = rand(2) .+ 10 + s = rand(4) .+ 1 + + # Multiply some interesting types together. + foo(α, A, d, s) = my_scale(α, WoodburyPDMat(A, Diagonal(d), Diagonal(s))) + + Y, pb = Zygote.pullback(foo, α, A, d, s) + test_approx(Y, foo(α, A, d, s)) + + Ȳ = ( + A=randn(4, 2), + D=(diag=randn(2),), + S=(diag=randn(4),), + ) + ᾱ, Ā, d̄, s̄ = pb(Ȳ) + + # FiniteDifferences doesn't play nicely with structural tangents for Diagonals, + # so I would have to do things manually to properly test this one. Not going to do that + # because I've not actually used any hand-written rules here, other than the + # Zygote.accum calls above, which look fine to me. + # ᾱ_fd, Ā_fd, d̄_fd, s̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, Ȳ, α, A, d, s) + + # test_approx(ᾱ, ᾱ_fd) + # test_approx(Ā, Ā_fd) + # test_approx(d̄, d̄_fd) + # test_approx(s̄, s̄_fd) +end From a88075297b947fff05da90bba4c4429d6fbd6efa Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 1 Sep 2021 23:25:06 +0100 Subject: [PATCH 31/36] Note that natural tangent addition is fine --- examples.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples.jl b/examples.jl index 84617f3a5..b10452aba 100644 --- a/examples.jl +++ b/examples.jl @@ -453,10 +453,10 @@ end # Example 7: WoodburyPDMat # WoodburyPDMat doesn't currently know anything about AD. I have no intention of # implementing any of the functionality here on it, because it's just fine as it is. -# However, it _is_ any interesting case-study, because it's an example where addition in +# ~~However, it _is_ any interesting case-study, because it's an example where addition in # natural (co)tangent space disagrees with addition in structural space. Since we know that # the notion of addition we have on structural tangents is the desirable one, this indicates -# that we don't always want to add natural tangents. +# that we don't always want to add natural tangents.~~ See below. # It's also interesting because, as with the `ScaledMatrix` example, I had no idea how to # find a natural (co)tangent prior to this PR. It's a comparatively complicated example, # and destructure and restructure are non-linear in the fields of `x`, which is another @@ -467,6 +467,9 @@ end # array types to make their type work with the natural pullback machinery. At least we # understand what an author would have to do though, even if it's not straightforward to # do all of the time. +# edit: natural tangents should always add properly. Ignore what is said in the para above +# about them not adding properly here. It's still interesting for the other reasons listed +# though. using PDMatsExtras: WoodburyPDMat import ChainRulesCore: destructure, Restructure From 6cada13982db477961f7747a1e139057d2ab3030 Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 1 Sep 2021 23:26:35 +0100 Subject: [PATCH 32/36] Add another examples comment --- examples.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples.jl b/examples.jl index b10452aba..d11ab40c8 100644 --- a/examples.jl +++ b/examples.jl @@ -562,6 +562,8 @@ ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(*), ::Real, :: ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(*), ::WoodburyPDMat, ::Real) Zygote.refresh() +# Not sure why these accum rules are needed, as the code which follows doesn't look to me +# like it should have an accumulations in it... # Something currently produces a `Diagonal` cotangent somewhere, so have to add this # accumulate rule. Zygote.accum(x::NamedTuple{(:diag, )}, y::Diagonal) = (diag=x.diag + y.diag, ) From 8f063643cfdbe39340bf4878ed593cc9739f4fc7 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 5 Sep 2021 14:00:32 +0100 Subject: [PATCH 33/36] Add Kronecker example and fix WoodburPDMat example --- examples.jl | 151 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 125 insertions(+), 26 deletions(-) diff --git a/examples.jl b/examples.jl index d11ab40c8..96d90b617 100644 --- a/examples.jl +++ b/examples.jl @@ -16,6 +16,7 @@ using ChainRulesCore: RuleConfig, wrap_natural_pullback my_mul(A::AbstractMatrix, B::AbstractMatrix) = A * B function rrule(config::RuleConfig, ::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix) + println("my_mul generic rrule") C = A * B natural_pullback_for_mul(C̄) = NoTangent(), C̄ * B', A' * C̄ return C, wrap_natural_pullback(config, natural_pullback_for_mul, C, A, B) @@ -550,27 +551,17 @@ end # ChainRules for me :( # Opt-out and refresh. -ChainRulesCore.@opt_out rrule(::typeof(my_scale), ::Real, ::WoodburyPDMat) -ChainRulesCore.@opt_out rrule(::typeof(*), ::Real, ::WoodburyPDMat) +using ChainRules: CommutativeMulNumber +@opt_out rrule(::RuleConfig, ::typeof(my_scale), ::Real, ::WoodburyPDMat) +@opt_out rrule(::typeof(*), ::WoodburyPDMat{<:CommutativeMulNumber}, ::CommutativeMulNumber) +@opt_out rrule(::typeof(*), ::CommutativeMulNumber, ::WoodburyPDMat{<:CommutativeMulNumber}) -ChainRulesCore.@opt_out rrule(::typeof(my_scale), ::WoodburyPDMat, ::Real) -ChainRulesCore.@opt_out rrule(::typeof(*), ::WoodburyPDMat, ::Real) +# This should probably go in ChainRules anyway. +@opt_out rrule(::typeof(*), ::CommutativeMulNumber, ::Diagonal{<:CommutativeMulNumber}) +@opt_out rrule(::typeof(*), ::Diagonal{<:CommutativeMulNumber}, ::CommutativeMulNumber) -ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(my_scale), ::Real, ::WoodburyPDMat) -ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(my_scale), ::WoodburyPDMat, ::Real) -ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(*), ::Real, ::WoodburyPDMat) -ChainRulesCore.@opt_out rrule(::Zygote.ZygoteRuleConfig, ::typeof(*), ::WoodburyPDMat, ::Real) Zygote.refresh() -# Not sure why these accum rules are needed, as the code which follows doesn't look to me -# like it should have an accumulations in it... -# Something currently produces a `Diagonal` cotangent somewhere, so have to add this -# accumulate rule. -Zygote.accum(x::NamedTuple{(:diag, )}, y::Diagonal) = (diag=x.diag + y.diag, ) - -# Something else is a producing a `Matrix`... -Zygote.accum(x::NamedTuple{(:diag, )}, y::Matrix) = (diag=x.diag + diag(y), ) - # This should just hit [this code](https://github.com/invenia/PDMatsExtras.jl/blob/b7b3a2035682465f1471c2d2e1e017b9fd75cec0/src/woodbury_pd_mat.jl#L92) let α = rand() @@ -591,14 +582,122 @@ let ) ᾱ, Ā, d̄, s̄ = pb(Ȳ) - # FiniteDifferences doesn't play nicely with structural tangents for Diagonals, - # so I would have to do things manually to properly test this one. Not going to do that - # because I've not actually used any hand-written rules here, other than the - # Zygote.accum calls above, which look fine to me. - # ᾱ_fd, Ā_fd, d̄_fd, s̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, Ȳ, α, A, d, s) + # No need to test correctness with FiniteDifferences because I've just opted out of + # rules, so if the code runs, it ought to be correct. If it's not correct, it must be a + # bug from somewhere other than here. +end + + + +# Example 8: Kronecker.jl +# This is a good example of a package where you want to opt out of basically all rules +# which apply to functions that Kronecker has implemented, but you also really want to +# ensure that the generic fallbacks work for operations that you don't care about so much +# (as discussed in various locations, typically if someone cares about the performance of +# an operations on a Kronecker matrix, they'll have implemented a specialised method to do +# it.) +# I have absolutely no idea how to implement restructure for this type. + +using Kronecker +using Kronecker: KroneckerProduct + +# destructure(X::KroneckerProduct) = kron(X.A, X.B) + +# Just differentiate the implementation of `kron` from the stdlib to produce the pullback. +# Again, I'm assuming that X.A and X.B are dense matrix-like things for now, but will +# generalise at some point. +function pullback_of_destructure(config::RuleConfig, X::T) where {T<:KroneckerProduct} + A = X.A + B = X.B + function pullback_destructure_KroneckerProduct(X̄::AbstractMatrix) + println("pullback_destructure_KroneckerProduct") + m = length(X̄) + Ā = zero(A) + B̄ = zero(B) + for j in reverse(1:size(A, 2)) + for l in reverse(1:size(B, 2)) + for i in reverse(1:size(A, 1)) + for k in reverse(1:size(B, 1)) + Ā[i, j] += X̄[m] * B[k, l] + B̄[k, l] += X̄[m] * A[i, j] + m -= 1 + end + end + end + end + return Tangent{T}(A=Ā, B=B̄) + end + return pullback_destructure_KroneckerProduct +end + +# Check my_sum correctness. Doesn't require Restructure since the output is a Real. +let + A = randn(4, 2) + B = randn(3, 5) + c̄ = randn() + + foo(A, B) = my_sum(kronecker(A, B)) + + c, pb = Zygote.pullback(foo, A, B) + test_approx(c, foo(A, B)) + + Ā, B̄ = pb(c̄) + + Ā_fd, B̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, c̄, A, B) + + test_approx(Ā, Ā_fd) + test_approx(B̄, B̄_fd) +end + +# Kronecker doesn't implement their efficient `KroneckerProduct * AbstractMatrix` operation +# in a mutation-free manner (at present), so the best we can do without implementing a rule +# is make use of the generic fallback. This is the kind of operation that you would probably +# want to implement a rule for / implement in an AD-friendly manner in reality. +let + A = randn(4, 2) + B = randn(3, 5) + X = randn(10, 2) + + # Multiply some interesting types together. + foo(A, B, X) = my_mul(kronecker(A, B), X) + + C, pb = Zygote.pullback(foo, A, B, X) + test_approx(C, foo(A, B, X)) + + C̄ = randn(12, 2) + Ā, B̄, X̄ = pb(C̄) + + Ā_fd, B̄_fd, X̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, C̄, A, B, X) + + test_approx(Ā, Ā_fd) + test_approx(B̄, B̄_fd) + test_approx(X̄, X̄_fd) +end - # test_approx(ᾱ, ᾱ_fd) - # test_approx(Ā, Ā_fd) - # test_approx(d̄, d̄_fd) - # test_approx(s̄, s̄_fd) +# Opt-out for multiplying two KroneckerProducts together. +using ChainRules: CommutativeMulNumber +@opt_out rrule(::typeof(*), ::AbstractKroneckerProduct{<:CommutativeMulNumber}, ::AbstractKroneckerProduct{<:CommutativeMulNumber}) +@opt_out rrule(::RuleConfig, ::typeof(my_mul), ::AbstractKroneckerProduct, ::AbstractKroneckerProduct) + +let + A = randn(4, 2) + B = randn(3, 5) + C = randn(2, 3) + D = randn(5, 3) + + # Multiply some interesting types together. + foo(A, B, C, D) = my_mul(kronecker(A, B), kronecker(C, D)) + + Z, pb = Zygote.pullback(foo, A, B, C, D) + test_approx(Z, foo(A, B, C, D)) + + Z̄ = (A=randn(size(Z.A)), B=randn(size(Z.B))) + Ā, B̄, C̄, D̄ = pb(Z̄) + + Ā_fd, B̄_fd, C̄_fd, D̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, Z̄, A, B, C, D) + + test_approx(Ā, Ā_fd) + test_approx(B̄, B̄_fd) + test_approx(C̄, C̄_fd) + test_approx(D̄, D̄_fd) end From 5647388d54dcba305e6f3da47999c027bcd45d23 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 5 Sep 2021 14:04:30 +0100 Subject: [PATCH 34/36] Add note on side-effects --- notes.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/notes.md b/notes.md index 5b5781eeb..f68580f27 100644 --- a/notes.md +++ b/notes.md @@ -233,6 +233,8 @@ This is a particularly interesting case because `parent` is exported from `Base` I'm not really sure how to think about this but, as I say, I suspect we're already suffering from it, so I'm not going to worry about it for now. +edit: I think it's best just to think of this "junk" data in the lower triangle as a side-effect. We can't generally guarantee anything in the presence of side effects, so it's no surprise that this is causing some problems. This is part of a more general problem (that we're currently dealing with on an ad-hoc basis in to_vec, and probably need to think more generally about). + From 6d9d01fdfba1e0cf5019d3db00ceb8cc09ec9196 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Tue, 14 Sep 2021 17:01:50 +0100 Subject: [PATCH 35/36] Update notes.md --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index f68580f27..9a2548fee 100644 --- a/notes.md +++ b/notes.md @@ -59,7 +59,7 @@ Other than the headlines at the top, additional benefits of (a correct implement 1. no risk of obstructing AD, 1. rand_tangent can just use structural tangents, simplifying its implementation and improving its reliability, 1. we can probably finally make `to_vec` treat things structurally (although we also need to extend it in other ways), which will also deal with reliability / simplicity of implementation problems, -1. generic constructor for composite types easy to implement, +1. generic constructor for composite types easy to implement since we can be sure of obtaining a structural tangent, and rid ourselves of the "need to define an adjoint for constructor..." errors we see in Zygote, 1. due to the utility functionality, all of the examples that I've encountered so far are very concise. The potential downside is additional conversions between natural and structural tangents. Most of the time, these are free. When they're not, you ideally want to minimise them. I'm not sure how often this is going to be a problem, but it's something we're essentially ignoring at the minute (as far as I know), so we're probably going to have to incur some additional cost at some point even if we don't go down the route proposed here. From 20f77272e1ef2e73ae120229f94ca8974ad6e036 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Tue, 14 Sep 2021 17:10:40 +0100 Subject: [PATCH 36/36] Update notes.md --- notes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notes.md b/notes.md index 9a2548fee..8805e2a06 100644 --- a/notes.md +++ b/notes.md @@ -213,7 +213,7 @@ end A few observations: 1. All dense primals are gone. In the pullback, they only appeared in places where they can be safely replaced with the primals themselves because they're doing array-like things. `C_dense` appeared in the construction of `restructure_C_pb`, however, we were using a sub-optimal implementation of that function. Much of the time, `restructure_of_pb` doesn't require `C_dense` in order to know what the pullback would look like and, if it does, it can be obtained from `C`. 2. All direct calls to `rrule_via_ad` have been replaced with calls to functions which are defined to return the things we actually need (the pullbacks). These typically have efficient (and easy to write) implementations. -3. `C̄_nat` could be any old `AbstractArray`, because it's conceptually the cotangent for an `Array`. For example, `pullback_of_restructure` for a `Diagonal` returns a `Diagonal`. This is good -- it means we might occassionally get faster computations in the pullback. +3. It is permissible for `C̄_nat` to be any old `AbstractArray`, because it's conceptually the cotangent for an `Array`. For example, it is permissible for us to manually optimise `pullback_of_restructure` for a `Diagonal` to ensure that it returns a `Diagonal`. This is good -- it means we can occassionally get faster computations in the natural pullback. Roughly speaking, the above implementation has only one additional operation than our existing rrules involving `ProjectTo`, which is a call to `restructure_C_pb`, which handles converting a structural tangent for `C̄` into the corresponding natural. Currently we require users to do this by hand, and no clear guidance is provided regarding the correct way to handle this conversion, in contrast to the clarity provided here. In this sense, all that the above is doing is providing a well-defined mechanism by which users can obtain natural cotangents from structural cotangents, so it should ease the burden on rule-implementers.