-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Labels
help wantedExtra attention is neededExtra attention is needed
Description
I am trying to put together an example of using FastAI.jl to finetune a pretrained ResNet from torchvision
and am unsure how to use the output of Zygote.grad
on a TorchModuleWrapper
to optimize its parameters.
So the question: what is the best way to update the parameters of a nested module? Have you tried this successfully and could share a minimal example?
I've tried applying Flux.Optimise.update!
to the README example by iterating over params and gradients manually, but I am not sure if this will work for further nested structures.
model = TorchModuleWrapper(torch_module)
optim = ADAM()
for i in 1:10
grad, = Zygote.gradient(m->loss(m, input, target), model)
for (p, g) in zip(model.params, grad.params)
Flux.Optimise.update!(optim, p, g)
end
end
The above works, while Flux.Optimise.update!(optim, model, grad)
does not. Maybe an overload for Flux.Optimise.update!(::AbstractOptimiser, ::TorchModuleWrapper, _)
is needed?
terasakisatoshi
Metadata
Metadata
Assignees
Labels
help wantedExtra attention is neededExtra attention is needed