Skip to content

Commit d8512a3

Browse files
committed
Added script for comparing pullbacks and added an explanation for why we implemented the pullback the way we did.
1 parent ac49890 commit d8512a3

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

scripts/pullback_comparison.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using SymbolicNeuralNetworks
2+
using AbstractNeuralNetworks
3+
using GeometricMachineLearning
4+
using AbstractNeuralNetworks: FeedForwardLoss
5+
using GeometricMachineLearning: ZygotePullback
6+
import Random
7+
Random.seed!(123)
8+
9+
c = Chain(Dense(2, 3, tanh), Dense(3, 1, tanh))
10+
nn = SymbolicNeuralNetwork(c)
11+
nn_cpu = NeuralNetwork(c, CPU())
12+
loss = FeedForwardLoss()
13+
spb = SymbolicPullback(nn, loss)
14+
zpb = ZygotePullback(loss)
15+
16+
batch_size = 10000
17+
input = rand(2, batch_size)
18+
output = rand(1, batch_size)
19+
# output sensitivities
20+
_do = 1.
21+
22+
spb(nn_cpu.params, nn.model, (input, output))[2](_do)
23+
zpb(nn_cpu.params, nn.model, (input, output))[2](_do)
24+
@time spb_evaluated = spb(nn_cpu.params, nn.model, (input, output))[2](_do)
25+
@time zpb_evaluated = zpb(nn_cpu.params, nn.model, (input, output))[2](_do)[1].params
26+
# @assert values(spb_evaluated) .≈ values(zpb_evaluated)

src/pullback.jl

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
@doc raw"""
22
SymbolicPullback <: AbstractPullback
33
44
`SymbolicPullback` computes the *symbolic pullback* of a loss function.
@@ -22,6 +22,63 @@ pv_values = pb(ps, nn.model, (rand(2), rand(1)))[2](1) |> typeof
2222
2323
@NamedTuple{L1::@NamedTuple{W::Matrix{Float64}, b::Vector{Float64}}}
2424
```
25+
26+
# Implementation
27+
28+
An instance of `SymbolicPullback` stores
29+
- `loss`: an instance of a `NetworkLoss`,
30+
- `fun`: a function that is used to compute the pullback.
31+
32+
If we call the functor of an instance of `SymbolicPullback` on `model`, `ps` and `input` it returns:
33+
```julia
34+
_pullback.loss(model, ps, input...), _pullback.fun(input..., ps)
35+
```
36+
where the second output argument is again a function.
37+
38+
# Extended help
39+
40+
We note the following seeming peculiarity:
41+
42+
```jldoctest
43+
using SymbolicNeuralNetworks
44+
using AbstractNeuralNetworks
45+
using Symbolics
46+
import Random
47+
Random.seed!(123)
48+
49+
c = Chain(Dense(2, 1, tanh))
50+
nn = SymbolicNeuralNetwork(c)
51+
loss = FeedForwardLoss()
52+
pb = SymbolicPullback(nn, loss)
53+
ps = initialparameters(c) |> NeuralNetworkParameters
54+
input_output = (rand(2), rand(1))
55+
loss_and_pullback = pb(ps, nn.model, input_output)
56+
pv_values = loss_and_pullback[2](1)
57+
58+
@variables soutput[1:SymbolicNeuralNetworks.output_dimension(nn.model)]
59+
symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, nn.params, nn.input, soutput), nn)
60+
pv_values2 = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)(input_output[1], input_output[2], ps)
61+
62+
pv_values == (pv_values2 |> SymbolicNeuralNetworks._get_params |> SymbolicNeuralNetworks._get_contents)
63+
64+
# output
65+
66+
true
67+
```
68+
69+
See the docstrings for [`symbolic_pullback`](@ref), [`build_nn_function`](@ref), [`_get_params`](@ref) and [`_get_contents`](@ref) for more info on the functions that we used here.
70+
The noteworthy thing in the expression above is that the functor of `SymbolicPullback` returns two objects: the first one is the loss value evaluated for the relevant parameters and inputs. The second one is a function that takes again an input argument and then finally returns the partial derivatives. But why do we need this extra step with another function?
71+
72+
!!! info "Reverse Accumulation"
73+
In machine learning we typically do [reverse accumulation](https://en.wikipedia.org/wiki/Automatic_differentiation#Forward_and_reverse_accumulation) to perform automatic differentiation (AD).
74+
Assuming we are given a function that is the composition of simpler functions ``f = f_1\circ{}f_2\circ\cdots\circ{}f_n:\mathbb{R}^n\to\mathbb{R}^m`` *reverse differentiation* starts with *output sensitivities* and then successively feeds them through ``f_n``, ``f_{n-1}`` etc. So it does:
75+
```math
76+
(\nabla_xf)^T = (\nabla_{x}f_1)^T(\nabla_{f_1(x)}f_2)^T\cdots(\nabla_{f_{n-1}(\cdots{}x)}f_n)^T(do),
77+
```
78+
where ``do\in\mathbb{R}^m`` are the *output sensitivities* and the jacobians are stepwise multiplied from the left. So we propagate from the output stepwise back to the input. If we have ``m=1``, i.e. if the output is one-dimensional, then the *output sensitivities* may simply be taken to be ``do = 1``.
79+
80+
So in theory we could leave out this extra step: returning an object (that is stored in `pb.fun`) can be seen as unnecessary as we could simply store the equivalent of `pb.fun(1.)` in an instance of `SymbolicPullback`.
81+
It is however customary for a pullback to return a callable function (that depends on the *output sensitivities*), which is why we also choose to do this here, even if the *output sensitivities* are a scalar quantity.
2582
"""
2683
struct SymbolicPullback{NNLT, FT} <: AbstractPullback{NNLT}
2784
loss::NNLT
@@ -38,17 +95,40 @@ function SymbolicPullback(nn::SymbolicNeuralNetwork, loss::NetworkLoss)
3895
symbolic_pullbacks = symbolic_pullback(symbolic_loss, nn)
3996
pbs_executable = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)
4097
function pbs(input, output, params)
41-
_ -> (pbs_executable(input, output, params) |> _get_params |> _get_contents)
98+
pullback(::Union{Real, AbstractArray{<:Real}}) = _get_contents(_get_params(pbs_executable(input, output, params)))
99+
pullback
42100
end
43101
SymbolicPullback(loss, pbs)
44102
end
45103

46104
SymbolicPullback(nn::SymbolicNeuralNetwork) = SymbolicPullback(nn, AbstractNeuralNetworks.FeedForwardLoss())
47105

106+
"""
107+
_get_params(ps::NeuralNetworkParameters)
108+
109+
Return the `NamedTuple` that's equivalent to the `NeuralNetworkParameters`.
110+
"""
48111
_get_params(nt::NamedTuple) = nt
49112
_get_params(ps::NeuralNetworkParameters) = ps.params
50113
_get_params(ps::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}) = [_get_params(nt) for nt in ps]
51114

115+
"""
116+
_get_contents(nt::AbstractArray{<:NamedTuple})
117+
118+
Return the contents of a one-dimensional vector.
119+
120+
# Examples
121+
122+
```jldoctest
123+
using SymbolicNeuralNetworks: _get_contents
124+
125+
_get_contents([(a = "element_contained_in_vector", )])
126+
127+
# output
128+
129+
(a = "element_contained_in_vector",)
130+
```
131+
"""
52132
_get_contents(nt::NamedTuple) = nt
53133
function _get_contents(nt::AbstractVector{<:NamedTuple})
54134
length(nt) == 1 ? nt[1] : __get_contents(nt)

0 commit comments

Comments
 (0)