From 41a0dbbf181f6e4dbe22b61482b724b5c56dae67 Mon Sep 17 00:00:00 2001 From: Jonathan Doucette Date: Tue, 25 Oct 2022 15:23:55 -0700 Subject: [PATCH 1/3] mark `OptimiserChain` with `@functor`, and improve type inference for `apply!(o::OptimiserChain, ...)` --- src/rules.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index c024feb2..3ca7245e 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -618,15 +618,15 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule end OptimiserChain(opts...) = OptimiserChain(opts) -init(o::OptimiserChain, x::AbstractArray) = [init(opt, x) for opt in o.opts] +@functor OptimiserChain + +init(o::OptimiserChain, x::AbstractArray) = map(opt -> init(opt, x), o.opts) function apply!(o::OptimiserChain, states, x, dx, dxs...) - new_states = similar(states) - for (i, (opt, state)) in enumerate(zip(o.opts, states)) - new_states[i], dx = apply!(opt, state, x, dx, dxs...) + foldl(tuple.(o.opts, states); init = ((), dx)) do (states′, dx′), (opt, state) + state′, dx′ = apply!(opt, state, x, dx′, dxs...) + return (states′..., state′), dx′ end - - return new_states, dx end function Base.show(io::IO, c::OptimiserChain) From 1735625db97c7119899cc6149280e178f3d0560a Mon Sep 17 00:00:00 2001 From: Jonathan Doucette Date: Tue, 25 Oct 2022 16:08:51 -0700 Subject: [PATCH 2/3] fix doc tests --- src/adjust.jl | 6 +++--- src/rules.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/adjust.jl b/src/adjust.jl index 78b3d452..cf000f74 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -32,13 +32,13 @@ julia> fieldnames(Adam) (:eta, :beta, :epsilon) julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m) -(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing) +(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ()) julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad -(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing) +(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ()) julia> Optimisers.adjust(st; beta = "no such field") # silently ignored! -(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing) +(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ()) ``` """ adjust(tree, eta::Real) = map(st -> adjust(st, eta), tree) diff --git a/src/rules.jl b/src/rules.jl index 3ca7245e..0b366faa 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -607,7 +607,7 @@ julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1)); julia> m = (zeros(3),); julia> s = Optimisers.setup(o, m) -(Leaf(OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)), [nothing, nothing]),) +(Leaf(OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)), (nothing, nothing)),) julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting ([-0.03, -0.1, -0.1],) From 9ac27ffe801b80749244e01fe679d22663829503 Mon Sep 17 00:00:00 2001 From: Jonathan Doucette Date: Tue, 25 Oct 2022 19:37:55 -0700 Subject: [PATCH 3/3] fix more doctests --- src/adjust.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/adjust.jl b/src/adjust.jl index cf000f74..5ff6b7f9 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -9,7 +9,7 @@ through training. To change just the learning rate, provide a number `η::Real`. # Example -```jldoctest +```jldoctest adjust julia> m = (vec = rand(Float32, 2), fun = sin); julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero @@ -27,7 +27,7 @@ julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentu To change other parameters, `adjust` also accepts keyword arguments matching the field names of the optimisation rule's type. -``` +```jldoctest adjust julia> fieldnames(Adam) (:eta, :beta, :epsilon) @@ -38,7 +38,7 @@ julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts (vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ()) julia> Optimisers.adjust(st; beta = "no such field") # silently ignored! -(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ()) +(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ()) ``` """ adjust(tree, eta::Real) = map(st -> adjust(st, eta), tree)