diff --git a/src/adjust.jl b/src/adjust.jl index 78b3d452..5ff6b7f9 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -9,7 +9,7 @@ through training. To change just the learning rate, provide a number `η::Real`. # Example -```jldoctest +```jldoctest adjust julia> m = (vec = rand(Float32, 2), fun = sin); julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero @@ -27,18 +27,18 @@ julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentu To change other parameters, `adjust` also accepts keyword arguments matching the field names of the optimisation rule's type. -``` +```jldoctest adjust julia> fieldnames(Adam) (:eta, :beta, :epsilon) julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m) -(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing) +(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ()) julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad -(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing) +(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ()) julia> Optimisers.adjust(st; beta = "no such field") # silently ignored! -(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing) +(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ()) ``` """ adjust(tree, eta::Real) = map(st -> adjust(st, eta), tree) diff --git a/src/rules.jl b/src/rules.jl index c024feb2..0b366faa 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -607,7 +607,7 @@ julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1)); julia> m = (zeros(3),); julia> s = Optimisers.setup(o, m) -(Leaf(OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)), [nothing, nothing]),) +(Leaf(OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)), (nothing, nothing)),) julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting ([-0.03, -0.1, -0.1],) @@ -618,15 +618,15 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule end OptimiserChain(opts...) = OptimiserChain(opts) -init(o::OptimiserChain, x::AbstractArray) = [init(opt, x) for opt in o.opts] +@functor OptimiserChain + +init(o::OptimiserChain, x::AbstractArray) = map(opt -> init(opt, x), o.opts) function apply!(o::OptimiserChain, states, x, dx, dxs...) - new_states = similar(states) - for (i, (opt, state)) in enumerate(zip(o.opts, states)) - new_states[i], dx = apply!(opt, state, x, dx, dxs...) + foldl(tuple.(o.opts, states); init = ((), dx)) do (states′, dx′), (opt, state) + state′, dx′ = apply!(opt, state, x, dx′, dxs...) + return (states′..., state′), dx′ end - - return new_states, dx end function Base.show(io::IO, c::OptimiserChain)