Skip to content

How to update parameters using gradients #19

@lorenzoh

Description

@lorenzoh

I am trying to put together an example of using FastAI.jl to finetune a pretrained ResNet from torchvision and am unsure how to use the output of Zygote.grad on a TorchModuleWrapper to optimize its parameters.

So the question: what is the best way to update the parameters of a nested module? Have you tried this successfully and could share a minimal example?

I've tried applying Flux.Optimise.update! to the README example by iterating over params and gradients manually, but I am not sure if this will work for further nested structures.

model = TorchModuleWrapper(torch_module)
optim = ADAM()

for i in 1:10
    grad, = Zygote.gradient(m->loss(m, input, target), model)
    for (p, g) in zip(model.params, grad.params)
        Flux.Optimise.update!(optim, p, g)
    end
end

The above works, while Flux.Optimise.update!(optim, model, grad) does not. Maybe an overload for Flux.Optimise.update!(::AbstractOptimiser, ::TorchModuleWrapper, _) is needed?

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions