Skip to content

RFC: add a supertype to layers #2028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ include("functor.jl")
# Pirate error to catch a common mistake.
Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.")

include("layers/types.jl")
include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
Expand Down
37 changes: 10 additions & 27 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

"""
Chain(layers...)
Chain(name = layer, ...)
Expand Down Expand Up @@ -32,7 +33,7 @@ For large models, there is a special type-unstable path which can reduce compila
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
This feature is somewhat experimental, beware!
"""
struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}}
struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}} <: ContainerLayer
layers::T
end

Expand All @@ -46,8 +47,6 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex

@functor Chain

(c::Chain)(x) = _applychain(c.layers, x)

@generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
Expand Down Expand Up @@ -150,7 +149,7 @@ julia> Flux.params(d1) # no trainable bias
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
```
"""
struct Dense{F, M<:AbstractMatrix, B}
struct Dense{F, M<:AbstractMatrix, B} <: SimpleLayer
weight::M
bias::B
σ::F
Expand All @@ -165,8 +164,6 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity;
Dense(init(out, in), bias, σ)
end

@functor Dense

function (a::Dense)(x::AbstractVecOrMat)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ.(a.weight * x .+ a.bias)
Expand Down Expand Up @@ -223,7 +220,7 @@ julia> Flux.params(b)
Params([[1 2 3 4]])
```
"""
struct Scale{F, A<:AbstractArray, B}
struct Scale{F, A<:AbstractArray, B} <: SimpleLayer
scale::A
bias::B
σ::F
Expand All @@ -236,8 +233,6 @@ end
Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act)
Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end])

@functor Scale

function (a::Scale)(x::AbstractArray)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
σ.(a.scale .* x .+ a.bias)
Expand Down Expand Up @@ -285,14 +280,12 @@ julia> Flux.outputsize(m3, (5, 11))
(7, 11)
```
"""
struct Maxout{T<:Tuple}
struct Maxout{T<:Tuple} <: ContainerLayer
layers::T
end
Maxout(layers...) = Maxout(layers)
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)

@functor Maxout

function (mo::Maxout)(input::AbstractArray)
# Perhaps surprisingly, pairwise max broadcast is often faster,
# even with Zygote. See #698 and #1794
Expand Down Expand Up @@ -333,13 +326,11 @@ true

See also [`Parallel`](@ref), [`Maxout`](@ref).
"""
struct SkipConnection{T,F}
struct SkipConnection{T,F} <: ContainerLayer
layers::T
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@functor SkipConnection

function (skip::SkipConnection)(input)
skip.connection(skip.layers(input), input)
end
Expand Down Expand Up @@ -397,7 +388,7 @@ julia> Flux.Bilinear(rand(4,8,16), false, tanh) # first dim of weight is the ou
Bilinear((8, 16) => 4, tanh; bias=false) # 512 parameters
```
"""
struct Bilinear{F,A,B}
struct Bilinear{F,A,B} <: SimpleLayer
weight::A
bias::B
σ::F
Expand All @@ -408,8 +399,6 @@ struct Bilinear{F,A,B}
end
end

@functor Bilinear

function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity;
bias = true, init = glorot_uniform)
Bilinear(init(out, in1, in2), bias, σ)
Expand Down Expand Up @@ -492,7 +481,7 @@ julia> model2[:β] == model2[2]
true
```
"""
struct Parallel{F, T<:Union{Tuple, NamedTuple}}
struct Parallel{F, T<:Union{Tuple, NamedTuple}} <: ContainerLayer
connection::F
layers::T
end
Expand All @@ -507,8 +496,6 @@ function Parallel(connection; kw...)
Parallel(connection, layers)
end

@functor Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)

Expand Down Expand Up @@ -582,7 +569,7 @@ end

A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
"""
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}} <: ContainerLayer
connection::F
layers::T
end
Expand Down Expand Up @@ -628,8 +615,6 @@ end
end
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)

@functor PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
Expand Down Expand Up @@ -672,12 +657,10 @@ julia> model(vocab_idxs) == model(x)
true
```
"""
struct Embedding{W}
struct Embedding{W} <: SimpleLayer
weight::W
end

@functor Embedding

Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))

(m::Embedding)(x::Integer) = m.weight[:, x]
Expand Down
12 changes: 3 additions & 9 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ julia> Conv((5,5), 3 => 7; stride = 2, dilation = 4)(xs) |> size
(42, 42, 7, 50)
```
"""
struct Conv{N,M,F,A,V}
struct Conv{N,M,F,A,V} <: SimpleLayer
σ::F
weight::A
bias::V
Expand Down Expand Up @@ -187,8 +187,6 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init(filter..., cin÷groups, cout)
end

@functor Conv

conv_dims(c::Conv, x::AbstractArray) =
DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)

