diff --git a/src/Flux.jl b/src/Flux.jl index 0cacbd419a..2f8b13818f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -44,6 +44,7 @@ include("functor.jl") # Pirate error to catch a common mistake. Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.") +include("layers/types.jl") include("layers/stateless.jl") include("layers/basic.jl") include("layers/conv.jl") diff --git a/src/layers/basic.jl b/src/layers/basic.jl index dde49bd8fb..deb465d262 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -1,3 +1,4 @@ + """ Chain(layers...) Chain(name = layer, ...) @@ -32,7 +33,7 @@ For large models, there is a special type-unstable path which can reduce compila times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`. This feature is somewhat experimental, beware! """ -struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}} +struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}} <: ContainerLayer layers::T end @@ -46,8 +47,6 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys, Base.firstindex -@functor Chain - (c::Chain)(x) = _applychain(c.layers, x) @generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N} @@ -150,7 +149,7 @@ julia> Flux.params(d1) # no trainable bias Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]]) ``` """ -struct Dense{F, M<:AbstractMatrix, B} +struct Dense{F, M<:AbstractMatrix, B} <: SimpleLayer weight::M bias::B σ::F @@ -165,8 +164,6 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity; Dense(init(out, in), bias, σ) end -@functor Dense - function (a::Dense)(x::AbstractVecOrMat) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc return σ.(a.weight * x .+ a.bias) @@ -223,7 +220,7 @@ julia> Flux.params(b) Params([[1 2 3 4]]) ``` """ -struct Scale{F, A<:AbstractArray, B} +struct Scale{F, A<:AbstractArray, B} <: SimpleLayer scale::A bias::B σ::F @@ -236,8 +233,6 @@ end Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act) Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end]) -@functor Scale - function (a::Scale)(x::AbstractArray) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc σ.(a.scale .* x .+ a.bias) @@ -285,14 +280,12 @@ julia> Flux.outputsize(m3, (5, 11)) (7, 11) ``` """ -struct Maxout{T<:Tuple} +struct Maxout{T<:Tuple} <: ContainerLayer layers::T end Maxout(layers...) = Maxout(layers) Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...) -@functor Maxout - function (mo::Maxout)(input::AbstractArray) # Perhaps surprisingly, pairwise max broadcast is often faster, # even with Zygote. See #698 and #1794 @@ -333,13 +326,11 @@ true See also [`Parallel`](@ref), [`Maxout`](@ref). """ -struct SkipConnection{T,F} +struct SkipConnection{T,F} <: ContainerLayer layers::T connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b end -@functor SkipConnection - function (skip::SkipConnection)(input) skip.connection(skip.layers(input), input) end @@ -397,7 +388,7 @@ julia> Flux.Bilinear(rand(4,8,16), false, tanh) # first dim of weight is the ou Bilinear((8, 16) => 4, tanh; bias=false) # 512 parameters ``` """ -struct Bilinear{F,A,B} +struct Bilinear{F,A,B} <: SimpleLayer weight::A bias::B σ::F @@ -408,8 +399,6 @@ struct Bilinear{F,A,B} end end -@functor Bilinear - function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity; bias = true, init = glorot_uniform) Bilinear(init(out, in1, in2), bias, σ) @@ -492,7 +481,7 @@ julia> model2[:β] == model2[2] true ``` """ -struct Parallel{F, T<:Union{Tuple, NamedTuple}} +struct Parallel{F, T<:Union{Tuple, NamedTuple}} <: ContainerLayer connection::F layers::T end @@ -507,8 +496,6 @@ function Parallel(connection; kw...) Parallel(connection, layers) end -@functor Parallel - (m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) (m::Parallel)(xs::Tuple) = m(xs...) @@ -582,7 +569,7 @@ end A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above). """ -struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}} +struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}} <: ContainerLayer connection::F layers::T end @@ -628,8 +615,6 @@ end end applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x) -@functor PairwiseFusion - Base.getindex(m::PairwiseFusion, i) = m.layers[i] Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i]) Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) = @@ -672,12 +657,10 @@ julia> model(vocab_idxs) == model(x) true ``` """ -struct Embedding{W} +struct Embedding{W} <: SimpleLayer weight::W end -@functor Embedding - Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in)) (m::Embedding)(x::Integer) = m.weight[:, x] diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 003395c15d..07e1327a83 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -117,7 +117,7 @@ julia> Conv((5,5), 3 => 7; stride = 2, dilation = 4)(xs) |> size (42, 42, 7, 50) ``` """ -struct Conv{N,M,F,A,V} +struct Conv{N,M,F,A,V} <: SimpleLayer σ::F weight::A bias::V @@ -187,8 +187,6 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init(filter..., cin÷groups, cout) end -@functor Conv - conv_dims(c::Conv, x::AbstractArray) = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) @@ -252,7 +250,7 @@ julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size (300, 300, 7, 50) ``` """ -struct ConvTranspose{N,M,F,A,V} +struct ConvTranspose{N,M,F,A,V} <: SimpleLayer σ::F weight::A bias::V @@ -307,8 +305,6 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = ConvTranspose(weight, bias, σ; stride, pad, dilation, groups) end -@functor ConvTranspose - function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) # Calculate size of "input", from ∇conv_data()'s perspective... combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end]) @@ -407,7 +403,7 @@ julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size (34, 32, 7, 50) ``` """ -struct CrossCor{N,M,F,A,V} +struct CrossCor{N,M,F,A,V} <: SimpleLayer σ::F weight::A bias::V @@ -453,8 +449,6 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden return CrossCor(weight, bias, σ; stride, pad, dilation) end -@functor CrossCor - function crosscor(x, w, ddims::DenseConvDims) ddims = DenseConvDims(ddims, F=true) return conv(x, w, ddims) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index f1f6c22033..7a9298a375 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -90,7 +90,7 @@ julia> isapprox(count(==(0), y) / length(y), 0.5, atol=0.1) true ``` """ -mutable struct Dropout{F,D,R<:AbstractRNG} +mutable struct Dropout{F,D,R<:AbstractRNG} <: NoTrainLayer p::F dims::D active::Union{Bool, Nothing} @@ -103,9 +103,6 @@ function Dropout(p; dims=:, rng = rng_from_array()) Dropout(p, dims, nothing, rng) end -@functor Dropout -trainable(a::Dropout) = (;) - function (a::Dropout)(x) _isactive(a) || return x return dropout(a.rng, x, a.p; dims=a.dims, active=true) @@ -146,7 +143,7 @@ julia> isapprox(std(x), std(y), atol=0.2) true ``` """ -mutable struct AlphaDropout{F,R<:AbstractRNG} +mutable struct AlphaDropout{F,R<:AbstractRNG} <: NoTrainLayer p::F active::Union{Bool, Nothing} rng::R @@ -158,9 +155,6 @@ end AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array()) AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng) -@functor AlphaDropout -trainable(a::AlphaDropout) = (;) - function (a::AlphaDropout)(x::AbstractArray{T}) where T _isactive(a) || return x p = a.p @@ -209,7 +203,7 @@ julia> isapprox(std(y, dims=1:3), ones(1, 1, 1, 2), atol=0.1) && std(y, dims=1:3 true ``` """ -struct LayerNorm{F,D,T,N} +struct LayerNorm{F,D,T,N} <: PartialTrainLayer{(:diag,)} λ::F diag::D ϵ::T @@ -224,8 +218,6 @@ end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) -@functor LayerNorm - (a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ)) function Base.show(io::IO, l::LayerNorm) @@ -322,7 +314,7 @@ julia> isapprox(std(m(xs)), 1, atol=0.1) && std(xs) != std(m(xs)) true ``` """ -mutable struct BatchNorm{F,V,N,W} +mutable struct BatchNorm{F,V,N,W} <: PartialTrainLayer{(:β, :γ)} λ::F # activation function β::V # bias γ::V # scale @@ -352,9 +344,6 @@ function BatchNorm(chs::Int, λ=identity; nothing, chs) end -@functor BatchNorm -trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) - function (BN::BatchNorm)(x) @assert size(x, ndims(x)-1) == BN.chs N = ndims(x) @@ -412,7 +401,7 @@ julia> isapprox(std(y, dims=1:2), ones(1, 1, 3, 2), atol=0.2) && std(y, dims=1:2 true ``` """ -mutable struct InstanceNorm{F,V,N,W} +mutable struct InstanceNorm{F,V,N,W} <: PartialTrainLayer{(:β, :γ)} λ::F # activation function β::V # bias γ::V # scale @@ -442,9 +431,6 @@ function InstanceNorm(chs::Int, λ=identity; nothing, chs) end -@functor InstanceNorm -trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;) - function (l::InstanceNorm)(x) @assert ndims(x) > 2 @assert size(x, ndims(x)-1) == l.chs @@ -506,7 +492,7 @@ julia> isapprox(std(y[:, :, 3:4, 2]), 1, atol=0.1) && std(xs[:, :, 3:4, 2]) != s true ``` """ -mutable struct GroupNorm{F,V,N,W} +mutable struct GroupNorm{F,V,N,W} <: PartialTrainLayer{(:β, :γ)} G::Int # number of groups λ::F # activation function β::V # bias @@ -521,9 +507,6 @@ mutable struct GroupNorm{F,V,N,W} chs::Int # number of channels end -@functor GroupNorm -trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;) - function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, affine=true, track_stats=false, diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 760933bb96..d0150a862a 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -125,7 +125,7 @@ julia> rnn.state 60 ``` """ -mutable struct Recur{T,S} +mutable struct Recur{T,S} <: ContainerLayer cell::T state::S end @@ -135,8 +135,7 @@ function (m::Recur)(x) return y end -@functor Recur -trainable(a::Recur) = (; cell = a.cell) +trainable(a::Recur) = (; cell = a.cell) # can't use <: PartialTrainLayer Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") @@ -189,7 +188,7 @@ end # Vanilla RNN -struct RNNCell{F,A,V,S} +struct RNNCell{F,A,V,S} <: SimpleLayer # or should it be PartialTrainLayer{(:Wi, :Wh, :b)}? σ::F Wi::A Wh::A @@ -207,8 +206,6 @@ function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T} return h, reshape_cell_output(h, x) end -@functor RNNCell - function Base.show(io::IO, l::RNNCell) print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)) l.σ == identity || print(io, ", ", l.σ) @@ -277,7 +274,7 @@ Recur(m::RNNCell) = Recur(m, m.state0) # LSTM -struct LSTMCell{A,V,S} +struct LSTMCell{A,V,S} <: SimpleLayer # or should it be PartialTrainLayer{(:Wi, :Wh, :b)}? Wi::A Wh::A b::V @@ -302,8 +299,6 @@ function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{Abstr return (h′, c′), reshape_cell_output(h′, x) end -@functor LSTMCell - Base.show(io::IO, l::LSTMCell) = print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")") @@ -351,7 +346,7 @@ function _gru_output(gxs, ghs, bs) return r, z end -struct GRUCell{A,V,S} +struct GRUCell{A,V,S} <: SimpleLayer # or should it be PartialTrainLayer{(:Wi, :Wh, :b)}? Wi::A Wh::A b::V @@ -370,8 +365,6 @@ function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},O return h′, reshape_cell_output(h′, x) end -@functor GRUCell - Base.show(io::IO, l::GRUCell) = print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") @@ -435,8 +428,6 @@ function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T} return h′, reshape_cell_output(h′, x) end -@functor GRUv3Cell - Base.show(io::IO, l::GRUv3Cell) = print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") diff --git a/src/layers/show.jl b/src/layers/show.jl index 5fb9991504..af6e0ef4f9 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,22 +1,20 @@ -for T in [ - :Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL - _big_show(io, x) - elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix - _layer_show(io, x) - else - show(io, x) - end +@nospecialize + +function Base.show(io::IO, m::MIME"text/plain", x::ContainerLayer) + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + _big_show(io, x) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + _layer_show(io, x) + else + show(io, x) end end function _big_show(io::IO, obj, indent::Int=0, name=nothing) pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") children = _show_children(obj) - if all(_show_leaflike, children) + if all(_show_leaflike, children) # or else if obj isa SimpleLayer, via dispatch below _layer_show(io, obj, indent, name) else println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre) @@ -43,6 +41,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end end +_big_show(io::IO, obj::SimpleLayer, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name) _show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: _show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv @@ -55,16 +54,11 @@ _show_children(m::Maxout) = m.layers _show_children(p::Parallel) = (p.connection, p.layers...) _show_children(f::PairwiseFusion) = (f.connection, f.layers...) -for T in [ - :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, - :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if !get(io, :compact, false) - _layer_show(io, x) - else - show(io, x) - end +function Base.show(io::IO, m::MIME"text/plain", x::SimpleLayer) + if !get(io, :compact, false) + _layer_show(io, x) + else + show(io, x) end end @@ -128,3 +122,5 @@ _any(f, x::Number) = f(x) # _any(f, x) = false _all(f, xs) = !_any(!f, xs) + +@specialize diff --git a/src/layers/types.jl b/src/layers/types.jl new file mode 100644 index 0000000000..97eb549be0 --- /dev/null +++ b/src/layers/types.jl @@ -0,0 +1,91 @@ +import Adapt, Functors, Optimisers + +""" + Flux.AbstractLayer + +Supertype for all of Flux's built-in layers. + +Layer types are not essential to use your own `struct` with Flux. +But they do simplify some common interactions: + +* Any `l::AbstractLayer` has a method for `Functors.functor`, + thus you need not invoke `@functor`. Note that your `struct` should + have the default constructor. (Or something which similarly accepts all of + its fields as arguments. It is simplest not to write an inner constructor.) + +* Calling `Adapt.adapt` on any `l::AbstractLayer` will recurse using `Functors.fmap`, + ensuring that the identification between tied weights is perserved. + +* Subtypeing `PartialTrainLayer` marks only some fields as trainable, + by overloading `Optimisers.trainable`. + +* Some subtypes tell fancy `show` whether to unfold their contents: + `l::ContainerLayer` behaves like `Chain`, while `l::SimpleLayer` behaves like `Dense`. +""" +abstract type AbstractLayer end + +function Functors.functor(::Type{T}, x) where {T<:AbstractLayer} + if @generated + F = fieldnames(T) + args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) + C = Base.typename(T).name # constructor + recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) + :((NamedTuple{$F}(($(args...),)), $recon)) + else + # Getting this parameterless type takes about 2μs, every time: + namedtuple(x), Base.splat(Base.typename(T).wrapper) + end +end + +function namedtuple(x::T) where T + F = fieldnames(T) + NamedTuple{F}(map(sy -> getfield(x, sy), F)) +end + +Adapt.adapt_structure(to, layer::AbstractLayer) = fmap(x -> adapt(to, x), layer) + +""" + Flux.ContainerLayer <: AbstractLayer + +Supertype for layers such as `Chain` & `Parallel`. Not essential to Flux's functioning, +but tells `show` to unfold the contents when this is the outermost struct. +And (like any `AbstractLayer`) removes the need for `@functor`. +""" +abstract type ContainerLayer <: AbstractLayer end +""" + Flux.SimpleLayer <: AbstractLayer + +Supertype for layers such as `Dense` & `Conv`. Not essential to Flux's functioning, +but tells `show` how to behave. And (like any `AbstractLayer`) removes the need for `@functor`. +""" +abstract type SimpleLayer <: AbstractLayer end + +""" + Flux.PartialTrainLayer{which} <: SimpleLayer <: AbstractLayer + +Supertype for layers such as `BatchNorm` which contain arrays of numbers +which are not to be optimised during training. + +`which` is a tuple of `Symbol`s, indicating the fields of the struct +that that *do* contain trainable parameters. This is used by a method of +`Optimisers.trainable`, instead of writing that yourself. + +Note that some fields (such as functions, integers, `nothing`) are never +trainable, and do not need special attention. `Optimisers.trainable` is needed +only to shield types which would otherwise be trainable, such as arrays of floats. + +Also (like any `AbstractLayer`) removes the need for `@functor`, +and (like `SimpleLayer`) tells `show` not to unfold further. +""" +abstract type PartialTrainLayer{which} <: SimpleLayer end + +function Optimisers.trainable(layer::PartialTrainLayer{which}) where {which} + NamedTuple{which}(map(sy -> getfield(layer, sy), which)) +end + +""" + Flux.NoTrainLayer <: SimpleLayer <: AbstractLayer + +Supertype for layers which contain no trainable parameters. +""" +const NoTrainLayer = PartialTrainLayer{(;)}