Skip to content

[API] Preventing errors from misplaced optimizer objects #2106

@MilesCranmer

Description

@MilesCranmer

I was wondering if you would be open to an API improvement, which would be completely optional and also minimal. The goal would be reduce potential for bugs and also make code more intuitive.

The following code is how one currently initializes an optimizer:

p = params(model)
opt = Adam(1e-3)

for i=1:1000
    ...
    update!(opt, p, grad)
end

It is not until the update! step does the optimizer actually initialize its state and read the parameters.

I think there are a couple potential issues with this:

  1. params and Adam are initialized quite far apart from eachother, even though they are very interrelated objects. This distance between them makes it easier for me to change the variable name for p, but not for opt. I could be running update! with an old optimizer (with its old state), and not even realize it.
  2. Knowing that Adam actually records information about the parameters, it seems unintuitive for it to not know about them when the object is created. To a beginner user, they might think that somehow just by initializing Adam after params, some global method would connect them to eachother, rather than the actual loading happening at the update! step.

I am wondering what you think about the following minimal (and optional) change to remedy this. I think it could be good for each optimizer to record the objectid of the parameter object at initialization, and throw an error if the user attempts to update them with a different set of parameters.

For example:

mutable struct Descent <: AbstractOptimiser
  eta::Float64
  param_id::Union{UInt64,Nothing}
end
Descent() = Descent(0.1, nothing)
Descent(eta) = Descent(eta, nothing)
Descent(eta, parameters) = Descent(eta, objectid(parameters))

Then, in the update! code, you would check whether param_id is nothing, and if it is not, you would verify that objectid(parameters) indeed matches what is stored.

Thus, it wouldn't affect any current code, but in the future it would let people start writing safer and more intuitive code, like:

p = params(model)
opt = Adam(1e-3, p)

for i=1:1000
    ...
    update!(opt, p, grad)
end

And, if I forget to change the name of p when copying the loop:

p2 = params(model)
opt2 = Adam(1e-3, p2)

for i=1:1000
    ...
    update!(opt2, p, grad)  # Throws an error!
end

In the future you could also have the constructor initialize the state of the optimizer, rather than just recording the objectid. I think just having this is a good start, though.

Thoughts on this?

Cheers,
Miles

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions