Closed
Description
Motivation and description
Using trainable
, we can walk a model
and only apply a function to trainable parameters. But the gradient
from Zygote is a named tuple without this information.
Normally, for optimizers this is fine, because our function is applied at every leaf, so we only need a single pass over the model. But it is fairly common to walk entire tree of gradients to compute something (e.g. like a global norm term) first. In this case, we need a pass over gradient
outside of the update
context.
Possible Implementation
We can include a maptrainable(f, model, [gradient])
(or better name) function that maps a function w.r.t. the trainable parameters of model
.
- If another tree like
gradient
is passed, thenf
is applied to the leaves ofgradient
(i.e. approximatelyfmap(TrainableWalk(f), gradient, model)
using the last argument to filter the walk). - If no other tree is passed, we just apply
f
tomodel
(this is a simple walk but maybe it is good for consistency).