Skip to content

Rethink train design and better callbacks support #1461

@DhairyaLGandhi

Description

@DhairyaLGandhi

There are a few cases where I find myself wondering if we should make it more explicit how we can extend the train loop design to be more friendly for callbacks not having to cheat to get things like the loss and so on. Further, things like FluxTraining.jl also show that we have a certain lack of preexisting callbacks, which don't need to be rewritten.

So keeping this stuff in mind, I think using pullback instead of gradient would be a step towards that, as well as not optimising before a prehook to check for callback conditions etc. This should also fall in nicely how we want to set up schedulers. I would also want to figure out where distributed and multi gpu falls in this, so we know how to proceed.

We don't necessarily want to return the losses etc, but perhaps a slightly more trained model? This would fall in line with how Optimisers.jl is looking as well.

xref #1017 #1067 #1434

cc @lorenzoh

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions