You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
See the docstrings for [`symbolic_pullback`](@ref), [`build_nn_function`](@ref), [`_get_params`](@ref) and [`_get_contents`](@ref) for more info on the functions that we used here.
70
+
The noteworthy thing in the expression above is that the functor of `SymbolicPullback` returns two objects: the first one is the loss value evaluated for the relevant parameters and inputs. The second one is a function that takes again an input argument and then finally returns the partial derivatives. But why do we need this extra step with another function?
71
+
72
+
!!! info "Reverse Accumulation"
73
+
In machine learning we typically do [reverse accumulation](https://en.wikipedia.org/wiki/Automatic_differentiation#Forward_and_reverse_accumulation) to perform automatic differentiation (AD).
74
+
Assuming we are given a function that is the composition of simpler functions ``f = f_1\circ{}f_2\circ\cdots\circ{}f_n:\mathbb{R}^n\to\mathbb{R}^m`` *reverse differentiation* starts with *output sensitivities* and then successively feeds them through ``f_n``, ``f_{n-1}`` etc. So it does:
where ``do\in\mathbb{R}^m`` are the *output sensitivities* and the jacobians are stepwise multiplied from the left. So we propagate from the output stepwise back to the input. If we have ``m=1``, i.e. if the output is one-dimensional, then the *output sensitivities* may simply be taken to be ``do = 1``.
79
+
80
+
So in theory we could leave out this extra step: returning an object (that is stored in `pb.fun`) can be seen as unnecessary as we could simply store the equivalent of `pb.fun(1.)` in an instance of `SymbolicPullback`.
81
+
It is however customary for a pullback to return a callable function (that depends on the *output sensitivities*), which is why we also choose to do this here, even if the *output sensitivities* are a scalar quantity.
0 commit comments