-
-
Notifications
You must be signed in to change notification settings - Fork 611
Description
There are a few cases where I find myself wondering if we should make it more explicit how we can extend the train loop design to be more friendly for callbacks not having to cheat to get things like the loss and so on. Further, things like FluxTraining.jl also show that we have a certain lack of preexisting callbacks, which don't need to be rewritten.
So keeping this stuff in mind, I think using pullback
instead of gradient
would be a step towards that, as well as not optimising before a prehook to check for callback conditions etc. This should also fall in nicely how we want to set up schedulers. I would also want to figure out where distributed and multi gpu falls in this, so we know how to proceed.
We don't necessarily want to return the losses etc, but perhaps a slightly more trained model? This would fall in line with how Optimisers.jl is looking as well.
cc @lorenzoh