diff --git a/Project.toml b/Project.toml index 245761cef..267c60845 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] BenchmarkTools = "0.5" Compat = "2, 3" -FiniteDifferences = "0.10" +FiniteDifferences = "0.12" OffsetArrays = "1" StaticArrays = "0.11, 0.12, 1" julia = "1" diff --git a/examples.jl b/examples.jl new file mode 100644 index 000000000..96d90b617 --- /dev/null +++ b/examples.jl @@ -0,0 +1,703 @@ +using ChainRulesCore +using ChainRulesTestUtils +using FiniteDifferences +using LinearAlgebra +using Zygote + +import ChainRulesCore: rrule, pullback_of_destructure, pullback_of_restructure + +using ChainRulesCore: RuleConfig, wrap_natural_pullback + +# All of the examples here involve new functions (`my_mul` etc) so that it's possible to +# ensure that Zygote's / ChainRules' existing adjoints don't get in the way. + +# Example 1: matrix-matrix multiplication. + +my_mul(A::AbstractMatrix, B::AbstractMatrix) = A * B + +function rrule(config::RuleConfig, ::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix) + println("my_mul generic rrule") + C = A * B + natural_pullback_for_mul(C̄) = NoTangent(), C̄ * B', A' * C̄ + return C, wrap_natural_pullback(config, natural_pullback_for_mul, C, A, B) +end + +let + A = randn(4, 3); + B = Symmetric(randn(3, 3)); + C, pb = Zygote.pullback(my_mul, A, B); + + @assert C ≈ my_mul(A, B) + + dC = randn(4, 3); + dA, dB_zg = pb(dC); + dB = Tangent{typeof(B)}(data=dB_zg.data); + + # Test correctness. + dA_fd, dB_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, dC, A, B); + + # to_vec doesn't know how to make `Tangent`s, so translate manually. + dB_fd = Tangent{typeof(B)}(data=dB_fd_sym.data); + + test_approx(dA, dA_fd) + test_approx(dB, dB_fd) +end + +# Zygote doesn't like the structural tangent here because of an @adjoint +my_upper_triangular(X) = UpperTriangular(X) + +function rrule(::typeof(my_upper_triangular), X::Matrix{<:Real}) + pullback_my_upper_triangular(Ū::Tangent) = NoTangent(), Ū.data + return my_upper_triangular(X), pullback_my_upper_triangular +end + +let + foo_my_mul(U_data, B) = my_mul(my_upper_triangular(U_data), B) + U_data = randn(3, 3) + B = randn(3, 4) + C, pb = Zygote.pullback(foo_my_mul, U_data, B) + + @assert C ≈ foo_my_mul(U_data, B) + + C̄ = randn(3, 4) + Ū_data, B̄ = pb(C̄) + + Ū_data_fd, B̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_mul, C̄, U_data, B) + + test_approx(Ū_data, Ū_data_fd) + test_approx(B̄, B̄_fd) + + display(Ū_data_fd) + println() + display(Ū_data) + println() +end + + + +# pullbacks for `Real`s so that they play nicely with the utility functionality. + +ChainRulesCore.pullback_of_destructure(config::RuleConfig, x::Real) = identity + +ChainRulesCore.pullback_of_restructure(config::RuleConfig, x::Real) = identity + + +# Example 2: something where the output isn't a matrix. + +my_sum(x::AbstractArray) = sum(x) + +function ChainRulesCore.rrule(config::RuleConfig, ::typeof(my_sum), x::AbstractArray) + y = my_sum(x) + natural_pullback_my_sum(ȳ::Real) = NoTangent(), fill(ȳ, size(x)) # Fill also fine here. + return y, wrap_natural_pullback(config, natural_pullback_my_sum, y, x) +end + +let + A = Symmetric(randn(2, 2)) + y, pb = Zygote.pullback(my_sum, A) + + test_approx(y, my_sum(A)) + + dy = randn() + dA_zg, = pb(dy) + dA = Tangent{typeof(A)}(data=dA_zg.data) + + dA_fd_sym, = FiniteDifferences.j′vp(central_fdm(5, 1), my_sum, dy, A) + dA_fd = Tangent{typeof(A)}(data=dA_fd_sym.data) + + test_approx(dA, dA_fd) +end + + + + + +# Example 3: structured-input-structured-output + +my_scale(a::Real, x::AbstractMatrix) = a * x + +function ChainRulesCore.rrule( + config::RuleConfig, ::typeof(my_scale), a::Real, x::AbstractMatrix, +) + y = my_scale(a, x) + natural_pullback_my_scale(ȳ::AbstractMatrix) = NoTangent(), dot(ȳ, x), a * ȳ + return y, wrap_natural_pullback(config, natural_pullback_my_scale, y, a, x) +end + +# DENSE TEST +let + a = randn() + x = randn(2, 2) + y, pb = Zygote.pullback(my_scale, a, x) + + dy = randn(size(y)) + da, dx = pb(dy) + + da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x) + + test_approx(y, my_scale(a, x)) + test_approx(da, da_fd) + test_approx(dx, dx_fd) +end + + +# DIAGONAL TEST + +# `diag` now returns a `Diagonal` as a tangnet, so have to define `my_diag` to make this +# work with Diagonal`s. + +my_diag(x) = diag(x) +function ChainRulesCore.rrule(::typeof(my_diag), D::P) where {P<:Diagonal} + my_diag_pullback(d) = NoTangent(), Tangent{P}(diag=d) + return diag(D), my_diag_pullback +end + +let + a = randn() + x = Diagonal(randn(2)) + y, pb = Zygote.pullback(my_diag ∘ my_scale, a, x) + + ȳ = randn(2) + ā, x̄_zg = pb(ȳ) + x̄ = Tangent{typeof(x)}(diag=x̄_zg.diag) + + ā_fd, _x̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_diag ∘ my_scale, ȳ, a, x) + x̄_fd = Tangent{typeof(x)}(diag=_x̄_fd.diag) + + test_approx(y, (my_diag ∘ my_scale)(a, x)) + test_approx(ā, ā_fd) + test_approx(x̄, x̄_fd) +end + + +# SYMMETRIC TEST - FAILS BECAUSE PRIVATE ELEMENTS IN LOWER-DIAGONAL ACCESSED IN PRIMAL! +# I would be surprised if we're doing this consistently at the minute though. +let + a = randn() + x = Symmetric(randn(2, 2)) + y, pb = Zygote.pullback(my_scale, a, x) + + dy = Tangent{typeof(y)}(data=randn(2, 2)) + da, dx_zg = pb(dy) + dx = Tangent{typeof(x)}(data=dx_zg.data) + + da_fd, dx_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x) + dx_fd = Tangent{typeof(x)}(data=dx_fd_sym.data) + + test_approx(y.data, my_scale(a, x).data) + test_approx(da, da_fd) + test_approx(dx, dx_fd) +end + + + + +# Example 4: ScaledVector. This is an interesting example because I truly had no idea how to +# specify a natural tangent prior to this work. + +# Implement AbstractArray interface. +struct ScaledMatrix <: AbstractMatrix{Float64} + v::Matrix{Float64} + α::Float64 +end + +Base.getindex(x::ScaledMatrix, p::Int, q::Int) = x.α * x.v[p, q] + +Base.size(x::ScaledMatrix) = size(x.v) + + +# Implement destructure and restructure pullbacks. + +function pullback_of_destructure(config::RuleConfig, x::P) where {P<:ScaledMatrix} + function pullback_destructure_ScaledMatrix(X̄::AbstractArray) + return Tangent{P}(v = X̄ * x.α, α = dot(X̄, x.v)) + end + return pullback_destructure_ScaledMatrix +end + +function pullback_of_restructure(config::RuleConfig, x::ScaledMatrix) + function pullback_restructure_ScaledMatrix(x̄::Tangent) + return x̄.v / x.α + end + return pullback_restructure_ScaledMatrix +end + +# What destructure and restructure would look like if implemented. pullbacks were derived +# based on these. +# ChainRulesCore.destructure(x::ScaledMatrix) = x.α * x.v + +# ChainRulesCore.Restructure(x::P) where {P<:ScaledMatrix} = Restructure{P, Float64}(x.α) + +# (r::Restructure{<:ScaledMatrix})(x::AbstractArray) = ScaledMatrix(x ./ r.data, r.data) + + + + +# Define a function on the type. + +my_dot(x::AbstractArray, y::AbstractArray) = dot(x, y) + +function ChainRulesCore.rrule( + config::RuleConfig, ::typeof(my_dot), x::AbstractArray, y::AbstractArray, +) + z = my_dot(x, y) + natural_pullback_my_dot(z̄::Real) = NoTangent(), z̄ * y, z̄ * x + return z, wrap_natural_pullback(config, natural_pullback_my_dot, z, x, y) +end + +let + # Check correctness of `my_dot` rrule. Build `ScaledMatrix` internally to avoid + # technical issues with FiniteDifferences. + V = randn(2, 2) + α = randn() + z̄ = randn() + + foo_scal(V, α) = my_dot(ScaledMatrix(V, α), V) + + z, pb = Zygote.pullback(foo_scal, V, α) + dx_ad = pb(z̄) + + dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_scal, z̄, V, α) + + test_approx(dx_ad, dx_fd) +end + + +# Specialised method of my_scale for ScaledMatrix +my_scale(a::Real, X::ScaledMatrix) = ScaledMatrix(X.v, X.α * a) + +let + # Verify correctness. + a = randn() + V = randn(2, 2) + α = randn() + z̄ = randn() + + # A more complicated programme involving `my_scale`. + B = randn(2, 2) + foo_my_scale(a, V, α) = my_dot(B, my_scale(a, ScaledMatrix(V, α))) + + z, pb = Zygote.pullback(foo_my_scale, a, V, α) + da, dV, dα = pb(z̄) + + da_fd, dV_fd, dα_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, z̄, a, V, α) + + test_approx(da, da_fd) + test_approx(dV, dV_fd) + test_approx(dα, dα_fd) +end + + + + +# Example 5: Fill + +using FillArrays + +# What you would implement: +# destucture(x::Fill) = collect(x) + +function pullback_of_destructure(config::RuleConfig, x::P) where {P<:Fill} + pullback_destructure_Fill(X̄::AbstractArray) = Tangent{P}(value=sum(X̄)) + return pullback_destructure_Fill +end + +# There are multiple equivalent choices for Restructure here. I present two options below, +# both yield the correct answer. +# To understand why there is a range of options here, recall that the input to +# (::Restructure)(x::AbstractArray) must be an array that is equal (in the `getindex` sense) +# to a `Fill`. Moreover, `(::Restructure)` simply has to promise that the `Fill` it outputs +# is equal to the `Fill` from which the `Restructure` is constructed (in the structural +# sense). Since the elements of `x` must all be equal, any affine combination will do ( +# weighted sum, whose weights sum to 1). +# While there is no difference in the primal for `(::Restructure)`, the pullback is quite +# different, depending upon your choice. Since we don't ever want to evaluate the primal for +# `(::Restructure)`, just the pullback, we are free to choose whatever definition of +# `(::Restructure)` makes its pullback pleasant. In particular, defining `(::Restructure)` +# to take the mean of its argument yields a pleasant pullback (see below). + +# Restucture option 1: + +# Restructure(x::P) where {P<:Fill} = Restructure{P, typeof(x.axes)}(x.axes) +# (r::Restructure{<:Fill})(x::AbstractArray) = Fill(x[1], r.data) + +function pullback_of_restructure(config::RuleConfig, x::Fill) + println("Option 1") + function pullback_restructure_Fill(x̄::Tangent) + X̄ = zeros(size(x)) + X̄[1] = x̄.value + return X̄ + end + return pullback_restructure_Fill +end + +# Restructure option 2: + +# Restructure(x::P) where {P<:Fill} = Restructure{P, typeof(x.axes)}(x.axes) +# (r::Restructure{<:Fill})(x::AbstractArray) = Fill(mean(x), r.data) + +function pullback_of_restructure(config::RuleConfig, x::Fill) + println("Option 2") + pullback_restructure_Fill(x̄::Tangent) = Fill(x̄.value / length(x), x.axes) + return pullback_restructure_Fill +end + +# An example which uses `pullback_of_destructure(::Fill)` because `Fill` is an input. +let + A = randn(2, 3) + v = randn() + + # Build the Fill inside because FiniteDifferenes doesn't play nicely with Fills, even + # if one adds a `to_vec` call. + foo_my_mul(A, v) = my_mul(A, Fill(v, 3, 4)) + C, pb = Zygote.pullback(foo_my_mul, A, v) + + @assert C ≈ foo_my_mul(A, v) + + C̄ = randn(2, 4); + Ā, v̄ = pb(C̄); + + Ā_fd, v̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_mul, C̄, A, v); + + test_approx(Ā, Ā_fd) + test_approx(v̄, v̄_fd) +end + +# Another example using `pullback_of_destructure(::Fill)`. +let + foo_my_sum(v) = my_sum(Fill(v, 4, 3)) + v = randn() + + c, pb = Zygote.pullback(foo_my_sum, v) + @assert c ≈ foo_my_sum(v) + + c̄ = randn() + v̄ = pb(c̄) + + v̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_sum, c̄, v); + + test_approx(v̄, v̄_fd) +end + +# An example using `pullback_of_restructure(::Fill)`. +let + foo_my_scale(a, v) = my_sum(my_scale(a, Fill(v, 3, 4))) + a = randn() + v = randn() + + c, pb = Zygote.pullback(foo_my_scale, a, v) + @assert c ≈ foo_my_scale(a, v) + + c̄ = randn() + ā, v̄ = pb(c̄) + + ā_fd, v̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, c̄, a, v); + + test_approx(ā, ā_fd) + test_approx(v̄, v̄_fd) +end + + + + +# Example 6: SArray +# This example demonstrates that within this framework we can easily work with structural +# tangents for `SArray`s. Unclear that we _want_ to do this, but it's nice to know that +# it's an option requiring minimal work. +# Notice that this should be performant, since `pullback_of_destructure` and +# `pullback_of_restructure` should be performant, and the operations in the pullback +# will all happen on `SArray`s. + +using StaticArrays + +function pullback_of_destructure(config::RuleConfig, x::P) where {P<:SArray} + pullback_destructure_SArray(X̄::AbstractArray) = Tangent{P}(data=X̄) + return pullback_destructure_SArray +end + +function pullback_of_restructure( + config::RuleConfig, x::SArray{S, T, N, L}, +) where {S, T, N, L} + pullback_restructure_SArray(x̄::Tangent) = SArray{S, T, N, L}(x̄.data) + return pullback_restructure_SArray +end + +# destructure + restructure example with `my_mul`. +let + A = SMatrix{2, 2}(randn(4)...) + B = SMatrix{2, 1}(randn(2)...) + C, pb = Zygote.pullback(my_mul, A, B) + + @assert C ≈ my_mul(A, B) + + C̄ = Tangent{typeof(C)}(data=(randn(2)..., )) + Ā_, B̄_ = pb(C̄) + + # Manually convert Ā_ and B̄_ to Tangents from Zygote types. + Ā = Tangent{typeof(A)}(data=Ā_.data) + B̄ = Tangent{typeof(B)}(data=B̄_.data) + + Ā_fd_, B̄_fd_ = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, C̄, A, B) + + # Manually convert Ā_fd and B̄_fd into Tangents from to_vec output. + Ā_fd = Tangent{typeof(A)}(data=Ā_fd_) + B̄_fd = Tangent{typeof(B)}(data=B̄_fd_) + + test_approx(Ā, Ā_fd) + test_approx(B̄, B̄_fd) +end + + + + + +# Example 7: WoodburyPDMat +# WoodburyPDMat doesn't currently know anything about AD. I have no intention of +# implementing any of the functionality here on it, because it's just fine as it is. +# ~~However, it _is_ any interesting case-study, because it's an example where addition in +# natural (co)tangent space disagrees with addition in structural space. Since we know that +# the notion of addition we have on structural tangents is the desirable one, this indicates +# that we don't always want to add natural tangents.~~ See below. +# It's also interesting because, as with the `ScaledMatrix` example, I had no idea how to +# find a natural (co)tangent prior to this PR. It's a comparatively complicated example, +# and destructure and restructure are non-linear in the fields of `x`, which is another +# interesting property. +# I've only bothered deriving stuff for destructure because restructure is really quite +# complicated and I don't have the time right now to work through the example. +# It does serve to show that it's not always going to be easy for authors of complicated +# array types to make their type work with the natural pullback machinery. At least we +# understand what an author would have to do though, even if it's not straightforward to +# do all of the time. +# edit: natural tangents should always add properly. Ignore what is said in the para above +# about them not adding properly here. It's still interesting for the other reasons listed +# though. + +using PDMatsExtras: WoodburyPDMat +import ChainRulesCore: destructure, Restructure + +# What destructure would do if we actually implemented it. +# destructure(x::WoodburyPDMat) = x.A * x.D * x.A' + x.S + +# This is an interesting pullback, because it looks like the destructuring mechanism +# is required here to ensure that the fields `D` and `S` are handled appropriately. +# A would also be necessary in general, but I'm assuming it's just a `Matrix{<:Real}` for +# now. +function pullback_of_destructure(config::RuleConfig, x::P) where {P<:WoodburyPDMat} + println("Hitting pullback") + pb_destructure_D = pullback_of_destructure(x.D) + pb_destructure_S = pullback_of_destructure(x.S) + function pullback_destructure_WoodburyPDMat(x̄::AbstractArray) + S̄ = pb_destructure_S(x̄) + D̄ = pb_destructure_D(x.A' * x̄ * x.A) + Ā = (x̄ + x̄') * x.A * x.D + return Tangent{P}(A=Ā, D=D̄, S=S̄) + end + return pullback_destructure_WoodburyPDMat +end + +# Check my_sum correctness. Doesn't require Restructure since the output is a Real. +let + A = randn(4, 2) + d = rand(2) .+ 1 + s = rand(4) .+ 1 + + foo(A, d, s) = my_sum(WoodburyPDMat(A, Diagonal(d), Diagonal(s))) + + Y, pb = Zygote.pullback(foo, A, d, s) + test_approx(Y, foo(A, d, s)) + + Ȳ = randn() + Ā, d̄, s̄ = pb(Ȳ) + + Ā_fd, d̄_fd, s̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, Ȳ, A, d, s) + + test_approx(Ā, Ā_fd) + test_approx(d̄, d̄_fd) + test_approx(s̄, s̄_fd) +end + +# Check my_mul correctness. Doesn't require Restructure since the output is an Array. +# This is a truly awful implementation (generic fallback for my_mul, and getindex is really +# very expensive for WoodburyPDMat), but it ought to work. +let + A = randn(4, 2) + d = rand(2) .+ 1 + s = rand(4) .+ 1 + b = randn() + + # Multiply some interesting types together. + foo(A, d, s, b) = my_mul(WoodburyPDMat(A, Diagonal(d), Diagonal(s)), Fill(b, 4, 3)) + + Y, pb = Zygote.pullback(foo, A, d, s, b) + test_approx(Y, foo(A, d, s, b)) + + Ȳ = randn(4, 3) + Ā, d̄, s̄, b̄ = pb(Ȳ) + + Ā_fd, d̄_fd, s̄_fd, b̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, Ȳ, A, d, s, b) + + test_approx(Ā, Ā_fd) + test_approx(d̄, d̄_fd) + test_approx(s̄, s̄_fd) + test_approx(b̄, b̄_fd) +end + + +# THIS EXAMPLE DOESN'T WORK. +# I'm not really sure why, but it's not the responsibility of this PR because I'm just +# trying to opt out of the generic rrule for my_scale, because there's a specialised +# implementation available. +# I've tried all of the opt-outs I can think of, but no luck -- it keeps hitting +# ChainRules for me :( + +# Opt-out and refresh. +using ChainRules: CommutativeMulNumber +@opt_out rrule(::RuleConfig, ::typeof(my_scale), ::Real, ::WoodburyPDMat) +@opt_out rrule(::typeof(*), ::WoodburyPDMat{<:CommutativeMulNumber}, ::CommutativeMulNumber) +@opt_out rrule(::typeof(*), ::CommutativeMulNumber, ::WoodburyPDMat{<:CommutativeMulNumber}) + +# This should probably go in ChainRules anyway. +@opt_out rrule(::typeof(*), ::CommutativeMulNumber, ::Diagonal{<:CommutativeMulNumber}) +@opt_out rrule(::typeof(*), ::Diagonal{<:CommutativeMulNumber}, ::CommutativeMulNumber) + +Zygote.refresh() + +# This should just hit [this code](https://github.com/invenia/PDMatsExtras.jl/blob/b7b3a2035682465f1471c2d2e1e017b9fd75cec0/src/woodbury_pd_mat.jl#L92) +let + α = rand() + A = randn(4, 2) + d = rand(2) .+ 10 + s = rand(4) .+ 1 + + # Multiply some interesting types together. + foo(α, A, d, s) = my_scale(α, WoodburyPDMat(A, Diagonal(d), Diagonal(s))) + + Y, pb = Zygote.pullback(foo, α, A, d, s) + test_approx(Y, foo(α, A, d, s)) + + Ȳ = ( + A=randn(4, 2), + D=(diag=randn(2),), + S=(diag=randn(4),), + ) + ᾱ, Ā, d̄, s̄ = pb(Ȳ) + + # No need to test correctness with FiniteDifferences because I've just opted out of + # rules, so if the code runs, it ought to be correct. If it's not correct, it must be a + # bug from somewhere other than here. +end + + + +# Example 8: Kronecker.jl +# This is a good example of a package where you want to opt out of basically all rules +# which apply to functions that Kronecker has implemented, but you also really want to +# ensure that the generic fallbacks work for operations that you don't care about so much +# (as discussed in various locations, typically if someone cares about the performance of +# an operations on a Kronecker matrix, they'll have implemented a specialised method to do +# it.) +# I have absolutely no idea how to implement restructure for this type. + +using Kronecker +using Kronecker: KroneckerProduct + +# destructure(X::KroneckerProduct) = kron(X.A, X.B) + +# Just differentiate the implementation of `kron` from the stdlib to produce the pullback. +# Again, I'm assuming that X.A and X.B are dense matrix-like things for now, but will +# generalise at some point. +function pullback_of_destructure(config::RuleConfig, X::T) where {T<:KroneckerProduct} + A = X.A + B = X.B + function pullback_destructure_KroneckerProduct(X̄::AbstractMatrix) + println("pullback_destructure_KroneckerProduct") + m = length(X̄) + Ā = zero(A) + B̄ = zero(B) + for j in reverse(1:size(A, 2)) + for l in reverse(1:size(B, 2)) + for i in reverse(1:size(A, 1)) + for k in reverse(1:size(B, 1)) + Ā[i, j] += X̄[m] * B[k, l] + B̄[k, l] += X̄[m] * A[i, j] + m -= 1 + end + end + end + end + return Tangent{T}(A=Ā, B=B̄) + end + return pullback_destructure_KroneckerProduct +end + +# Check my_sum correctness. Doesn't require Restructure since the output is a Real. +let + A = randn(4, 2) + B = randn(3, 5) + c̄ = randn() + + foo(A, B) = my_sum(kronecker(A, B)) + + c, pb = Zygote.pullback(foo, A, B) + test_approx(c, foo(A, B)) + + Ā, B̄ = pb(c̄) + + Ā_fd, B̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, c̄, A, B) + + test_approx(Ā, Ā_fd) + test_approx(B̄, B̄_fd) +end + +# Kronecker doesn't implement their efficient `KroneckerProduct * AbstractMatrix` operation +# in a mutation-free manner (at present), so the best we can do without implementing a rule +# is make use of the generic fallback. This is the kind of operation that you would probably +# want to implement a rule for / implement in an AD-friendly manner in reality. +let + A = randn(4, 2) + B = randn(3, 5) + X = randn(10, 2) + + # Multiply some interesting types together. + foo(A, B, X) = my_mul(kronecker(A, B), X) + + C, pb = Zygote.pullback(foo, A, B, X) + test_approx(C, foo(A, B, X)) + + C̄ = randn(12, 2) + Ā, B̄, X̄ = pb(C̄) + + Ā_fd, B̄_fd, X̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, C̄, A, B, X) + + test_approx(Ā, Ā_fd) + test_approx(B̄, B̄_fd) + test_approx(X̄, X̄_fd) +end + +# Opt-out for multiplying two KroneckerProducts together. +using ChainRules: CommutativeMulNumber +@opt_out rrule(::typeof(*), ::AbstractKroneckerProduct{<:CommutativeMulNumber}, ::AbstractKroneckerProduct{<:CommutativeMulNumber}) +@opt_out rrule(::RuleConfig, ::typeof(my_mul), ::AbstractKroneckerProduct, ::AbstractKroneckerProduct) + +let + A = randn(4, 2) + B = randn(3, 5) + C = randn(2, 3) + D = randn(5, 3) + + # Multiply some interesting types together. + foo(A, B, C, D) = my_mul(kronecker(A, B), kronecker(C, D)) + + Z, pb = Zygote.pullback(foo, A, B, C, D) + test_approx(Z, foo(A, B, C, D)) + + Z̄ = (A=randn(size(Z.A)), B=randn(size(Z.B))) + Ā, B̄, C̄, D̄ = pb(Z̄) + + Ā_fd, B̄_fd, C̄_fd, D̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo, Z̄, A, B, C, D) + + test_approx(Ā, Ā_fd) + test_approx(B̄, B̄_fd) + test_approx(C̄, C̄_fd) + test_approx(D̄, D̄_fd) +end diff --git a/notes.md b/notes.md new file mode 100644 index 000000000..8805e2a06 --- /dev/null +++ b/notes.md @@ -0,0 +1,255 @@ +# A General Mechanism for Generic Rules for AbstractArrays + +That we don't have a general formalism for deriving natural (co)tangents has been discussed quite a bit recently. As has our lack of understanding of the precise relationship between the generic rrules we're writing, and what AD would do. This PR proposes a recipe for deriving generic rules, which leads to a possible formalism for natural derivatives. This formalism can be applied to any AbstractArray, and AD can be in principle be used to obtain default values for the natural tangent. Moreover, there's some utility functionality proposed to make working with this formalism straightforward for rule-writers. + + +I want reviewers to determine whether they agree that the proposed recipe +1. is sufficient for implementing generic rrules on AbstractArrays, +2. is correct, in the sense that it produces the same answer as AD, and +3. the definition of natural tangents proposed indeed applies to any AbstractArray, and broadly agrees with our intuitions about what a natural tangent should be. + +I think what is proposed should be doable without making breaking changes since it just involves a changing the output types of some rules, which isn't something that we consider breaking provided that they represent the same thing. I'd prefer we worry about this later if we think this is a good idea though. + +Apologies in advance for the lenght. I've tried to condense where possible, but there's a decent amount of content to get through. + + + + + +## Cutting to the Chase + +Under the proposed system, rule-implementers would write rules that look like this: +```julia +function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + + # Do the primal computation. + C = A * B + + # "natural pullback": write intuitive pullback, closing over stuff in the usual manner. + function natural_pullback_for_mul(C̄_natural) + Ā_natural = C̄_natural * B' + B̄_natural = A' * C̄_natural + return NoTangent(), Ā_natural, B̄_natural + end + + # Make a call to utility functionality which transforms cotangents of C, A, and B. + # Rule-writing without this utility has similar requirements to `ProjectTo`. + return C, wrap_natural_pullback(config, natural_pullback_for_mul, C, A, B) +end +``` +I'm proposing to coin the term "natural pullback" for pullbacks written within this system, as they're rules written involving natural (co)tangents. + +Authors will have to implement two functions for their `AbstractArray` type `P`: +```julia +pullback_of_destructure(::P) +pullback_of_restructure(::P) +``` +which are the pullbacks of two functions, `destructure` and `(::Restructure)`, that we'll define later. + +The proposed candidates for natural (co)tangents are obtained as follows: +1. natural tangents are obtained from structural tangents via the pushforward of `destructure`, +2. natural cotangents are obtained from structural cotangents via the pullback of `(::Restructure)`. + +Our current `ProjectTo` functionality is roughly the same as the pullback of `destructure`. + +In the proposed system, natural (co)tangents remain confined to `rrule`s, and rule authors can choose to work with either natural, structural, or a mixture of (co)tangents. + +Other than the headlines at the top, additional benefits of (a correct implementation of) this approach include: +1. no chance of trying to sum tangents failing because one is natural and the other structural, +1. no risk of obstructing AD, +1. rand_tangent can just use structural tangents, simplifying its implementation and improving its reliability, +1. we can probably finally make `to_vec` treat things structurally (although we also need to extend it in other ways), which will also deal with reliability / simplicity of implementation problems, +1. generic constructor for composite types easy to implement since we can be sure of obtaining a structural tangent, and rid ourselves of the "need to define an adjoint for constructor..." errors we see in Zygote, +1. due to the utility functionality, all of the examples that I've encountered so far are very concise. + +The potential downside is additional conversions between natural and structural tangents. Most of the time, these are free. When they're not, you ideally want to minimise them. I'm not sure how often this is going to be a problem, but it's something we're essentially ignoring at the minute (as far as I know), so we're probably going to have to incur some additional cost at some point even if we don't go down the route proposed here. + + + + + +## Starting Point + +Imagine a clean-slate version of Zygote / Diffractor, in which any rules writen are either +1. necessary to define what AD should do (e.g. `+` or `*` on `Float64`s, `getindex` on `Array`s, `getfield` on composite types, etc), or +2. produce _exactly_ the same answer that AD would produce -- crucially they always return a structural tangent for a non-primitive composite type. + +Moreover, assume that users provide structural tangents -- this restriction could be removed by AD authors by shoving a call to `destructure` at the end of their AD if that's something they want to do. + +The consequence of the above is that we can safely assume that the input and output of any rule will be either +1. a structural tangent (if the primal is a non-primitive composite type) +2. whatever we have defined its tangent type to be (if it's a primitive type like `Float64` or `Array`). + +I'm intentionally not trying to define precisely what a primitive is, but will assume that everyone agrees that `Array`s are an example of a primitive, and that we are happy with representing the tangent of an `Array` by another `Array`. + +Assume that structural tangents are a valid representation of a tangent of any non-primitive composite type, convenience for rule-writing aside. + +More generally, assume that if an AD successfully runs on a given function under the above assumptions, they give the answer desired (the "correct" answer). Consequently, two core goals of the proposed recipe are to make it +1. easy to never write a rule which prevents an AD from differentiating a programme that they already know how to differentiate, +2. make it easy to write rules using intuitive representations of tangents. + + + + + +## The Formalism + +First consider a specific case -- we'll optimise the implementation later, and provide more general examples in `examples.jl`. + +Consider a function `f(x::AbstractArray) -> y::AbstractArray`. Lets assume that there's just one method, so we can be sure that a generic fallback will be hit, regardless the concrete type of the argument. + +The intuition behind the recipe is to find a function which is equivalent to `f`, whose rules we know how to write safely. If we can find such a function, AD-ing it will clearly give the correct answer (the same as running AD on `f` itself) -- the following lays out an approach to doing this. + +The recipe is: +1. Map `x` to an `Array`, `x_dense`, using `getindex`. Call this operation `destructure` (it's essentially `collect`). +2. Apply `f` to `x_dense` to obtain `y_dense`. +3. Map `y_dense` onto `y`. Call this operation `(::Restructure)`. + +I'm going to define equivalence of output structurally -- two `AbstractArray`s are equal if +1. they are primitives, and whatever notion of equality we have for them holds, or +2. they are composite types, and each of their fields are equal under this definition. + +The reason for this notion of equality is that AD (as proposed above) treats concrete subtypes of AbstractArray no differently from any other composite type. + +Applying this recipe, a very literal implementation for the `rrule` of the equivalent function for `*` is something like the following: +```julia +function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + + # Produce dense versions of A and B, and the pullbacks of this operation. + A_dense, destructure_A_pb = rrule(destructure, A) + B_dense, destructure_B_pb = rrule(destructure, B) + + # Compute dense primal. + C_dense = A_dense * B_dense + + # Compute structured primal without densifying to ensure that we get structured `C` back + # if that's what the primal would do. + C = A * B + + # Construct pullback of Restructure. We generally need to extract some information from + # C in order to find the structured version. + _, restructure_C_pb = rrule_via_ad(config, Restructure(C), C_dense) + + function my_mul_generic_pullback(C̄) + + # Recover natural cotangent. + _, C̄_nat = restructure_C_pb(C̄) + + # Compute pullback using natural cotangent of C. + Ā_nat = C̄_nat * B_dense' + B̄_nat = A_dense' * C̄_nat + + # Transform natural cotangents w.r.t. A and B into structural (if non-primitive). + _, Ā = destructure_A_pb(Ā_nat) + _, B̄ = destructure_B_pb(B̄_nat) + return NoTangent(), Ā, B̄ + end + + # The output that we want is `C`, not `C_dense`, so return `C`. + return C, my_mul_generic_pullback +end +``` +All I've done is write out the rrule for differentiating through the equivalent function by hand. +We'll optimise this implementation shortly to avoid e.g. having to densify primals, and computing the same function twice. +`my_mul` in `examples.jl` verifies the correctness of the above implementation. + + +`destructure` is quite straightforward to define -- essentially equivalent to `collect`. I'm confident that this is always going to be simple to define, because `collect` is always easy to define. + +`Restructure(C)(C_dense)` is a bit trickier. It's the function which takes an `Array` `C_dense` and transforms it into `C`. This feels like a slightly odd thing to do, since we already have `C`, but it's necessary to already know what `C` is in order to construct this function in general -- for example, over-parametrised matrices require this (see the `ScaledMatrix` example in the tests / examples). I'm _reasonably_ confident that this is always going to be possible to define, but I might have missed something. + +The PR shows how to implement steps 1 and 3 for `Array`s, `Diagonal`s, `Symmetric`s, and a custom `AbstractArray` `ScaledMatrix`. + + + + + +## Acceptable (Co)Tangents for `Array` + +Any `AbstractArray` is an acceptable (co)tangent for an `Array` (provided it's the right size, and its elements are appropriate (co)tangents for the elements of the primal `Array`). +I'm going to assume this is true, because I can't see any obvious reason why it wouldn't be. +If anyone feels otherwise, please say. + +For example, this means that a `Diagonal{Float64}` is a valid (co)tangent for an `Array{Float64}`. + + + + + +## Optimising rrules for the equivalent function + +The basic example layed out above was very sub-optimal. Consider the following (equivalent) re-write +```julia +function rrule(config::RuleConfig, ::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + + # Produce pullbacks of destructure. + destructure_A_pb = pullback_of_destructure(config, A) + destructure_B_pb = pullback_of_destructure(config, B) + + # Primal computation. + C = A * B + + # Find pullback of restructure. + restructure_C_pb = pullback_of_restructure(config, C) + + function my_mul_generic_pullback(C̄) + + # Recover natural cotangent. + C̄_nat = restructure_C_pb(C̄) + + # Compute pullback using natural cotangent of C. + Ā_nat = C̄_nat * B' + B̄_nat = A' * C̄_nat + + # Transform natural cotangents w.r.t. A and B into structural (if non-primitive). + Ā = destructure_A_pb(Ā_nat) + B̄ = destructure_B_pb(B̄_nat) + return NoTangent(), Ā, B̄ + end + + return C, my_mul_generic_pullback +end +``` +A few observations: +1. All dense primals are gone. In the pullback, they only appeared in places where they can be safely replaced with the primals themselves because they're doing array-like things. `C_dense` appeared in the construction of `restructure_C_pb`, however, we were using a sub-optimal implementation of that function. Much of the time, `restructure_of_pb` doesn't require `C_dense` in order to know what the pullback would look like and, if it does, it can be obtained from `C`. +2. All direct calls to `rrule_via_ad` have been replaced with calls to functions which are defined to return the things we actually need (the pullbacks). These typically have efficient (and easy to write) implementations. +3. It is permissible for `C̄_nat` to be any old `AbstractArray`, because it's conceptually the cotangent for an `Array`. For example, it is permissible for us to manually optimise `pullback_of_restructure` for a `Diagonal` to ensure that it returns a `Diagonal`. This is good -- it means we can occassionally get faster computations in the natural pullback. + +Roughly speaking, the above implementation has only one additional operation than our existing rrules involving `ProjectTo`, which is a call to `restructure_C_pb`, which handles converting a structural tangent for `C̄` into the corresponding natural. Currently we require users to do this by hand, and no clear guidance is provided regarding the correct way to handle this conversion, in contrast to the clarity provided here. In this sense, all that the above is doing is providing a well-defined mechanism by which users can obtain natural cotangents from structural cotangents, so it should ease the burden on rule-implementers. + +Almost all of the boilerplate in the above example can be removed by utilising the `wrap_natural_pullback` utility function defined in the PR, as in the example at the top of this note. + + + + + +## Gotchas + +There does seem to be something that goes wrong when primals access non-public fields of types (array authors are obviously allowed to do this), but the generic rrules assume that only the AbstractArray API is used. +I don't think this differs from what we're doing at the minute, so probably we're suffering from this already and just haven't hit it yet. +See the third example in `examples.jl`, involving a `Symmetric`. + +This is a particularly interesting case because `parent` is exported from `Base`, effectively making the field of a `Symmetric` part of its public API. + +I'm not really sure how to think about this but, as I say, I suspect we're already suffering from it, so I'm not going to worry about it for now. + +edit: I think it's best just to think of this "junk" data in the lower triangle as a side-effect. We can't generally guarantee anything in the presence of side effects, so it's no surprise that this is causing some problems. This is part of a more general problem (that we're currently dealing with on an ad-hoc basis in to_vec, and probably need to think more generally about). + + + + + +## Summary + +The above lays out a mechanism for writing generic rrules for AbstractArrays, out of which drops what I believe to be a good candidate for a precise definition of the natural (co)tangent of any particular AbstractArray. + +There are a lot more examples in `examples.jl` that I would encourage people to work through. Moreover, the `Symmetric` results are a little odd, but I think make sense. +Additionally, implementations of `destructure` and `Restructure` can be moved to the tests because they're really just used to verify the correctness of manually implementations of their pullbacks. + +I've presented this work in the context of `AbstractArray`s, but the general scheme could probably be extended to other types by finding other canonical types (like `Array`) on which people's intuition about what ought to happen holds. + +The implementation are also limited to arrays of real numbers to avoid the need to recursively apply `destructure` / `restructure`. This restriction could be dropped in practice, and recursive definitions applied. + +I'm sure there's stuff above which is unclear -- please let me know if so. There's more to say about a lot of this stuff, but I'll stop here in the interest of keeping this concise. + +Please now go and look at `examples.jl`. diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 08ac3847c..e3191795f 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -33,6 +33,8 @@ include("config.jl") include("rules.jl") include("rule_definition_tools.jl") +include("destructure.jl") + include("deprecated.jl") end # module diff --git a/src/destructure.jl b/src/destructure.jl new file mode 100644 index 000000000..5ce15cd88 --- /dev/null +++ b/src/destructure.jl @@ -0,0 +1,256 @@ +# Fallbacks for destructure +destructure(X::AbstractArray) = collect(X) + +pushforward_of_destructure(X) = dX -> frule((NoTangent(), dX), destructure, X)[2] + +function pullback_of_destructure(config::RuleConfig, X) + return dY -> rrule_via_ad(config, destructure, X)[2](dY)[2] +end + +# Restructure machinery. +struct Restructure{P, D} + data::D +end + +function pullback_of_restructure(config::RuleConfig, X) + return dY -> rrule_via_ad(config, Restructure(X), destructure(X))[2](dY)[2] +end + + + + +# Array + +function pullback_of_destructure(X::Array{<:Real}) + pullback_destructure_Array(X̄::AbstractArray{<:Real}) = X̄ + return pullback_destructure_Array +end + +function pullback_of_restructure(X::Array{<:Real}) + pullback_restructure_Array(X̄::AbstractArray{<:Real}) = X̄ + return pullback_restructure_Array +end + +function pullback_of_destructure(config::RuleConfig, X::Array{<:Real}) + return pullback_of_destructure(X) +end + +function pullback_of_restructure(config::RuleConfig, X::Array{<:Real}) + return pullback_of_restructure(X) +end + + +# Stuff below here for Array to move to tests. +destructure(X::Array) = X + +frule((_, dX)::Tuple{Any, AbstractArray}, ::typeof(destructure), X::Array) = X, dX + +function rrule(::typeof(destructure), X::Array) + destructure_pullback(dXm::AbstractArray) = NoTangent(), dXm + return X, destructure_pullback +end + +Restructure(X::P) where {P<:Array} = Restructure{P, Nothing}(nothing) + +(r::Restructure{P})(X::Array) where {P<:Array} = X + +function frule( + (_, dX)::Tuple{Any, AbstractArray}, ::Restructure{P}, X::Array, +) where {P<:Array} + return X, dX +end + +function rrule(::Restructure{P}, X::Array) where {P<:Array} + restructure_pullback(dY::AbstractArray) = NoTangent(), dY + return X, restructure_pullback +end + + + + + +# Diagonal +function pullback_of_destructure(D::P) where {P<:Diagonal} + pullback_destructure_Diagonal(D̄::AbstractArray) = Tangent{P}(diag=diag(D̄)) + return pullback_destructure_Diagonal +end + +function pullback_of_restructure(D::P) where {P<:Diagonal} + pullback_restructure_Diagonal(D̄::Tangent) = Diagonal(D̄.diag) + return pullback_restructure_Diagonal +end + +function pullback_of_destructure(config::RuleConfig, X::Diagonal) + return pullback_of_destructure(X) +end + +function pullback_of_restructure(config::RuleConfig, X::Diagonal) + return pullback_of_restructure(X) +end + +# Stuff below here for Diagonal to move to tests. +destructure(X::Diagonal) = collect(X) + +function frule((_, dX)::Tuple{Any, Tangent}, ::typeof(destructure), X::Diagonal) + des_diag, d_des_diag = frule((NoTangent(), dX.diag), destructure, X.diag) + return collect(X), collect(Diagonal(d_des_diag)) +end + +function rrule(::typeof(destructure), X::P) where {P<:Diagonal} + _, des_diag_pb = rrule(destructure, X.diag) + function destructure_pullback(dY::AbstractMatrix) + d_des_diag = diag(dY) + _, d_diag = des_diag_pb(d_des_diag) + return NoTangent(), Tangent{P}(diag=d_diag) + end + return destructure(X), destructure_pullback +end + +Restructure(X::P) where {P<:Diagonal} = Restructure{P, Nothing}(nothing) + +function (r::Restructure{P})(X::Array) where {P<:Diagonal} + @assert isdiag(X) # for illustration. Remove in actual because numerics. + return Diagonal(diag(X)) +end + +function frule( + (_, dX)::Tuple{Any, AbstractArray}, ::Restructure{P}, X::Array, +) where {P<:Diagonal} + return Diagonal(diag(X)), Tangent{P}(diag=diag(dX)) +end + +function rrule(::Restructure{P}, X::Array) where {P<:Diagonal} + restructure_pullback(dY::Tangent) = NoTangent(), Diagonal(dY.diag) + return X, restructure_pullback +end + + + + + +# Symmetric +function pullback_of_destructure(S::P) where {P<:Symmetric} + function destructure_pullback_Symmetric(dXm::AbstractMatrix) + U = UpperTriangular(dXm) + L = LowerTriangular(dXm) + if S.uplo == 'U' + return Tangent{P}(data=U + L' - Diagonal(dXm)) + else + return Tangent{P}(data=U' + L - Diagonal(dXm)) + end + end + return destructure_pullback_Symmetric +end + +# Assume upper-triangular for now. +function pullback_of_restructure(S::P) where {P<:Symmetric} + function restructure_pullback_Symmetric(dY::Tangent) + return collect(UpperTriangular(dY.data)) + end + return restructure_pullback_Symmetric +end + +function pullback_of_destructure(config::RuleConfig, X::Symmetric) + return pullback_of_destructure(X) +end + +function pullback_of_restructure(config::RuleConfig, X::Symmetric) + return pullback_of_restructure(X) +end + +# Stuff below here for Symmetric to move to tests. +function destructure(X::Symmetric) + des_data = destructure(X.data) + if X.uplo == 'U' + U = UpperTriangular(des_data) + return U + U' - Diagonal(des_data) + else + L = LowerTriangular(des_data) + return L' + L - Diagonal(X.data) + end +end + +# This gives you the natural tangent! +function frule((_, dx)::Tuple{Any, Tangent}, ::typeof(destructure), x::Symmetric) + des_data, d_des_data = frule((NoTangent(), dx.data), destructure, x.data) + + if x.uplo == 'U' + dU = UpperTriangular(d_des_data) + return destructure(x), dU + dU' - Diagonal(d_des_data) + else + dL = LowerTriangular(d_des_data) + return destructure(x), dL + dL' - Diagonal(d_des_data) + end +end + +function rrule(::typeof(destructure), X::P) where {P<:Symmetric} + function destructure_pullback(dXm::AbstractMatrix) + U = UpperTriangular(dXm) + L = LowerTriangular(dXm) + if X.uplo == 'U' + return NoTangent(), Tangent{P}(data=U + L' - Diagonal(dXm)) + else + return NoTangent(), Tangent{P}(data=U' + L - Diagonal(dXm)) + end + end + return destructure(X), destructure_pullback +end + +Restructure(X::P) where {P<:Symmetric} = Restructure{P, P}(X) + +# In generic-abstractarray-rrule land, assume getindex was used, so the +# strict-lower-triangle was never touched. +function (r::Restructure{P})(X::Array) where {P<:Symmetric} + strict_lower_triangle_of_data = LowerTriangular(r.data.data) - Diagonal(r.data.data) + return Symmetric(UpperTriangular(X) + strict_lower_triangle_of_data) +end + +# We get to assume that `issymmetric(X)` is (at least roughly) true, so we could also +# implement restructure as Symmetric((X + X') / 2), provided that we then take care of the +# lower triangle as above. + +function frule( + (_, dX)::Tuple{Any, AbstractArray}, r::Restructure{P}, X::Array, +) where {P<:Symmetric} + return r(X), Tangent{P}(data=UpperTriangular(dX)) +end + +function rrule(r::Restructure{P}, X::Array) where {P<:Symmetric} + function restructure_pullback(dY::Tangent) + d_restructure = Tangent{Restructure{P}}(data=Tangent{P}(data=tril(dY.data))) + return d_restructure, collect(UpperTriangular(dY.data)) + end + return r(X), restructure_pullback +end + + + + + +# UpperTriangular +function pullback_of_destructure(U::P) where {P<:UpperTriangular{<:Real, <:StridedMatrix}} + pullback_destructure(Ū::AbstractArray) = Tangent{P}(data=UpperTriangular(Ū)) + return pullback_destructure +end + +function pullback_of_restructure(U::P) where {P<:UpperTriangular} + pullback_restructure(Ū::Tangent) = UpperTriangular(Ū.data) + return pullback_restructure +end + +function pullback_of_destructure(config::RuleConfig, U::UpperTriangular) + return pullback_of_destructure(U) +end + +function pullback_of_restructure(config::RuleConfig, U::UpperTriangular) + return pullback_of_restructure(U) +end + + + +# Cholesky -- you get to choose whatever destructuring operation is helpful for a given +# type. This one is helpful for writing generic pullbacks for `cholesky`, the output of +# which is a Cholesky. Not completed. Probably won't be included in initial merge. +destructure(C::Cholesky) = Cholesky(destructure(C.factors), C.uplo, C.info) + +# Restructure(C::P) where {P<:Cholesky} = Restructure{P, Nothing}() diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 10912ce61..9d66700b2 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -547,3 +547,45 @@ function _constrain_and_name(arg::Expr, _) error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type + + + +# This comment is for reviewing purposes, and will need to be replaced later. +# This will often make life really easy. Just requires that pullback_of_restructure is +# defined for C, and pullback_of_destructure for A and B. Could be generalised to make +# different assumptions (e.g. some arguments don't require destructuring, output doesn't +# require restructuring, etc). Would need to be generalised to arbitrary numbers of +# arguments (clearly doable -- at worst requires a generated function). +# I'm assuming that functions don't need to have the destructure pullback applied to them, +# but this probably won't always be true. +function wrap_natural_pullback(config, natural_pullback, C, A, B) + + # Generate enclosing pullbacks. Notice that C / A / B only appear here, and aren't + # part of the closure returned. This means that they don't need to be carried around, + # which is good. + destructure_A_pb = pullback_of_destructure(config, A) + destructure_B_pb = pullback_of_destructure(config, B) + restructure_C_pb = pullback_of_restructure(config, C) + + # Wrap natural_pullback to make it play nicely with AD. + function generic_pullback(C̄) + C̄_natural = restructure_C_pb(C̄) + f̄, Ā_natural, B̄_natural = natural_pullback(C̄_natural) + Ā = destructure_A_pb(Ā_natural) + B̄ = destructure_B_pb(B̄_natural) + return f̄, Ā, B̄ + end + return generic_pullback +end + +function wrap_natural_pullback(config, natural_pullback, B, A) + destructure_input_pb = pullback_of_destructure(config, A) + restructure_output_pb = pullback_of_restructure(config, B) + function generic_pullback(B̄) + B̄_natural = restructure_output_pb(B̄) + f̄, Ā_natural = natural_pullback(B̄_natural) + Ā = destructure_input_pb(Ā_natural) + return f̄, Ā + end + return generic_pullback +end diff --git a/test/destructure.jl b/test/destructure.jl new file mode 100644 index 000000000..5c37d2a13 --- /dev/null +++ b/test/destructure.jl @@ -0,0 +1,92 @@ +using ChainRulesCore: + destructure, + Restructure, + pushforward_of_destructure, + pullback_of_destructure, + pullback_of_restructure + +# Need structural versions for tests, rather than the thing we currently have in +# FiniteDifferences. +function FiniteDifferences.to_vec(X::Diagonal) + diag_vec, diag_from_vec = to_vec(X.diag) + Diagonal_from_vec(diag_vec) = Diagonal(diag_from_vec(diag_vec)) + return diag_vec, Diagonal_from_vec +end + +function FiniteDifferences.to_vec(X::Symmetric) + data_vec, data_from_vec = to_vec(X.data) + Symmetric_from_vec(data_vec) = Symmetric(data_from_vec(data_vec)) + return data_vec, Symmetric_from_vec +end + +interpret_as_Tangent(x::Array) = x + +interpret_as_Tangent(d::Diagonal) = Tangent{Diagonal}(diag=d.diag) + +interpret_as_Tangent(s::Symmetric) = Tangent{Symmetric}(data=s.data) + +Base.isapprox(t::Tangent{<:Diagonal}, d::Diagonal) = isapprox(t.diag, d.diag) + +Base.isapprox(t::Tangent{<:Symmetric}, s::Symmetric) = isapprox(t.data, s.data) + +function check_destructure(x::AbstractArray, ȳ, ẋ) + + # Verify correctness of frule. + yf, ẏ = frule((NoTangent(), ẋ), destructure, x) + @test yf ≈ destructure(x) + + ẏ_fd = jvp(central_fdm(5, 1), destructure, (x, ẋ)) + @test ẏ ≈ ẏ_fd + + yr, pb = rrule(destructure, x) + _, x̄_r = pb(ȳ) + @test yr ≈ destructure(x) + + # Use inner product relationship to avoid needing CRTU. + @test dot(ȳ, ẏ_fd) ≈ dot(x̄_r, ẋ) + + # Check that the round-trip in tangent / cotangent space is the identity function. + pf_des = pushforward_of_destructure(x) + pb_des = pullback_of_destructure(x) + + # I thought that maybe the pushforward of destructure would be equivalent to the + # pullback of restructure, but that doesn't seem to hold. Not sure why / whether I + # should have thought it might be a thing in the first place. + ẋ_dense = pf_des(ẋ) + # @test ẋ_dense ≈ pf_des(pb_des(ẋ_dense)) + + # Check that the round-trip is the identity function. + @test x ≈ Restructure(x)(destructure(x)) + + # Verify frule of restructure. + x_dense = destructure(x) + x_re, ẋ_re = frule((NoTangent(), ẋ_dense), Restructure(x), x_dense) + @test x_re ≈ Restructure(x)(x_dense) + + ẋ_re_fd = FiniteDifferences.jvp(central_fdm(5, 1), Restructure(x), (x_dense, ẋ_dense)) + @test ẋ_re ≈ ẋ_re_fd + + # Verify rrule of restructure. + x_re_r, pb_r = rrule(Restructure(x), x_dense) + _, x̄_dense = pb_r(ẋ) + + # ẋ serves as the cotangent for the reconstructed x + @test dot(ẋ, interpret_as_Tangent(ẋ_re_fd)) ≈ dot(x̄_dense, ẋ_dense) + + # Check the rrule for destructure. + pb_des = pullback_of_destructure(x) + pb_res = pullback_of_restructure(x) + x̄_des = pb_res(ẋ) + + # I don't have access to test_approx here. + # Since I need to test these, maybe chunks of this PR belong in ChainRules.jl. + @test x̄_des ≈ pb_res(pb_des(x̄_des)) +end + +@testset "destructure" begin + check_destructure(randn(3, 3), randn(3, 3), randn(3, 3)) + check_destructure(Diagonal(randn(3)), randn(3, 3), Tangent{Diagonal}(diag=randn(3))) + check_destructure( + Symmetric(randn(3, 3)), randn(3, 3), Tangent{Symmetric}(data=randn(3, 3)), + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index e4499b0ff..199a0e078 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Base.Broadcast: broadcastable using BenchmarkTools using ChainRulesCore +using FiniteDifferences using LinearAlgebra using LinearAlgebra.BLAS: ger!, gemv!, gemv, scal! using StaticArrays @@ -8,19 +9,20 @@ using SparseArrays using Test @testset "ChainRulesCore" begin - @testset "differentials" begin - include("differentials/abstract_zero.jl") - include("differentials/thunks.jl") - include("differentials/composite.jl") - include("differentials/notimplemented.jl") - end + # @testset "differentials" begin + # include("differentials/abstract_zero.jl") + # include("differentials/thunks.jl") + # include("differentials/composite.jl") + # include("differentials/notimplemented.jl") + # end - include("accumulation.jl") - include("projection.jl") + # include("accumulation.jl") - include("rules.jl") - include("rule_definition_tools.jl") - include("config.jl") + # include("rules.jl") + # include("rule_definition_tools.jl") + # include("config.jl") - include("deprecated.jl") + include("destructure.jl") + + # include("deprecated.jl") end