Skip to content

Utility for walking a tree (e.g. gradients) w.r.t. a model #143

Closed
@darsnack

Description

@darsnack

Motivation and description

Using trainable, we can walk a model and only apply a function to trainable parameters. But the gradient from Zygote is a named tuple without this information.

Normally, for optimizers this is fine, because our function is applied at every leaf, so we only need a single pass over the model. But it is fairly common to walk entire tree of gradients to compute something (e.g. like a global norm term) first. In this case, we need a pass over gradient outside of the update context.

Possible Implementation

We can include a maptrainable(f, model, [gradient]) (or better name) function that maps a function w.r.t. the trainable parameters of model.

  • If another tree like gradient is passed, then f is applied to the leaves of gradient (i.e. approximately fmap(TrainableWalk(f), gradient, model) using the last argument to filter the walk).
  • If no other tree is passed, we just apply f to model (this is a simple walk but maybe it is good for consistency).

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions