-
-
Notifications
You must be signed in to change notification settings - Fork 611
Description
I was wondering if you would be open to an API improvement, which would be completely optional and also minimal. The goal would be reduce potential for bugs and also make code more intuitive.
The following code is how one currently initializes an optimizer:
p = params(model)
opt = Adam(1e-3)
for i=1:1000
...
update!(opt, p, grad)
end
It is not until the update!
step does the optimizer actually initialize its state and read the parameters.
I think there are a couple potential issues with this:
params
andAdam
are initialized quite far apart from eachother, even though they are very interrelated objects. This distance between them makes it easier for me to change the variable name forp
, but not foropt
. I could be runningupdate!
with an old optimizer (with its old state), and not even realize it.- Knowing that
Adam
actually records information about the parameters, it seems unintuitive for it to not know about them when the object is created. To a beginner user, they might think that somehow just by initializingAdam
afterparams
, some global method would connect them to eachother, rather than the actual loading happening at theupdate!
step.
I am wondering what you think about the following minimal (and optional) change to remedy this. I think it could be good for each optimizer to record the objectid
of the parameter object at initialization, and throw an error if the user attempts to update them with a different set of parameters.
For example:
mutable struct Descent <: AbstractOptimiser
eta::Float64
param_id::Union{UInt64,Nothing}
end
Descent() = Descent(0.1, nothing)
Descent(eta) = Descent(eta, nothing)
Descent(eta, parameters) = Descent(eta, objectid(parameters))
Then, in the update!
code, you would check whether param_id
is nothing
, and if it is not, you would verify that objectid(parameters)
indeed matches what is stored.
Thus, it wouldn't affect any current code, but in the future it would let people start writing safer and more intuitive code, like:
p = params(model)
opt = Adam(1e-3, p)
for i=1:1000
...
update!(opt, p, grad)
end
And, if I forget to change the name of p
when copying the loop:
p2 = params(model)
opt2 = Adam(1e-3, p2)
for i=1:1000
...
update!(opt2, p, grad) # Throws an error!
end
In the future you could also have the constructor initialize the state of the optimizer, rather than just recording the objectid
. I think just having this is a good start, though.
Thoughts on this?
Cheers,
Miles