diff --git a/NEWS.md b/NEWS.md index 9bf97ddb3b..5140435114 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,11 +1,22 @@ # Flux Release Notes +# v0.14 + +* The use of Zygote's implicit parameters (with `Flux.params` and global variables) is deprecated in favour of the explicit style. + The function `train!` has new methods (accepting the model itself) to handle this. + +* Sub-module `Flux.Optimise` has been removed, in favour of using [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) more deeply. + The function `train!` now lives in sub-module `Flux.Train`, and has re-written internals. + +* One-hot arrays have moved to a new package [OneHotArrays.jl](https://github.com/FluxML/OneHotArrays.jl) + ## v0.13.4 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983) -## v0.13 +## v0.13 (April 2022) + * After a deprecations cycle, the datasets in `Flux.Data` have -been removed in favour of MLDatasets.jl. + been removed in favour of [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl). * `params` is not exported anymore since it is a common name and is also exported by Distributions.jl * `flatten` is not exported anymore due to clash with Iterators.flatten. * Remove Juno.jl progress bar support as it is now obsolete. @@ -48,7 +59,7 @@ been removed in favour of MLDatasets.jl. * CUDA.jl 3.0 support * Bug fixes and optimizations. -## v0.12.0 +## v0.12.0 (March 2021) * Add [identity_init](https://github.com/FluxML/Flux.jl/pull/1524). * Add [Orthogonal Matrix initialization](https://github.com/FluxML/Flux.jl/pull/1496) as described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks](https://arxiv.org/abs/1312.6120). @@ -73,7 +84,7 @@ been removed in favour of MLDatasets.jl. * Adds the [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser. * Other new features and bug fixes (see GitHub releases page) -## v0.11 +## v0.11 (July 2020) * Moved CUDA compatibility to use [CUDA.jl instead of CuArrays.jl](https://github.com/FluxML/Flux.jl/pull/1204) * Add [kaiming initialization](https://arxiv.org/abs/1502.01852) methods: [kaiming_uniform and kaiming_normal](https://github.com/FluxML/Flux.jl/pull/1243) @@ -101,7 +112,7 @@ keyword argument. The `Dropout` struct *whose behavior is left unchanged) is the See GitHub's releases. -## v0.10.0 +## v0.10.0 (November 2019) * The default AD engine has switched from [Tracker to Zygote.jl](https://github.com/FluxML/Flux.jl/pull/669) - The dependency on Tracker.jl has been removed. diff --git a/Project.toml b/Project.toml index 07a7098b01..86ad76dd52 100644 --- a/Project.toml +++ b/Project.toml @@ -39,6 +39,7 @@ ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" StatsBase = "0.33" +Yota = "0.7.4" Zygote = "0.6.34" julia = "1.6" @@ -49,6 +50,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Yota = "cd998857-8626-517d-b929-70ad188a48f0" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Yota"] diff --git a/src/Flux.jl b/src/Flux.jl index 0cacbd419a..c8ae8153fb 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owne using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback, @nograd -export gradient +# export gradient # stop exporting this, to make people say "using Zygote", and make easier to replace # Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.) Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`") @@ -25,14 +25,15 @@ export Chain, Dense, Maxout, SkipConnection, Parallel, PairwiseFusion, fmap, cpu, gpu, f32, f64, testmode!, trainmode! -include("optimise/Optimise.jl") -using .Optimise -using .Optimise: @epochs -using .Optimise: skip -export Descent, Adam, Momentum, Nesterov, RMSProp, - AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, OAdam, - AdamW, RAdam, AdaBelief, InvDecay, ExpDecay, - WeightDecay, ClipValue, ClipNorm +include("train/Train.jl") +using .Train +export train! +# Stop exporting these, since Optimisers.jl exports the same names, +# and with this PR, Flux.Adam() is literally a wrapper around Adam(). +# export Descent, Adam, Momentum, Nesterov, RMSProp, +# AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, OAdam, +# AdamW, RAdam, AdaBelief, InvDecay, ExpDecay, +# WeightDecay, ClipValue, ClipNorm using CUDA const use_cuda = Ref{Union{Nothing,Bool}}(nothing) diff --git a/src/deprecations.jl b/src/deprecations.jl index 6719bd39e2..6cb73d2cf2 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -34,11 +34,6 @@ struct Zeros end Zeros(args...) = Zeros() # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros()) -function Optimise.update!(x::AbstractArray, x̄) - Base.depwarn("`Flux.Optimise.update!(x, x̄)` was not used internally and has been removed. Please write `x .-= x̄` instead.", :update!) - x .-= x̄ -end - function Diagonal(size::Integer...; kw...) Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal) Scale(size...; kw...) @@ -80,3 +75,6 @@ Base.@deprecate_binding RADAM RAdam Base.@deprecate_binding OADAM OAdam Base.@deprecate_binding ADAGrad AdaGrad Base.@deprecate_binding ADADelta AdaDelta + +# What remains from the Optimise sub-module has moved to Train: +Base.@deprecate_binding Optimise Train diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl deleted file mode 100644 index e691ce0170..0000000000 --- a/src/optimise/Optimise.jl +++ /dev/null @@ -1,15 +0,0 @@ -module Optimise - -using LinearAlgebra -import ArrayInterface - -export train!, update!, - Descent, Adam, Momentum, Nesterov, RMSProp, - AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW,RAdam, OAdam, AdaBelief, - InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, - ClipValue, ClipNorm - -include("optimisers.jl") -include("train.jl") - -end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl deleted file mode 100644 index ce72a4b0ce..0000000000 --- a/src/optimise/optimisers.jl +++ /dev/null @@ -1,724 +0,0 @@ -using Flux -using MacroTools: @forward - -abstract type AbstractOptimiser end - -const EPS = 1e-8 - -# TODO: should use weak refs - -""" - Descent(η = 0.1) - -Classic gradient descent optimiser with learning rate `η`. -For each parameter `p` and its gradient `δp`, this runs `p -= η*δp` - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. - -# Examples -```julia -opt = Descent() - -opt = Descent(0.3) - -ps = Flux.params(model) - -gs = gradient(ps) do - loss(x, y) -end - -Flux.Optimise.update!(opt, ps, gs) -``` -""" -mutable struct Descent <: AbstractOptimiser - eta::Float64 -end - -Descent() = Descent(0.1) - -function apply!(o::Descent, x, Δ) - Δ .*= o.eta -end - -""" - Momentum(η = 0.01, ρ = 0.9) - -Gradient descent optimizer with learning rate `η` and momentum `ρ`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect damping oscillations. - -# Examples -```julia -opt = Momentum() - -opt = Momentum(0.01, 0.99) -``` -""" -mutable struct Momentum <: AbstractOptimiser - eta::Float64 - rho::Float64 - velocity::IdDict -end - -Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) - -function apply!(o::Momentum, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - @. v = ρ * v - η * Δ - @. Δ = -v -end - -""" - Nesterov(η = 0.001, ρ = 0.9) - -Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Nesterov momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect damping oscillations. - -# Examples -```julia -opt = Nesterov() - -opt = Nesterov(0.003, 0.95) -``` -""" -mutable struct Nesterov <: AbstractOptimiser - eta::Float64 - rho::Float64 - velocity::IdDict -end - -Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) - -function apply!(o::Nesterov, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - d = @. ρ^2 * v - (1+ρ) * η * Δ - @. v = ρ*v - η*Δ - @. Δ = -d -end - -""" - RMSProp(η = 0.001, ρ = 0.9, ϵ = $EPS) - -Optimizer using the -[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -algorithm. Often a good choice for recurrent networks. Parameters other than learning rate -generally don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect damping oscillations. - -# Examples -```julia -opt = RMSProp() - -opt = RMSProp(0.002, 0.95) -``` -""" -mutable struct RMSProp <: AbstractOptimiser - eta::Float64 - rho::Float64 - epsilon::Float64 - acc::IdDict -end -RMSProp(η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = EPS) = RMSProp(η, ρ, ϵ, IdDict()) -RMSProp(η::Real, ρ::Real, acc::IdDict) = RMSProp(η, ρ, EPS, acc) - -function apply!(o::RMSProp, x, Δ) - η, ρ = o.eta, o.rho - acc = get!(() -> zero(x), o.acc, x)::typeof(x) - @. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ) - @. Δ *= η / (√acc + o.epsilon) -end - -""" - Adam(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -[Adam](https://arxiv.org/abs/1412.6980) optimiser. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = Adam() - -opt = Adam(0.001, (0.9, 0.8)) -``` -""" -mutable struct Adam <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -Adam(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = Adam(η, β, ϵ, IdDict()) -Adam(η::Real, β::Tuple, state::IdDict) = Adam(η, β, EPS, state) - -function apply!(o::Adam, x, Δ) - η, β = o.eta, o.beta - - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ) - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + o.epsilon) * η - βp .= βp .* β - - return Δ -end - -""" - RAdam(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -[Rectified Adam](https://arxiv.org/abs/1908.03265) optimizer. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = RAdam() - -opt = RAdam(0.001, (0.9, 0.8)) -``` -""" -mutable struct RAdam <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -RAdam(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = RAdam(η, β, ϵ, IdDict()) -RAdam(η::Real, β::Tuple, state::IdDict) = RAdam(η, β, EPS, state) - -function apply!(o::RAdam, x, Δ) - η, β = o.eta, o.beta - ρ∞ = 2/(1-β[2])-1 - - mt, vt, βp, t = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]], Ref(1)) - end :: Tuple{typeof(x),typeof(x),Vector{Float64},Base.RefValue{Int}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ) - ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2]) - if ρ > 4 - r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ)) - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + o.epsilon) * η * r - else - @. Δ = mt / (1 - βp[1]) * η - end - βp .= βp .* β - t[] += 1 - - return Δ -end - -""" - AdaMax(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of Adam based on the ∞-norm. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AdaMax() - -opt = AdaMax(0.001, (0.9, 0.995)) -``` -""" -mutable struct AdaMax <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -AdaMax(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = AdaMax(η, β, ϵ, IdDict()) -AdaMax(η::Real, β::Tuple, state::IdDict) = AdaMax(η, β, EPS, state) - -function apply!(o::AdaMax, x, Δ) - η, β = o.eta, o.beta - - mt, ut, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. ut = max(β[2] * ut, abs(Δ)) - @. Δ = (η/(1 - βp[1])) * mt/(ut + o.epsilon) - βp .= βp .* β - - return Δ -end - -""" - OAdam(η = 0.0001, β::Tuple = (0.5, 0.9), ϵ = $EPS) - -[OAdam](https://arxiv.org/abs/1711.00141) (Optimistic Adam) -is a variant of Adam adding an "optimistic" term suitable for adversarial training. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = OAdam() - -opt = OAdam(0.001, (0.9, 0.995)) -``` -""" -mutable struct OAdam <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -OAdam(η::Real = 0.001, β::Tuple = (0.5, 0.9), ϵ::Real = EPS) = OAdam(η, β, ϵ, IdDict()) -OAdam(η::Real, β::Tuple, state::IdDict) = RMSProp(η, β, EPS, state) - -function apply!(o::OAdam, x, Δ) - η, β = o.eta, o.beta - - mt, vt, Δ_, βp = get!(o.state, x) do - (zero(x), zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ) - @. Δ = -Δ_ - @. Δ_ = η * mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + o.epsilon) - @. Δ += 2Δ_ - βp .= βp .* β - - return Δ -end - -""" - AdaGrad(η = 0.1, ϵ = $EPS) - -[AdaGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has -parameter specific learning rates based on how frequently it is updated. -Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. - -# Examples -```julia -opt = AdaGrad() - -opt = AdaGrad(0.001) -``` -""" -mutable struct AdaGrad <: AbstractOptimiser - eta::Float64 - epsilon::Float64 - acc::IdDict -end -AdaGrad(η::Real = 0.1, ϵ::Real = EPS) = AdaGrad(η, ϵ, IdDict()) -AdaGrad(η::Real, state::IdDict) = AdaGrad(η, EPS, state) - -function apply!(o::AdaGrad, x, Δ) - η = o.eta - acc = get!(() -> fill!(similar(x), o.epsilon), o.acc, x)::typeof(x) - @. acc += Δ * conj(Δ) - @. Δ *= η / (√acc + o.epsilon) -end - -""" - AdaDelta(ρ = 0.9, ϵ = $EPS) - -[AdaDelta](https://arxiv.org/abs/1212.5701) is a version of AdaGrad adapting its learning -rate based on a window of past gradient updates. -Parameters don't need tuning. - -# Parameters -- Rho (`ρ`): Factor by which the gradient is decayed at each time step. - -# Examples -```julia -opt = AdaDelta() - -opt = AdaDelta(0.89) -``` -""" -mutable struct AdaDelta <: AbstractOptimiser - rho::Float64 - epsilon::Float64 - state::IdDict{Any, Any} -end -AdaDelta(ρ::Real = 0.9, ϵ::Real = EPS) = AdaDelta(ρ, ϵ, IdDict()) -AdaDelta(ρ::Real, state::IdDict) = AdaDelta(ρ, EPS, state) - -function apply!(o::AdaDelta, x, Δ) - ρ = o.rho - acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} - @. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ) - # DON'T remove epsilon from numerator - # or even out of the square roots - @. Δ *= √(Δacc + o.epsilon) / √(acc + o.epsilon) - @. Δacc = ρ * Δacc + (1 - ρ) * Δ * conj(Δ) - return Δ -end - -""" - AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the Adam -optimiser. Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AMSGrad() - -opt = AMSGrad(0.001, (0.89, 0.995)) -``` -""" -mutable struct AMSGrad <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64, Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -AMSGrad(η::Real = 0.001, β = (0.9, 0.999), ϵ::Real = EPS) = AMSGrad(η, β, ϵ, IdDict()) -AMSGrad(η::Real, β::Tuple, state::IdDict) = AMSGrad(η, β, EPS, state) - -function apply!(o::AMSGrad, x, Δ) - η, β = o.eta, o.beta - - mt, vt, v̂t = get!(o.state, x) do - (fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon)) - end :: NTuple{3,typeof(x)} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 - @. v̂t = max(v̂t, vt) - @. Δ = η * mt / (√v̂t + o.epsilon) -end - -""" - NAdam(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -[NAdam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of Adam. -Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = NAdam() - -opt = NAdam(0.002, (0.89, 0.995)) -``` -""" -mutable struct NAdam <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64, Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -NAdam(η::Real = 0.001, β = (0.9, 0.999), ϵ::Real = EPS) = NAdam(η, β, ϵ, IdDict()) -NAdam(η::Real, β::Tuple, state::IdDict) = NAdam(η, β, EPS, state) - -function apply!(o::NAdam, x, Δ) - η, β = o.eta, o.beta - - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[o.beta[1], o.beta[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - β1p, β2p = βp - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ) - @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + o.epsilon) * η - βp .= βp .* β - - return Δ -end - -""" - AdamW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0) - -[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its -weight decay regularization. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. -- `decay`: Decay applied to weights during optimisation. - -# Examples -```julia -opt = AdamW() - -opt = AdamW(0.001, (0.89, 0.995), 0.1) -``` -""" -AdamW(η = 0.001, β = (0.9, 0.999), decay = 0) = - Optimiser(Adam(η, β), WeightDecay(decay)) - -""" - AdaBelief(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known -Adam optimiser. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AdaBelief() - -opt = AdaBelief(0.001, (0.9, 0.8)) -``` -""" -mutable struct AdaBelief <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -AdaBelief(η::Real = 0.001, β = (0.9, 0.999), ϵ::Real = EPS) = AdaBelief(η, β, ϵ, IdDict()) -AdaBelief(η::Real, β::Tuple, state::IdDict) = AdaBelief(η, β, EPS, state) - -function apply!(o::AdaBelief, x, Δ) - η, β = o.eta, o.beta - - mt, st, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x), typeof(x), Vector{Float64}} - - #= st is a variance and can go to zero. This is in contrast to Adam, which uses the - second moment which is usually far enough from zero. This is problematic, since st - can be slightly negative due to numerical error, and the square root below will fail. - Also, if we want to differentiate through the optimizer, √0 is not differentiable. - To protect against this, we add a small number, st -> st + eps2. - The original implementation (https://github.com/juntang-zhuang/Adabelief-Optimizer) - uses the square of Adam's epsilon, which we do here. - See also: https://github.com/juntang-zhuang/Adabelief-Optimizer/issues/61 =# - eps2 = o.epsilon^2 # TODO: make epsilon^2 the default in next breaking release - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. st = β[2] * st + (1 - β[2]) * (Δ - mt) * conj(Δ - mt) + eps2 - @. Δ = η * mt / (1 - βp[1]) / (√(st / (1 - βp[2])) + eps2) - βp .= βp .* β - - return Δ -end - - -# Compose optimizers - -""" - Optimiser(a, b, c...) - -Combine several optimisers into one; each optimiser produces a modified gradient -that will be fed into the next, and this is finally applied to the parameter as -usual. -""" -mutable struct Optimiser <: AbstractOptimiser - os::Vector{Any} -end - -Optimiser(opts::AbstractOptimiser...) = Optimiser(Any[opts...]) - -@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! -@forward Optimiser.os Base.iterate - -Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) - -function apply!(o::Optimiser, x, Δ) - for opt in o.os - Δ = apply!(opt, x, Δ) - end - return Δ -end - -""" - InvDecay(γ = 0.001) - -Apply inverse time decay to an optimiser, so that the effective step size at -iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size. -The wrapped optimiser's step size is not modified. - -See also the [Scheduling Optimisers](@ref) section of the docs -for more general scheduling techniques. - -# Examples - -`InvDecay` is typically composed with other optimizers -as the last transformation of the gradient: - -```julia -# Inverse decay of the learning rate -# with starting value 0.001 and decay coefficient 0.01. -opt = Optimiser(Adam(1f-3), InvDecay(1f-2)) -``` -""" -mutable struct InvDecay <: AbstractOptimiser - gamma::Float64 - state::IdDict{Any, Int} -end - -InvDecay(γ = 0.001) = InvDecay(γ, IdDict{Any, Int}()) - -function apply!(o::InvDecay, x, Δ) - γ = o.gamma - n = get!(o.state, x, 1) - Δ .*= 1 / (1 + γ * n) - o.state[x] = n + 1 - return Δ -end - -""" - ExpDecay(η = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4, start = 1) - -Discount the learning rate `η` by the factor `decay` every `decay_step` steps till -a minimum of `clip`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- `decay`: Factor by which the learning rate is discounted. -- `decay_step`: Schedule decay operations by setting the number of steps between - two decay operations. -- `clip`: Minimum value of learning rate. -- 'start': Step at which the decay starts. - - -See also the [Scheduling Optimisers](@ref) section of the docs -for more general scheduling techniques. - -# Examples - -`ExpDecay` is typically composed with other optimizers -as the last transformation of the gradient: -```julia -opt = Optimiser(Adam(), ExpDecay(1.0)) -``` -Note: you may want to start with `η=1` in `ExpDecay` when combined with other -optimizers (`Adam` in this case) that have their own learning rate. -""" -mutable struct ExpDecay <: AbstractOptimiser - eta::Float64 - decay::Float64 - step::Int64 - clip::Float64 - start::Int64 - current::IdDict -end - -ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4, start = 0) = - ExpDecay(opt, decay, decay_step, clip, start, IdDict()) - -function apply!(o::ExpDecay, x, Δ) - η, s, decay, start = o.eta, o.step, o.decay, o.start - n = o.current[x] = get(o.current, x, 0) + 1 - if n > start && n % s == 0 && count(x -> x > start && x % s == 0, values(o.current)) == 1 - η = max(η * decay, o.clip) - o.eta = η - end - @. Δ *= η -end - -""" - WeightDecay(λ = 0) - -Decay weights by ``λ``. -Typically composed with other optimizers as the first transformation to the gradient, -making it equivalent to adding ``L_2`` regularization -with coefficient ``λ`` to the loss. - -# Examples - -```julia -opt = Optimiser(WeightDecay(1f-4), Adam()) -``` -""" -mutable struct WeightDecay <: AbstractOptimiser - wd::Real -end - -WeightDecay() = WeightDecay(0) - -function apply!(o::WeightDecay, x, Δ) - wd = o.wd - @. Δ += wd * x -end - -""" - ClipValue(thresh) - -Clip gradients when their absolute value exceeds `thresh`. -""" -mutable struct ClipValue{T} <: AbstractOptimiser - thresh::T -end - -apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh) - -""" - ClipNorm(thresh) - -Clip gradients when their L2 norm exceeds `thresh`. -""" -mutable struct ClipNorm{T} <: AbstractOptimiser - thresh::T -end - -function apply!(o::ClipNorm, x, Δ) - Δnrm = norm(Δ) - if Δnrm > o.thresh - rmul!(Δ, o.thresh / Δnrm) - end - return Δ -end diff --git a/src/optimise/train.jl b/src/optimise/train.jl deleted file mode 100644 index b6dac7951b..0000000000 --- a/src/optimise/train.jl +++ /dev/null @@ -1,157 +0,0 @@ -using ProgressLogging: @progress, @withprogress, @logprogress -import Zygote: Params, gradient - - -""" - update!(opt, p, g) - update!(opt, ps::Params, gs) - -Perform an update step of the parameters `ps` (or the single parameter `p`) -according to optimizer `opt` and the gradients `gs` (the gradient `g`). - -As a result, the parameters are mutated and the optimizer's internal state may change. -The gradient could be mutated as well. -""" -function update!(opt::AbstractOptimiser, x, x̄) - x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's - # output are not mutable, see #1510 - x .-= apply!(opt, x, x̄r) -end - -function update!(opt::AbstractOptimiser, xs::Params, gs) - for x in xs - isnothing(gs[x]) && continue - update!(opt, x, gs[x]) - end -end - -# Callback niceties -call(f, xs...) = f(xs...) -runall(f) = f -runall(fs::AbstractVector) = () -> foreach(call, fs) - -struct SkipException <: Exception end - -""" - skip() - -Call `Flux.skip()` in a callback to indicate when a callback condition is met. -This will trigger the train loop to skip the current data point and not update with the calculated gradient. - -# Examples -```julia -cb = function () - loss() > 1e7 && Flux.skip() -end -``` -""" -function skip() - throw(SkipException()) -end - - -struct StopException <: Exception end - -""" - stop() - -Call `Flux.stop()` in a callback to indicate when a callback condition is met. -This will trigger the train loop to stop and exit. - -# Examples -```julia -cb = function () - accuracy() > 0.9 && Flux.stop() -end -``` -""" -function stop() - throw(StopException()) -end - -batchmemaybe(x) = tuple(x) -batchmemaybe(x::Tuple) = x - -""" - train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb]) - -Uses a `loss` function and training `data` to improve the -model's parameters according to a particular optimisation rule `opt`. - -For each `d in data`, first the gradient of the `loss` is computed like this: -``` - gradient(() -> loss(d...), pars) # if d isa Tuple - gradient(() -> loss(d), pars) # otherwise -``` -Here `pars` is produced by calling [`Flux.params`](@ref) on your model. -(Or just on the layers you want to train, like `train!(loss, params(model[1:end-2]), data, opt)`.) -This is the "implicit" style of parameter handling. - -Then, this gradient is used by optimizer `opt` to update the paramters: -``` - update!(opt, pars, grads) -``` -The optimiser should be from the `Flux.Optimise` module (see [Optimisers](@ref)). -Different optimisers can be combined using [`Flux.Optimise.Optimiser`](@ref Flux.Optimiser). - -This training loop iterates through `data` once. -You can use [`@epochs`](@ref) to do this several times, or -use for instance `Iterators.repeat` to make a longer `data` iterator. - -## Callbacks - -[Callbacks](@ref) are given with the keyword argument `cb`. -For example, this will print "training" every 10 seconds (using [`Flux.throttle`](@ref)): -``` - train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10)) -``` - -The callback can call [`Flux.stop`](@ref) to interrupt the training loop. - -Multiple callbacks can be passed to `cb` as array. -""" -function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) - cb = runall(cb) - itrsz = Base.IteratorSize(typeof(data)) - n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0 - @withprogress for (i, d) in enumerate(data) - try - gs = gradient(ps) do - loss(batchmemaybe(d)...) - end - update!(opt, ps, gs) - cb() - catch ex - if ex isa StopException - break - elseif ex isa SkipException - continue - else - rethrow(ex) - end - end - @logprogress iszero(n) ? nothing : i / n - end -end - -""" - @epochs N body - -Run `body` `N` times. Mainly useful for quickly doing multiple epochs of -training in a REPL. - -# Examples -```jldoctest -julia> Flux.@epochs 2 println("hello") -[ Info: Epoch 1 -hello -[ Info: Epoch 2 -hello -``` -""" -macro epochs(n, ex) - :(@progress for i = 1:$(esc(n)) - @info "Epoch $i" - $(esc(ex)) - end) -end diff --git a/src/train/Train.jl b/src/train/Train.jl new file mode 100644 index 0000000000..44a210df8b --- /dev/null +++ b/src/train/Train.jl @@ -0,0 +1,180 @@ +module Train + +using LinearAlgebra +using Optimisers: Optimisers +using Functors: fmap + +export train!, update!, adjust!, FluxState, @train_autodiff, + Descent, Adam, Momentum, Nesterov, RMSProp, + AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #, + # InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, + # ClipValue, ClipNorm + + +### Mutable state storage, to wrap Optimisers.jl + +""" + FluxState(rule, state=missing) + +This is an interface between the all-mutable world Flux.jl likes, +and the could-be-immutable world that Optimisers.jl inhabits. + +`state` can can be either the whole state tree which Optimisers.jl builds, +or else (for Zygote's implicit mode) an IdDict of such states. +Once initialised, it cannot change between these two modes. +""" +mutable struct FluxState{T<:Optimisers.AbstractRule}; + rule::T + state::Any +end + +function Base.show(io::IO, opt::FluxState) + print(io, "FluxState(") + show(io, opt.rule) + if opt.state isa Missing + print(io, ", )") + elseif opt.state isa IdDict + n = length(keys(opt.state)) + print(io, ", ))") + else + rn = Ref(0) + fmap(x -> (rn[]+=1; x), opt.state, exclude = (x -> x isa Optimisers.Leaf)) + print(io, ", )") + end +end + +_DESCENT_EXAMPLE = """# Implicit-style example +This usage matches Flux ≤ v0.13: +``` +opt = Flux.Descent(0.3) + +ps = Flux.params(model) # returns a Zygote.Params object + +gs = gradient(ps) do # gradient takes a zero-argument anonymous function + loss3(model, x, y) # ... which depends on the global model +end # ... and returns a Zygote.Grads object + +Flux.update!(opt, ps, gs) +``` +New on Flux v0.14 is a method `train!(loss, ps, opt)` which performs one step, +rather than iterating over `data`. This is equivalent to `gradient` and `update!` above: +``` +Flux.train!(ps, opt) do + loss3(model, x, y) +end +``` + +# Explicit-style example + +This no longer uses `Flux.params`, but instead the model itself: +``` +opt = Flux.Descent(0.3) # the same FluxState object + +Flux.train!(model, opt) do m # now explicitly depends on the model + loss3(m, x, y) +end +``` +""" +for opt in [ + :Descent, :Adam, :Momentum, :Nesterov, :RMSProp, + :AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :AdamW, :RAdam, :OAdam, :AdaBelief, + # :InvDecay, :ExpDecay, :WeightDecay, :Optimiser, + :ClipGrad, :ClipNorm, +# TODO sort out the remaining rules +] + @eval begin + $opt(parameters...; kw...) = FluxState(Optimisers.$opt(parameters...; kw...), missing) + str = string(""" Flux.$($opt)(args...) + + Returns `FluxState` wrapper around the following rule definition from Optimisers.jl, + allowing its use with `Flux.train!` (in the same manner as `Flux.AbstractOptimiser` objects on Flux ≤ v0.13). + Accepts the same arguments, with the same defaults, as the underlying rule: + + """, @doc(Optimisers.$opt), $opt == Descent ? _DESCENT_EXAMPLE : "") + @doc str $opt + end +end + +@deprecate ClipValue ClipGrad + + +### Two styles of gradient, and their `train!` functions + +using ProgressLogging: @progress, @withprogress, @logprogress # TODO add progress logging again +using Zygote: Zygote, Params + +include("explicit_train.jl") # new! +include("implicit_train.jl") # Params etc, Zygote only + +explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor + +""" + Flux.@train_autodiff Zygote + Flux.@train_autodiff Yota + Flux.@train_autodiff Diffractor + +This macro allows the use of `train!` with various automatic differentiation packages, +instead of the default Zygote.jl. +You should load the package, then call this macro. + +Only affects "explicit-mode" versions `train!(loss, model, data, opt)` or `train!(loss, model, opt)`, +since the (deprecated) "implicit-mode" `train!(loss, ps::Params, data, opt)` is Zygote-specific. + +Only works with [Yota.jl](https://github.com/dfdx/Yota.jl) and [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl), +and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl). + +!!! note + This is experimental! +""" +macro train_autodiff(pkg) + if pkg == :Diffractor + return quote + Diffractor.gradient(sin, 0.0)[1] ≈ 1.0 # ensures an error if not loaded + function Flux.Train.explicit_withgradient(f, args...) + y, back = Diffractor.∂⃖¹(f, args...) + dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors + return (; value = y, gradient = Base.tail(back(dy1))) + end + end |> esc + elseif pkg == :Yota + return quote + Yota.grad(sin, 0.0) # [2][1] ≈ 1.0 + function Flux.Train.explicit_withgradient(f, args...) + value, (_, gradient...) = Yota.grad(f, args...) + return (; value, gradient) + end + end |> esc + elseif pkg == :Zygote + return quote + Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...) + end |> esc + else + throw("@train_autodiff expects either Zygote, Yota, or Diffractor. No other arguments are understood.") + end +end + + +### Misc. related utilities + +""" + Flux.adjust!(opt::FluxState, η::Real) + +Alters the learning rate of the optimiser, +without resetting its stored momentum state, etc. +""" +function adjust!(opt::FluxState, eta::Real) + opt.rule = Optimisers.adjust(opt.rule, eta) + s = opt.state + if s isa missing + elseif s isa IdDict + for k in keys(s) + s[k] = Optimisers.adjust(s[k], eta) + end + else + s = Optimisers.adjust(s, eta) + end + opt.state = s + return opt +end + +end # module diff --git a/src/train/explicit_train.jl b/src/train/explicit_train.jl new file mode 100644 index 0000000000..cb1e2aea5d --- /dev/null +++ b/src/train/explicit_train.jl @@ -0,0 +1,129 @@ +""" + train!(loss, model, data, opt::FluxState) + +Flux 0.14 no longer uses Zygote's implicit parameter dictionary `Flux.params`. + +The major change to `train!` is that instead of `loss` being a function which typically accepts +two arguments (the input `x` and expected output `y` from each element of `data`) +now it should typically accept three, the first of which is the `model` itself. + +For example, with these definitions... +``` +data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple (or NamedTuple) + +loss(m, x, y) = Flux.crossentropy(m(x), y) # the model is the first argument + +opt = Flux.Adam() # now returns a FluxState +``` +...calling `train!(loss, model, data, opt)` runs a loop like this: +``` +for d in data + ∂L∂m = Zygote.gradient(loss, model, d...)[1] + # update the model using opt & ∂L∂m +end +``` +which evaluates the gradient of `loss(model, x1, y1)` with respect to `model`, +to know how to update the parameters stored within `model`. + +To change the package used to calculate gradients, enter `using Yota; Flux.@train_autodiff Yota` +to use [Yota.jl](https://github.com/dfdx/Yota.jl). The same command works with [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl). + +It is often convenient to provide the function `loss` using `do` block syntax, +instead of defining a named function: +``` +Flux.train!(model, Iterators.take(Iterators.cycle(data), 10), Flux.Adam()) do m, x, y + Flux.crossentropy(m(x), y) # this does not depend on global variables! +end +``` +Here `Iterators.take ∘ Iterators.cycle` uses the same `data` for 10 epochs. + +Callback functions are not supported. But see 3-argument `train!` for an +easy way to construct more complicated training loops. For example, this +adds printing & an early stop to the above: +``` +for (i, d) in enumerate(data) + x, y = d + ell = Flux.train!(model, opt) do m + Flux.crossentropy(m(x), y) + end + i%10==0 && println("on step \$i, the loss was \$l") # prints every 10th step + ell<0.1 && break # stops training +end +``` +""" +function train!(loss::Function, model, data, opt::FluxState) + _initialise!(opt, model) + losses = Float32[] + s = opt.state + s isa IdDict && error("""Can't mix explicit & implicit modes! + Once `FluxState` is initialised by `train!` in one mode, it cannot be used in the other.""") + # TODO check whether this loop ought to be in another function, for perfomance / type-stability. + for d in data + l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...) + s, model = Optimisers.update!(s, model, g) + push!(losses, l) + opt.state = s + end + return losses # Not entirely sure returning losses is a good idea. Flux 0.13 returns `nothing`. +end + +data_splat(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T + To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""") +data_splat(x::Tuple) = x +data_splat(x::NamedTuple) = x + +function _initialise!(opt::FluxState, model) + if opt.state isa Missing + opt.state = Optimisers.setup(opt.rule, model) + fmap(model, exclude = Optimisers.isnumeric) do x + Optimisers.maywrite(x) || error("""model must be fully mutable for train! to work, got x::$(typeof(x)) + If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""") + end + end + opt +end + +""" + train!(loss, model, opt) + +While the 4-argument method of `train!` iterates over a dataset, +calling `gradient` many times, this 3-argument version is for a single datapoint, +and calls `gradient` just once. + +Its expects a function `loss` which takes just one argument, the model. +For instance: +``` +opt = Flux.Adam() +train!(model, opt) do m # the model is explicitly passed to the function as `m` + Flux.crossentropy(m(x1), y1) # but the data point `(x1, y1)` is closed over. +end +``` +This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`. +(The `do` block is another syntax for this anonymous function.) +Then it updates the parameters contained within `model` according +to the chosen `opt`imiser. +Finally it returns the value of the loss function. + +To change the package used to calculate gradients, enter `using Yota; Flux.@train_autodiff Yota` +to use [Yota.jl](https://github.com/dfdx/Yota.jl). The same command works with [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl). +""" +function train!(loss::Function, model, opt::FluxState) + _initialise!(opt, model) + s = opt.state + s isa IdDict && error() + l, (g, _...) = explicit_withgradient(loss, model) + opt.state, model = Optimisers.update!(s, model, g) + l +end + +# This method lets you use Optimisers.Descent() instead of Flux.Descent(), when there is no state +function train!(loss::Function, model, data, rule::Optimisers.AbstractRule) + opt = FluxState(rule, missing) + _initialise!(opt, model) + @gensym warn_id + fmap(opt.state, exclude = x -> x isa Optimisers.Leaf) do leaf + leaf.state isa Nothing || @warn "Optimiser state will be discarded! Please wrap optimisation rule from Optimisers.jl in `FluxState`, e.g. by using `Flux.Adam()`" leaf maxlog=1 _id=warn_id + leaf + end + train!(loss, model, data, opt) +end diff --git a/src/train/implicit_train.jl b/src/train/implicit_train.jl new file mode 100644 index 0000000000..eb2068eaa0 --- /dev/null +++ b/src/train/implicit_train.jl @@ -0,0 +1,87 @@ +""" + train!(loss, pars::Params, data, opt::FluxState) + +Legacy method, mimicking the behaviour of Flux <= 0.13. +(Note that the implementation is different, using Optimisers.jl internally.) + +For each `d in data`, first the gradient of the `loss` is computed like this: +``` + gradient(() -> loss(d...), pars) # if d isa Tuple + gradient(() -> loss(d), pars) # otherwise +``` +Here `pars` is produced by calling [`Flux.params`](@ref) on your model. +This is Zygote's "implicit" style of parameter handling. + +Then, this gradient is used by optimizer `opt` to update the paramters: +``` + update!(opt, pars, grads) +``` +The `data` is iterated through once in this manner. + +Typically `data` contains tuples, like `data = [(x1, y1), (x2, y2), (x3, y3)]`. +In this case the function might be `loss(x, y) = mse(model(x), y)`, accepting two arguments. +Notice that it closes over the `model`, which is a global variable. +""" +function train!(loss::Function, pars::Params, data, opt::FluxState) + Base.depwarn("""`Flux.train!` accepting implicit `Params` is a legacy method in Flux 0.14. + Explicit parameters are now preferred, see `train!(loss, model, data, opt)`""", :train!, force=true) + _initialise!(opt, pars) + losses = Float32[] + for d in data + l, grads = Zygote.withgradient(() -> loss(batchmemaybe(d)...), pars) + _update!(opt, pars, grads) + push!(losses, l) + end + return losses +end + +batchmemaybe(x) = tuple(x) +batchmemaybe(x::Tuple) = x + +""" + train!(loss, pars::Params, opt::FluxState) + +This 3-arg method is a bit of a hybrid. With no `data` to iterate over, +it calls `gradient(() -> loss(), pars)` just once, then updates parameters. +""" +function train!(loss::Function, pars::Params, opt::FluxState) + Base.depwarn("""`Flux.train!` accepting implicit `Params` is a legacy method in Flux 0.14. + Explicit parameters are now preferred, see `train!(loss, model, data, opt)`""", :train!, force=true) + _initialise!(opt, pars) + l, grads = Zygote.withgradient(() -> loss(), pars) + _update!(opt, pars, grads) + return l +end + +function _initialise!(opt::FluxState, pars::Params) + dict = IdDict() + for p in pars + dict[p] = Optimisers.setup(opt.rule, p) + end + opt.state = dict +end + +""" + Flux.update!(opt::FluxState, ps::Params, gs) + +Legacy method, mimicking the behaviour of Flux <= 0.13. +""" +function update!(opt::FluxState, xs::Params, gs) + Base.depwarn("Flux.update! is a legacy function", :update!) + _initialise!(opt, xs) + _update!(opt, xs, gs) +end +# This _update! exists only so that train! above gives one depwarn, not two! +# ... and also to call _initialise! +function _update!(opt::FluxState, xs::Params, gs) + for x in xs + isnothing(gs[x]) && continue + update!(opt, x, gs[x]) + end +end + +function update!(opt::FluxState, x::AbstractArray, dx) + opt.state[x], xnew = Optimisers.update!(opt.state[x], x, dx) + xnew === x || error("failed to mutate x!") + nothing +end \ No newline at end of file diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 019f3fd603..32e40f4186 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -55,13 +55,13 @@ end bias = Conv((2, 2), 1=>3, bias = false); ip = zeros(Float32, 28,28,1,1) op = zeros(Float32, 27,27,3,1) .+ 2.f0 - opt = Descent() + opt = Flux.Descent() for _ = 1:10^3 gs = gradient(Flux.params(bias)) do Flux.Losses.mse(bias(ip), op) end - Flux.Optimise.update!(opt, params(bias), gs) + Flux.Optimise.update!(opt, Flux.params(bias), gs) end @test Flux.Losses.mse(bias(ip), op) ≈ 4.f0 @@ -168,7 +168,7 @@ end x = zeros(Float32, 5, 5, 2, 4) m = ConvTranspose((3,3), 2=>3) - @test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads + @test gradient(()->sum(m(x)), Flux.params(m)) isa Flux.Zygote.Grads # test ConvTranspose supports groups argument x = randn(Float32, 10, 10, 2, 3) @@ -178,7 +178,7 @@ end m2 = ConvTranspose((3,3), 2=>4, groups=2, pad=SamePad()) @test size(m2.weight) == (3,3,2,2) @test size(m1(x)) == size(m2(x)) - @test gradient(()->sum(m2(x)), params(m2)) isa Flux.Zygote.Grads + @test gradient(()->sum(m2(x)), Flux.params(m2)) isa Flux.Zygote.Grads x = randn(Float32, 10, 2,1) m = ConvTranspose((3,), 2=>4, pad=SamePad(), groups=2) diff --git a/test/optimise.jl b/test/optimise.jl deleted file mode 100644 index e922d3c0b8..0000000000 --- a/test/optimise.jl +++ /dev/null @@ -1,239 +0,0 @@ -using Flux.Optimise -using Flux.Optimise: runall -using Flux: Params, gradient -import FillArrays, ComponentArrays -using Test -using Random - -@testset "Optimise" begin - # Ensure rng has different state inside and outside the inner @testset - # so that w and w' are different - Random.seed!(84) - w = randn(10, 10) - @testset for opt in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), - NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), - Nesterov(), RMSProp(), Momentum()] - Random.seed!(42) - w′ = randn(10, 10) - b = false - loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) - for t = 1: 10^5 - θ = params([w′, b]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) - end - @test loss(rand(10, 10)) < 0.01 - end -end - -@testset "Optimiser" begin - Random.seed!(84) - w = randn(10, 10) - @testset for Opt in [InvDecay, WeightDecay, ExpDecay] - Random.seed!(42) - w′ = randn(10, 10) - loss(x) = Flux.Losses.mse(w*x, w′*x) - opt = Optimiser(Opt(), Adam(0.001)) - for t = 1:10^5 - θ = Params([w′]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) - end - @test loss(rand(10, 10)) < 0.01 - end -end - -@testset "Training Loop" begin - i = 0 - l = 1 - Flux.train!( - () -> (sleep(0.1); Flux.skip(); i+=1), - Params([]), - Iterators.repeated((), 10), - Descent() - ) - - @test i==0 #all skipped - - Flux.train!( - () -> (sleep(0.1); i==8 && Flux.skip(); i+=1), - Params([]), - Iterators.repeated((), 10), - Descent() - ) - - @test i==8 #skip after i hit 8 - - i = 0 - Flux.train!(() -> (sleep(0.1); i += 1; l), - Params([]), - Iterators.repeated((), 100), - Descent(), - cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) - - @test 3 < i < 50 - - # Test multiple callbacks - x = 0 - fs = [() -> (), () -> x = 1] - cbs = runall(fs) - cbs() - @test x == 1 - - r = rand(3, 3) - loss(x) = sum(x .* x) - Flux.train!(loss, Flux.params(r), (r,), Descent()) -end - -@testset "ExpDecay" begin - - @testset "Sanity Check" begin - o = ExpDecay(0.2, 0.5, 1, 1e-3) - p = [0.0] - steps = 1:8 - eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip) - eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] - @test eta_actual == eta_expected - end - - @testset "starting step" begin - start = 4 - o = ExpDecay(0.2, 0.5, 1, 1e-3, start) - p = [0.0] - steps = 1:8 - eta_expected = @. max(o.eta * 0.5 ^ max(steps - start, 0), o.clip) - eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] - @test eta_actual == eta_expected - end - - w = randn(10, 10) - o = ExpDecay(0.1, 0.1, 1000, 1e-4) - w1 = randn(10,10) - loss(x) = Flux.Losses.mse(w*x, w1*x) - flag = 1 - decay_steps = [] - for t = 1:10^5 - prev_eta = o.eta - θ = Params([w1]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - prev_grad = collect(θ̄[w1]) - delta = Optimise.apply!(o, w1, θ̄[w1]) - w1 .-= delta - new_eta = o.eta - if new_eta != prev_eta - push!(decay_steps, t) - end - array = fill(o.eta, size(prev_grad)) - if array .* prev_grad != delta - flag = 0 - end - end - @test flag == 1 - # Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). - ground_truth = [] - for i in 1:4 - push!(ground_truth, 1000*i) # Expected decay steps for this example. - end - @test decay_steps == ground_truth - @test o.eta == o.clip -end - -@testset "Clipping" begin - w = randn(10, 10) - loss(x) = sum(w * x) - θ = Params([w]) - x = 1000 * randn(10) - w̄ = gradient(() -> loss(x), θ)[w] - w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) - @test all(w̄_value .<= 1) - w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) - @test norm(w̄_norm) <= 1 -end - -@testset "update!: handle Fills from Zygote" begin - w = randn(10,10) - wold = copy(w) - g = FillArrays.Ones(size(w)) - opt = Descent(0.1) - Flux.update!(opt, w, g) - @test w ≈ wold .- 0.1 - - w = randn(3) - wold = copy(w) - θ = Flux.params([w]) - gs = gradient(() -> w[1], θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w[1] ≈ wold[1] .- 0.1 - @test w[2:3] ≈ wold[2:3] - - ## Issue #1510 - w = randn(10,10) - wold = copy(w) - θ = Flux.params([w]) - gs = gradient(() -> sum(w), θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w ≈ wold .- 0.1 -end - -@testset "update!: handle ComponentArrays" begin - w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) - wold = deepcopy(w) - θ = Flux.params([w]) - gs = gradient(() -> sum(w.a) + sum(w.c.b), θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w.a ≈ wold.a .- 0.1 - @test w.b ≈ wold.b - @test w.c.b ≈ wold.c.b .- 0.1 - @test w.c.a ≈ wold.c.a - - w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) - wold = deepcopy(w) - θ = Flux.params([w]) - gs = gradient(() -> sum(w), θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w ≈ wold .- 0.1 -end - -# Flux PR #1776 -# We need to test that optimisers like Adam that maintain an internal momentum -# estimate properly calculate the second-order statistics on the gradients as -# the flow backward through the model. Previously, we would calculate second- -# order statistics via `Δ^2` rather than the complex-aware `Δ * conj(Δ)`, which -# wreaks all sorts of havoc on our training loops. This test ensures that -# a simple optimization is montonically decreasing (up to learning step effects) -@testset "Momentum Optimisers and complex values" begin - # Test every optimizer that has momentum internally - for opt_ctor in [Adam, RMSProp, RAdam, OAdam, AdaGrad, AdaDelta, NAdam, AdaBelief] - # Our "model" is just a complex number - w = zeros(ComplexF32, 1) - - # Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x` - function loss() - # Deterministic training data is the best training data - x = ones(1, 1) + 1im*ones(1, 1) - - # Manually implement `mse()` to allow demonstration of brokenness - # on older Flux builds that don't have a fixed `mse()` - return sum(abs2.(w * x .- conj(x))) - end - - params = Flux.Params([w]) - opt = opt_ctor(1e-2) - - # Train for 10 iterations, enforcing that loss is monotonically decreasing - last_loss = Inf - for idx in 1:10 - grads = Flux.gradient(loss, params) - @test loss() < last_loss - last_loss = loss() - Flux.update!(opt, params, grads) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 706f126451..d9a5011879 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,8 +20,8 @@ Random.seed!(0) include("onehot.jl") end - @testset "Optimise" begin - include("optimise.jl") + @testset "Train" begin + include("train.jl") end @testset "Data" begin diff --git a/test/train.jl b/test/train.jl new file mode 100644 index 0000000000..8e15cd3e15 --- /dev/null +++ b/test/train.jl @@ -0,0 +1,165 @@ +using Flux.Train +using Zygote: Params, gradient + +import Optimisers, FillArrays, ComponentArrays, Yota + +using Test +using Random + +@testset "Implicit train!" begin # These tests pass on Flux v0.13 + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for opt in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), + NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), + Nesterov(), RMSProp(), Momentum()] + w′ = copy(w2) + b = zeros(10) + loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) + @test loss(rand(10, 10)) > 1 + Flux.train!(loss, Flux.params([w′, b]), (rand(10) for _ in 1: 10^5), opt) + @test loss(rand(10, 10)) < 0.01 + end +end + +@testset "Explicit train! with Zygote" begin + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for opt in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), + NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), + Nesterov(), RMSProp(), Momentum()] + @test opt isa FluxState + @test opt.state isa Missing + + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + + train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + @test opt.state isa NamedTuple + end + + # Test 3-arg `train!` method: + @testset for opt in [Descent(0.1), Adam(), AdamW()] + @test opt isa FluxState + @test opt.state isa Missing + + loss(m) = let x = rand(10) + Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + end + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model) > 1 + + for i in 1:10^5 + train!(loss, model, opt) + end + @test loss(model) < 0.01 + @test opt.state isa NamedTuple + end + + # Test direct use of Optimisers.jl rule, only really OK for `Descent`: + @testset for opt in [Optimisers.Descent(0.1), Optimisers.Adam()] + @test opt isa Optimisers.AbstractRule + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end +end + +using Yota +using Flux: Descent, Adam, AdamW, FluxState +Flux.@train_autodiff Yota + +@testset "Explicit train! with Yota" begin + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for opt in [Descent(0.1), Adam(), AdamW()] + @test opt isa FluxState + @test opt.state isa Missing + + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + + train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + @test opt.state isa NamedTuple + end + + # Test 3-arg `train!` method: + @testset for opt in [Descent(0.1), Adam(), AdamW()] + @test opt isa FluxState + @test opt.state isa Missing + + loss(m) = let x = rand(10) + Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + end + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model) > 1 + + for i in 1:10^5 + train!(loss, model, opt) + end + @test loss(model) < 0.01 + @test opt.state isa NamedTuple + end +end + +Flux.@train_autodiff Zygote + +#= + +@testset "update!: handle Fills from Zygote" begin + w = randn(10,10) + wold = copy(w) + g = FillArrays.Ones(size(w)) + opt = Descent(0.1) + Flux.update!(opt, w, g) + @test w ≈ wold .- 0.1 + + w = randn(3) + wold = copy(w) + θ = Flux.params([w]) + gs = gradient(() -> w[1], θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w[1] ≈ wold[1] .- 0.1 + @test w[2:3] ≈ wold[2:3] + + ## Issue #1510 + w = randn(10,10) + wold = copy(w) + θ = Flux.params([w]) + gs = gradient(() -> sum(w), θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w ≈ wold .- 0.1 +end + +@testset "update!: handle ComponentArrays" begin + w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) + wold = deepcopy(w) + θ = Flux.params([w]) + gs = gradient(() -> sum(w.a) + sum(w.c.b), θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w.a ≈ wold.a .- 0.1 + @test w.b ≈ wold.b + @test w.c.b ≈ wold.c.b .- 0.1 + @test w.c.a ≈ wold.c.a + + w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) + wold = deepcopy(w) + θ = Flux.params([w]) + gs = gradient(() -> sum(w), θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w ≈ wold .- 0.1 +end + +=# \ No newline at end of file