Expand Down Expand Up @@ -252,7 +250,7 @@ julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size
(300, 300, 7, 50)
```
"""
struct ConvTranspose{N,M,F,A,V}
struct ConvTranspose{N,M,F,A,V} <: SimpleLayer
σ::F
weight::A
bias::V
Expand Down Expand Up @@ -307,8 +305,6 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
end

@functor ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
Expand Down Expand Up @@ -407,7 +403,7 @@ julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size
(34, 32, 7, 50)
```
"""
struct CrossCor{N,M,F,A,V}
struct CrossCor{N,M,F,A,V} <: SimpleLayer
σ::F
weight::A
bias::V
Expand Down Expand Up @@ -453,8 +449,6 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden
return CrossCor(weight, bias, σ; stride, pad, dilation)
end

@functor CrossCor

function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
return conv(x, w, ddims)
Expand Down
29 changes: 6 additions & 23 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ julia> isapprox(count(==(0), y) / length(y), 0.5, atol=0.1)
true
```
"""
mutable struct Dropout{F,D,R<:AbstractRNG}
mutable struct Dropout{F,D,R<:AbstractRNG} <: NoTrainLayer
p::F
dims::D
active::Union{Bool, Nothing}
Expand All @@ -103,9 +103,6 @@ function Dropout(p; dims=:, rng = rng_from_array())
Dropout(p, dims, nothing, rng)
end

@functor Dropout
trainable(a::Dropout) = (;)

function (a::Dropout)(x)
_isactive(a) || return x
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
Expand Down Expand Up @@ -146,7 +143,7 @@ julia> isapprox(std(x), std(y), atol=0.2)
true
```
"""
mutable struct AlphaDropout{F,R<:AbstractRNG}
mutable struct AlphaDropout{F,R<:AbstractRNG} <: NoTrainLayer
p::F
active::Union{Bool, Nothing}
rng::R
Expand All @@ -158,9 +155,6 @@ end
AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)

@functor AlphaDropout
trainable(a::AlphaDropout) = (;)

function (a::AlphaDropout)(x::AbstractArray{T}) where T
_isactive(a) || return x
p = a.p
Expand Down Expand Up @@ -209,7 +203,7 @@ julia> isapprox(std(y, dims=1:3), ones(1, 1, 1, 2), atol=0.1) && std(y, dims=1:3
true
```
"""
struct LayerNorm{F,D,T,N}
struct LayerNorm{F,D,T,N} <: PartialTrainLayer{(:diag,)}
λ::F
diag::D
ϵ::T
Expand All @@ -224,8 +218,6 @@ end
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)

@functor LayerNorm

(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))

function Base.show(io::IO, l::LayerNorm)
Expand Down Expand Up @@ -322,7 +314,7 @@ julia> isapprox(std(m(xs)), 1, atol=0.1) && std(xs) != std(m(xs))
true
```
"""
mutable struct BatchNorm{F,V,N,W}
mutable struct BatchNorm{F,V,N,W} <: PartialTrainLayer{(:β, :γ)}
λ::F # activation function
β::V # bias
γ::V # scale
Expand Down Expand Up @@ -352,9 +344,6 @@ function BatchNorm(chs::Int, λ=identity;
nothing, chs)
end

@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
Comment on lines -355 to -356
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here PartialTrainLayer{(:β, :γ)} fixes which fields are trainable, permanently. That's OK, since β = nothing when it's not trainable, so it'll be ignored. It improves type-stability although I doubt this matters at all.


function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
N = ndims(x)
Expand Down Expand Up @@ -412,7 +401,7 @@ julia> isapprox(std(y, dims=1:2), ones(1, 1, 3, 2), atol=0.2) && std(y, dims=1:2
true
```
"""
mutable struct InstanceNorm{F,V,N,W}
mutable struct InstanceNorm{F,V,N,W} <: PartialTrainLayer{(:β, :γ)}
λ::F # activation function
β::V # bias
γ::V # scale
Expand Down Expand Up @@ -442,9 +431,6 @@ function InstanceNorm(chs::Int, λ=identity;
nothing, chs)
end

@functor InstanceNorm
trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)

function (l::InstanceNorm)(x)
@assert ndims(x) > 2
@assert size(x, ndims(x)-1) == l.chs
Expand Down Expand Up @@ -506,7 +492,7 @@ julia> isapprox(std(y[:, :, 3:4, 2]), 1, atol=0.1) && std(xs[:, :, 3:4, 2]) != s
true
```
"""
mutable struct GroupNorm{F,V,N,W}
mutable struct GroupNorm{F,V,N,W} <: PartialTrainLayer{(:β, :γ)}
G::Int # number of groups
λ::F # activation function
β::V # bias
Expand All @@ -521,9 +507,6 @@ mutable struct GroupNorm{F,V,N,W}
chs::Int # number of channels
end

@functor GroupNorm
trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
Expand Down
Loading