From 7ca3449bf2a4768d6313f1d714380e6ff50e57c9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 6 Apr 2022 19:46:52 -0400 Subject: [PATCH 1/7] add big_show macro --- src/layers/show.jl | 63 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 11 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index 0ae14dd9ee..fe7fb537a1 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,18 +1,59 @@ +""" + @big_show MyContainer -for T in [ - :Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL - _big_show(io, x) - elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix - _layer_show(io, x) - else - show(io, x) +This macro lets you opt-in to Flux's fancy printing. + +When `model::MyContainer` is returned at the REPL it will be treated like `Chain`, +and the printing routine will recursively unfold its children. +This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`. + +Custom layers which do not contain other layers (more like `Dense` than like `Chain`) +need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`. + +# Example +```jldoctest +julia> struct Trio{A,B,C}; a::A; b::B; c::C end + +julia> Flux.@functor Trio + +julia> Flux.@big_show Trio + +julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax) +Trio( + Dense(10 => 5, tanh), # 55 parameters + Dense(5 => 2), # 12 parameters + NNlib.softmax, +) # Total: 4 arrays, 67 parameters, 492 bytes. +``` + +Note that there is no automatic method for 2-arg `show`, and thus +something like `(tri, tri)` will print all the type parameters. + +However, `Chain(tri, tri)` will always use Flux's recursive printing, +even without using this macro: `Chain` is the entry point. +""" +macro big_show(ex) + ex isa Symbol || error("usage is `Flux.@big_show Chain`") + eex = esc(ex) + quote + function Base.show(io::IO, m::MIME"text/plain", x::$eex) + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + _big_show(io, x) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + _layer_show(io, x) + else + show(io, x) + end + end end - end end +@big_show Chain +@big_show Parallel +@big_show SkipConnection +@big_show Recur +@big_show Maxout + function _big_show(io::IO, obj, indent::Int=0, name=nothing) pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") children = _show_children(obj) From 3ab088b4fb55fb683246785fa87b99f60cef6bbb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 23 Aug 2022 16:49:06 -0400 Subject: [PATCH 2/7] upgrade to at-layer macro, replaces at-functor fixup tidy up, add NEWS review suggestions macro docstring, incl. hcat(3.3) --- NEWS.md | 5 ++ docs/src/models/advanced.md | 32 +++++-- docs/src/models/basics.md | 9 +- src/Flux.jl | 5 +- src/functor.jl | 1 + src/layers/basic.jl | 18 ++-- src/layers/conv.jl | 6 +- src/layers/macro.jl | 173 ++++++++++++++++++++++++++++++++++++ src/layers/normalise.jl | 19 ++-- src/layers/recurrent.jl | 11 ++- src/layers/show.jl | 115 ++++++++++-------------- test/layers/macro.jl | 45 ++++++++++ test/runtests.jl | 1 + test/utils.jl | 2 +- 14 files changed, 331 insertions(+), 111 deletions(-) create mode 100644 src/layers/macro.jl create mode 100644 test/layers/macro.jl diff --git a/NEWS.md b/NEWS.md index ac8883a091..9eb544a6f6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,10 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.14.13 +* New macro `Flux.@layer` which should be used in place of `@functor`. + This also adds `show` methods for pretty printing. + ## v0.14.0 (July 2023) * Flux now requires julia v1.9 or later. * CUDA.jl is not a hard dependency anymore. Support is now provided through the extension mechanism, by loading `using Flux, CUDA`. @@ -51,6 +55,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl ## v0.13.6 * Use the package [OneHotArrays.jl](https://github.com/FluxML/OneHotArrays.jl) instead of having the same code here. +* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078) ## v0.13.4 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983) diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index b7161b8c59..531d176cb5 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -18,8 +18,8 @@ function (m::CustomModel)(x) return m.chain(x) + x end -# Call @functor to allow for training. Described below in more detail. -Flux.@functor CustomModel +# Call @layer to allow for training. Described below in more detail. +Flux.@layer CustomModel ``` You can then use the model like: @@ -39,7 +39,7 @@ Taking reference from our example `Affine` layer from the [basics](@ref man-basi By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, the way to mark some fields of our layer as trainable is through overloading the `trainable` function: ```julia-repl -julia> Flux.@functor Affine +julia> @layer Affine julia> a = Affine(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]) Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]) @@ -47,7 +47,7 @@ Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]) julia> Flux.params(a) # default behavior Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]]) -julia> Flux.trainable(a::Affine) = (; a.W) # returns a NamedTuple using the field's name +julia> Flux.trainable(a::Affine) = (; W = a.W) # returns a NamedTuple using the field's name julia> Flux.params(a) Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0]]) @@ -67,7 +67,21 @@ julia> Flux.params(Affine(true, [10, 11, 12.0])) Params([]) ``` -It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired). +The exact same method of `trainable` can also be defined using the macro, for convenience: + +```julia +Flux.@layer Affine trainable=(W,) +``` + +There is a second, more severe, kind of restriction possible: + +``` +Flux.@layer Affine children=(W,) +``` + +This is equivalent to `Functors.@functor Affine (W,)`. It means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This is not usually recommended. + This is generally not recommended. It requires the `struct` to have a corresponding constructor that accepts only `W` as an argument. + ## Freezing Layer Parameters @@ -135,9 +149,9 @@ Join(combine, paths...) = Join(combine, paths) ``` Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field. -The next step is to use [`Functors.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path. +The next step is to use [`Functors.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path. ```julia -Flux.@functor Join +Flux.@layer Join ``` Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results. @@ -194,7 +208,7 @@ model(xs) Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs. -We start by following the same steps as the `Join` layer: define a struct, use [`Functors.@functor`](@ref), and define the forward pass. +We start by following the same steps as the `Join` layer: define a struct, use [`@layer`](@ref), and define the forward pass. ```julia using Flux using CUDA @@ -206,7 +220,7 @@ end Split(paths...) = Split(paths) -Flux.@functor Split +Flux.@layer Split (m::Split)(x::AbstractArray) = map(f -> f(x), m.paths) ``` diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index ca95dc747d..fb0f2d5488 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -257,8 +257,8 @@ m(5) # => 26 There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@functor`](@ref Functors.@functor) macro: -``` -Flux.@functor Affine +```julia +Flux.@layer Affine ``` Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the `Affine` layer as follows, using the helper function [`create_bias`](@ref Flux.create_bias): @@ -272,3 +272,8 @@ end Affine(3 => 1, bias=false, init=ones) |> gpu ``` + +```@docs +Flux.@layer +Flux.create_bias +``` diff --git a/src/Flux.jl b/src/Flux.jl index d3ca611dbd..5675f7c10f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,6 +9,7 @@ using MacroTools: @forward @reexport using NNlib using MLUtils +const stack = MLUtils.stack # now exported by Base import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions using Optimisers: freeze!, thaw!, adjust! using Random: default_rng @@ -69,6 +70,9 @@ include("functor.jl") # Pirate error to catch a common mistake. Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.") +include("layers/show.jl") +include("layers/macro.jl") + include("layers/stateless.jl") include("layers/basic.jl") include("layers/conv.jl") @@ -76,7 +80,6 @@ include("layers/recurrent.jl") include("layers/normalise.jl") include("layers/upsample.jl") include("layers/attention.jl") -include("layers/show.jl") include("loading.jl") diff --git a/src/functor.jl b/src/functor.jl index 2c8f3360db..e0329ae313 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -81,6 +81,7 @@ function params!(p::Params, x, seen = IdSet()) elseif x in seen nothing else + _check_new_macro(x) # complains if you used @functor not @layer push!(seen, x) for child in trainable(x) params!(p, child, seen) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index b7027f5007..018b19b31d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -46,7 +46,7 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys, Base.firstindex -@functor Chain +@layer :expand Chain # the + opts-in to container-style pretty-printing (c::Chain)(x) = _applychain(c.layers, x) @@ -165,7 +165,7 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity; Dense(init(out, in), bias, σ) end -@functor Dense +@layer Dense function (a::Dense)(x::AbstractVecOrMat) _size_check(a, x, 1 => size(a.weight, 2)) @@ -251,7 +251,7 @@ end Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act) Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end]) -@functor Scale +@layer Scale function (a::Scale)(x::AbstractArray) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc @@ -306,7 +306,7 @@ end Maxout(layers...) = Maxout(layers) Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...) -@functor Maxout +@layer :expand Maxout function (mo::Maxout)(input::AbstractArray) # Perhaps surprisingly, pairwise max broadcast is often faster, @@ -353,7 +353,7 @@ struct SkipConnection{T,F} connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b end -@functor SkipConnection +@layer :expand SkipConnection function (skip::SkipConnection)(input) skip.connection(skip.layers(input), input) @@ -423,7 +423,7 @@ struct Bilinear{F,A,B} end end -@functor Bilinear +@layer Bilinear function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity; bias = true, init = glorot_uniform) @@ -522,7 +522,7 @@ function Parallel(connection; kw...) Parallel(connection, layers) end -@functor Parallel +@layer :expand Parallel (m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) (m::Parallel)(xs::Tuple) = m(xs...) @@ -643,7 +643,7 @@ end end applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x) -@functor PairwiseFusion +@layer :expand PairwiseFusion Base.getindex(m::PairwiseFusion, i) = m.layers[i] Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i]) @@ -701,7 +701,7 @@ struct Embedding{W<:AbstractMatrix} weight::W end -@functor Embedding +@layer Embedding Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in)) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ca275d4a16..4e6044dcfb 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -187,7 +187,7 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init(filter..., cin÷groups, cout) end -@functor Conv +@layer Conv conv_dims(c::Conv, x::AbstractArray) = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) @@ -309,7 +309,7 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = ConvTranspose(weight, bias, σ; stride, pad, dilation, groups) end -@functor ConvTranspose +@layer ConvTranspose function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) # Calculate size of "input", from ∇conv_data()'s perspective... @@ -460,7 +460,7 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden return CrossCor(weight, bias, σ; stride, pad, dilation) end -@functor CrossCor +@layer CrossCor function crosscor(x, w, ddims::DenseConvDims) ddims = DenseConvDims(ddims, F=true) diff --git a/src/layers/macro.jl b/src/layers/macro.jl new file mode 100644 index 0000000000..5bf33f137d --- /dev/null +++ b/src/layers/macro.jl @@ -0,0 +1,173 @@ + +""" + @layer Dense + @layer :expand Chain + @layer BatchNorm trainable=(β,γ) + @layer Struct children=(α,β) trainable=(β,) + +This macro replaces most uses of `@functor`. Its basic purpose is the same: +When you define a new layer, this tells Flux to explore inside it +to see the parameters it trains, and also to move them to the GPU, change precision, etc. +Like `@functor`, this assumes your struct has the default constructor, to enable re-building. + +Some keywords allow you to limit this exploration, instead of visiting all `fieldnames(T)`. +Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. +* If some fields look like parameters but should not be trained, + then `trainable` lets you specify which fields to include, while the rest are ignored. +* You can likewise add restrictions to Functors's `children` (although this is seldom a good idea), + equivalent to `@functor Struct (α,β)`. Any `trainable` limitation must then be a subset of `children`. + +The macro also handles overloads of `show` for pretty printing. +* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`. +* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents. +* To disable all `show` overloads, there is an `:ignore` option too. + +(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.) + +Note that re-running the macro with different options may not overwrite all methods, you will need to restart. + +# Example +```jldoctest +julia> struct Trio; a; b; c end + +julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4)) +Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4)) + +julia> Flux.destructure(tri) # parameters are not yet visible to Flux +(Bool[], Restructure(Trio, ..., 0)) + +julia> Flux.@layer :expand Trio + +julia> Flux.destructure(tri) # now gpu, params, train!, etc will see inside too +([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4)) + +julia> tri # and layer is printed like Chain +Trio( + Dense(2 => 1, tanh), # 3 parameters + Dense(1 => 1; bias=false), # 1 parameters + Dropout(0.4), +) # Total: 3 arrays, 4 parameters, 224 bytes. +``` + +""" +macro layer(exs...) + out = quote end + + # These functions are defined in show.jl, and each return an expression overloading Base.show + type, rest... = if exs[1] == QuoteNode(:expand) + push!(out.args, _macro_big_show(esc(exs[2]))) + exs[2:end] + elseif exs[1] == QuoteNode(:ignore) + exs[2:end] + elseif exs[1] isa QuoteNode + error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)") + else + push!(out.args, _macro_layer_show(esc(exs[1]))) + exs + end + + # This function exists only for depwarns when you use @functor directly + push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) + + i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :children, rest) + if isnothing(i) # then default like @functor Layer + push!(out.args, _macro_functor(esc(type))) + else + push!(out.args, _macro_functor(esc(type), rest[i].args[2])) + end + for j in 1:length(rest) + j == i && continue + ex = rest[j] + Meta.isexpr(ex, :(=)) || error("The macro `@layer` expects here `keyword = (fields...,)`, got $ex") + + name = if ex.args[1] == :trainable + :(Optimisers.trainable) + elseif ex.args[1] == :functor + error("Can't use `functor=(...)` as a keyword to `@layer`. Use `childen=(...)` to define a method for `functor`.") + else + error("`@layer` cannot define a method for `$(ex.args[1])` at the moment, sorry.") + # @warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1 + # esc(ex.args[1]) + end + push!(out.args, _macro_trainable(esc(type), name, ex.args[2])) + end + + out +end + +# Temporary depwarn function, called within `params`, is also called by `show`. + +function _check_new_macro(x::T) where T + Functors.isleaf(x) && return + Base.depwarn("This type should probably now use `Flux.@layer` instead of `@functor`: $T", Symbol("@functor")) +end +_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users +_check_new_macro(::NamedTuple) = nothing +_check_new_macro(::AbstractArray) = nothing +_check_new_macro(::Ref) = nothing + +# @layer's code for Functors & Adapt +# Unlike @functor, _default_functor doesn't need to eval anything + +function _macro_functor(type) + quote + Functors.functor(::Type{T}, x) where {T<:$type} = $_default_functor(T, x) + Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer) + end +end + +function _macro_functor(type, fields) + Meta.isexpr(fields, :tuple) || error("expected a tuple of field names") + symbols = Tuple(map(_noquotenode, fields.args)) + quote + Functors.functor(::Type{T}, x) where {T<:$type} = $_custom_functor(T, x, Val($symbols)) + Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer) + end +end +_macro_functor(type, field::Union{Symbol,QuoteNode}) = _macro_functor(type, :(($field,))) # lets you forget a comma + +function _default_functor(::Type{T}, x) where {T} + if @generated + F = fieldnames(T) + args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) + C = Base.typename(T).wrapper # constructor + recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) + :((NamedTuple{$F}(($(args...),)), $recon)) + else + # Getting this parameterless type takes about 2μs, every time: + spl = VERSION > v"1.9-" ? Splat : Base.splat + namedtuple(x), spl(Base.typename(T).wrapper) + end +end + +function _custom_functor(::Type{T}, x, ::Val{which}) where {T,which} + if false + # TODO write the @generated version + else + remake(nt) = Base.typename(T).wrapper(map(f -> f in which ? getfield(nt, f) : getfield(x, f), fieldnames(T))...) + NamedTuple{which}(map(s -> getfield(x, s), which)), remake + end +end + +function namedtuple(x::T) where T + F = fieldnames(T) + NamedTuple{F}(map(sy -> getfield(x, sy), F)) +end + +# @layer's code for Optimisers.trainable, and perhaps anything else, +# with the pattern that keywords mean function names & what fields they pick. + +function _macro_trainable(type, fun, fields) + Meta.isexpr(fields, :tuple) || error("expected a tuple of field names") + symbols = Tuple(map(_noquotenode, fields.args)) + quoted = map(QuoteNode, symbols) + gets = [:(getfield(x, $f)) for f in quoted] + quote + $fun(x::$type) = NamedTuple{$symbols}(($(gets...),)) + end +end +_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma + +_noquotenode(s::Symbol) = s +_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y) +_noquotenode(ex) = error("expected a symbol here, as a field name, but got $ex") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 1c8fbff5a1..c0a86c8796 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -78,8 +78,7 @@ function Dropout(p::Real; dims=:, active::Union{Bool,Nothing} = nothing, rng = d Dropout(p, dims, active, rng) end -@functor Dropout -trainable(a::Dropout) = (;) +@layer Dropout trainable=() (a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims) @@ -131,8 +130,7 @@ function AlphaDropout(p; rng = default_rng(), active::Union{Bool,Nothing} = noth AlphaDropout(p, active, rng) end -@functor AlphaDropout -trainable(a::AlphaDropout) = (;) +@layer AlphaDropout trainable=() function (a::AlphaDropout)(x::AbstractArray{T}) where T _isactive(a, x) || return x @@ -151,6 +149,8 @@ end testmode!(m::AlphaDropout, mode=true) = (m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m) +Base.show(io::IO, d::AlphaDropout) = print(io, "AlphaDropout(", d.p, ")") + """ LayerNorm(size..., λ=identity; affine=true, eps=1f-5) @@ -199,7 +199,7 @@ end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) -@functor LayerNorm +@layer LayerNorm function (a::LayerNorm)(x::AbstractArray) ChainRulesCore.@ignore_derivatives if a.diag isa Scale @@ -343,8 +343,7 @@ function BatchNorm(chs::Int, λ=identity; active, chs) end -@functor BatchNorm -trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) +@layer BatchNorm trainable=(β,γ) function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N} _size_check(BN, x, N-1 => BN.chs) @@ -437,8 +436,7 @@ function InstanceNorm(chs::Int, λ=identity; active, chs) end -@functor InstanceNorm -trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;) +@layer InstanceNorm trainable=(β,γ) function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N} _size_check(l, x, N-1 => l.chs) @@ -517,8 +515,7 @@ mutable struct GroupNorm{F,V,N} chs::Int # number of channels end -@functor GroupNorm -trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;) +@layer GroupNorm trainable=(β,γ) function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 375ff43d52..f55ebb1741 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -135,8 +135,7 @@ function (m::Recur)(x) return y end -@functor Recur -trainable(a::Recur) = (; cell = a.cell) +@layer :expand Recur trainable=(cell,) Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") @@ -209,7 +208,7 @@ function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where return h, reshape_cell_output(h, x) end -@functor RNNCell +@layer RNNCell # state0 is trainable, see issue 807 about this. function Base.show(io::IO, l::RNNCell) print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)) @@ -318,7 +317,7 @@ function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::AbstractV return (h′, c′), reshape_cell_output(h′, x) end -@functor LSTMCell +@layer LSTMCell Base.show(io::IO, l::LSTMCell) = print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")") @@ -391,7 +390,7 @@ function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where { return h′, reshape_cell_output(h′, x) end -@functor GRUCell +@layer GRUCell Base.show(io::IO, l::GRUCell) = print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") @@ -461,7 +460,7 @@ function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) wh return h′, reshape_cell_output(h′, x) end -@functor GRUv3Cell +@layer GRUv3Cell Base.show(io::IO, l::GRUv3Cell) = print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") diff --git a/src/layers/show.jl b/src/layers/show.jl index fe7fb537a1..09190ec780 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,63 +1,30 @@ -""" - @big_show MyContainer - -This macro lets you opt-in to Flux's fancy printing. - -When `model::MyContainer` is returned at the REPL it will be treated like `Chain`, -and the printing routine will recursively unfold its children. -This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`. - -Custom layers which do not contain other layers (more like `Dense` than like `Chain`) -need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`. - -# Example -```jldoctest -julia> struct Trio{A,B,C}; a::A; b::B; c::C end - -julia> Flux.@functor Trio - -julia> Flux.@big_show Trio - -julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax) -Trio( - Dense(10 => 5, tanh), # 55 parameters - Dense(5 => 2), # 12 parameters - NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 492 bytes. -``` - -Note that there is no automatic method for 2-arg `show`, and thus -something like `(tri, tri)` will print all the type parameters. - -However, `Chain(tri, tri)` will always use Flux's recursive printing, -even without using this macro: `Chain` is the entry point. -""" -macro big_show(ex) - ex isa Symbol || error("usage is `Flux.@big_show Chain`") - eex = esc(ex) - quote - function Base.show(io::IO, m::MIME"text/plain", x::$eex) - if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL - _big_show(io, x) - elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix - _layer_show(io, x) - else - show(io, x) - end - end +@nospecialize # just for this file, for startup time + +# This is called by @layer :expand, on layers which should be treated like Chain, and returns an expression: +function _macro_big_show(ex) + quote + # Entry point: + function Base.show(io::IO, m::MIME"text/plain", x::$ex) + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + _big_show(io, x) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + _layer_show(io, x) + else + show(io, x) + end end -end -@big_show Chain -@big_show Parallel -@big_show SkipConnection -@big_show Recur -@big_show Maxout + # Don't show Chain(Tuple(...)), always splat that. And ignore Recur's non-trainable state: + Flux._show_children(x::$ex) = _flat_children(trainable(x)) + end +end function _big_show(io::IO, obj, indent::Int=0, name=nothing) pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") children = _show_children(obj) if all(_show_leaflike, children) + # This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids, + # but once all layers use @layer, they stop the recursion by defining a method for _big_show. _layer_show(io, obj, indent, name) else println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre) @@ -90,25 +57,33 @@ _show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: # note the covariance of tuple, using <:T causes warning or error _show_leaflike(::Tuple{Vararg{Number}}) = true # e.g. stride of Conv _show_leaflike(::Tuple{Vararg{AbstractArray}}) = true # e.g. parameters of LSTMcell -_show_leaflike(::Scale) = true # appears inside LayerNorm _show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays +_show_leaflike(::Scale) = true # appears inside LayerNorm -_show_children(x) = trainable(x) # except for layers which hide their Tuple: -_show_children(c::Chain) = c.layers -_show_children(m::Maxout) = m.layers -_show_children(p::Parallel) = (p.connection, p.layers...) -_show_children(f::PairwiseFusion) = (f.connection, f.layers...) - -for T in [ - :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag, - :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if !get(io, :compact, false) - _layer_show(io, x) - else - show(io, x) +_show_children(x) = trainable(x) +# This used to have methods for Chain, Maxout, Parallel, PairwiseFusion. Now @layer instead +# writes a method to use this function. It flattens the Tuple within Chain etc. +# (The remaining special cases are for printing of layer names when a NamedTuple, above.) +function _flat_children(x) + alpha = map(f -> getfield(x, f), fieldnames(typeof(x))) + beta = map(y -> y isa Union{Tuple, NamedTuple} ? y : (y,), alpha) + gamma = ((beta...)...,) +end + +# This is called by @layer, on layers which should be treated like Dense, and returns an expression: +function _macro_layer_show(ex) + quote + # Entry point: + function Base.show(io::IO, m::MIME"text/plain", x::$ex) + if !get(io, :compact, false) + _layer_show(io, x) + else + show(io, x) + end end + + # Exit from _big_show recursion: + Flux._big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name) end end @@ -167,6 +142,8 @@ function _nan_show(io::IO, x) end end +@specialize # un-does @nospecialze at the top of this file + _any(f, xs::AbstractArray{<:Number}) = any(f, xs) # _any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs) _any(f, xs) = any(x -> _any(f, x), xs) diff --git a/test/layers/macro.jl b/test/layers/macro.jl new file mode 100644 index 0000000000..1361a895f4 --- /dev/null +++ b/test/layers/macro.jl @@ -0,0 +1,45 @@ +using Flux, Functors, Optimisers + +module MacroTest + using Flux: @layer + + struct Duo{T,S}; x::T; y::S; end + @layer :expand Duo + + struct Trio; a; b; c end + # @layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget + @layer Trio trainable=(a,b) # defining a method for test is made an error, for now + + struct TwoThirds; a; b; c; end +end + +@testset "@layer macro" begin + @test !isdefined(MacroTest, :Flux) # That's why the module, to check scope + + m2 = MacroTest.Duo(Dense(2=>2), Chain(Flux.Scale(2), Dropout(0.2))) + + @test Functors.children(m2) isa NamedTuple{(:x, :y)} + @test length(Optimisers.destructure(m2)[1]) == 10 + + m3 = MacroTest.Trio([1.0], [2.0], [3.0]) + + @test Functors.children(m3) isa NamedTuple{(:a, :b, :c)} + @test fmap(zero, m3) isa MacroTest.Trio + + @test Optimisers.trainable(m3) isa NamedTuple{(:a, :b)} + @test Optimisers.destructure(m3)[1] == [1, 2] + + # @test MacroTest.test(m3) == (c = [3.0],) # removed, for now + + m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6]) + # Check that we can use the macro with a qualified type name, outside the defining module: + Flux.@layer :expand MacroTest.TwoThirds children=(:a,:c) trainable=(:a) # documented as (a,c) but allow quotes + + @test Functors.children(m23) == (a = [1 2], c = [5 6]) + m23re = Functors.functor(m23)[2]((a = [10 20], c = [50 60])) + @test m23re isa MacroTest.TwoThirds + @test Flux.namedtuple(m23re) == (a = [10 20], b = [3 4], c = [50 60]) + + @test Optimisers.trainable(m23) == (a = [1 2],) +end + diff --git a/test/runtests.jl b/test/runtests.jl index 94e0c466e6..8dca6becdd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,6 +48,7 @@ Random.seed!(0) include("layers/conv.jl") include("layers/upsample.jl") include("layers/show.jl") + include("layers/macro.jl") end @testset "outputsize" begin diff --git a/test/utils.jl b/test/utils.jl index 620a4d40b4..e175eb1f5b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -616,7 +616,7 @@ end a::A b::A end - Flux.@functor Model + Flux.@layer Model (m::Model)(x) = m.a(x) .+ m.b(x) d = Dense(1, 1) From 35ade0ac913dfc8204855d1670d08e93e2eee3ed Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Feb 2024 10:14:55 -0500 Subject: [PATCH 3/7] minor fixes after rebase --- NEWS.md | 3 +++ src/layers/macro.jl | 8 +++++--- src/layers/show.jl | 1 - 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/NEWS.md b/NEWS.md index 9eb544a6f6..68d36fdc34 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,9 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl * New macro `Flux.@layer` which should be used in place of `@functor`. This also adds `show` methods for pretty printing. +## v0.14.12 +* New `SignDecay` optimiser, like `` WeightNorm` but for L1 norm. + ## v0.14.0 (July 2023) * Flux now requires julia v1.9 or later. * CUDA.jl is not a hard dependency anymore. Support is now provided through the extension mechanism, by loading `using Flux, CUDA`. diff --git a/src/layers/macro.jl b/src/layers/macro.jl index 5bf33f137d..02e5ef4540 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -131,18 +131,20 @@ function _default_functor(::Type{T}, x) where {T} F = fieldnames(T) args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) C = Base.typename(T).wrapper # constructor - recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) + # recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) + recon = :(Base.splat($C)) :((NamedTuple{$F}(($(args...),)), $recon)) else # Getting this parameterless type takes about 2μs, every time: - spl = VERSION > v"1.9-" ? Splat : Base.splat + # spl = VERSION > v"1.9-" ? Splat : Base.splat + spl = Base.splat namedtuple(x), spl(Base.typename(T).wrapper) end end function _custom_functor(::Type{T}, x, ::Val{which}) where {T,which} if false - # TODO write the @generated version + # TODO write the @generated version. Or decide we don't care, or should forbid this? else remake(nt) = Base.typename(T).wrapper(map(f -> f in which ? getfield(nt, f) : getfield(x, f), fieldnames(T))...) NamedTuple{which}(map(s -> getfield(x, s), which)), remake diff --git a/src/layers/show.jl b/src/layers/show.jl index 09190ec780..a03ddf3754 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -58,7 +58,6 @@ _show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: _show_leaflike(::Tuple{Vararg{Number}}) = true # e.g. stride of Conv _show_leaflike(::Tuple{Vararg{AbstractArray}}) = true # e.g. parameters of LSTMcell _show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays -_show_leaflike(::Scale) = true # appears inside LayerNorm _show_children(x) = trainable(x) # This used to have methods for Chain, Maxout, Parallel, PairwiseFusion. Now @layer instead From 29e0d689b5da38c37b8e92a055f0d34d835c3819 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 1 Mar 2024 20:45:59 -0500 Subject: [PATCH 4/7] layer MultiHeadAttention, and show methods for this --- src/layers/attention.jl | 45 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 3701be2bb0..e058088156 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2} out_proj::P2 end -@functor MultiHeadAttention +@layer MultiHeadAttention function MultiHeadAttention(dims; nheads::Int = 8, @@ -83,8 +83,8 @@ function MultiHeadAttention(dims; dropout_prob = 0.0) dims = normalize_mha_dims(dims) - @assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads" - @assert dims.v % nheads == 0 "v_dim should be divisible by nheads" + dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)") + dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)") q_proj = Dense(dims.q_in => dims.qk; bias, init) k_proj = Dense(dims.k_in => dims.qk; bias, init) v_proj = Dense(dims.v_in => dims.v; bias, init) @@ -131,3 +131,42 @@ function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, # [α] = [kv_len, q_len, nheads, batch_size] return x, α end + +function Base.show(io::IO, mha::MultiHeadAttention) + qk, q_in = size(mha.q_proj.weight) + qk, k_in = size(mha.k_proj.weight) + v, v_in = size(mha.v_proj.weight) + out, v = size(mha.out_proj.weight) + # @show q_in, k_in, v_in, qk, v, out + print(io, "MultiHeadAttention(") + if q_in == k_in == v_in == qk == v == out + print(io, q_in) + elseif q_in == k_in == v_in && qk == v + print(io, q_in, " => ", qk, " => ", out) + elseif q_in == k_in == v_in + print(io, q_in, " => (", qk, ", ", v,") => ", out) + else + print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out) + end + print(io, "; nheads=", mha.nheads) + if mha.q_proj.bias === true + print(io, ", bias=true") + end + if mha.attn_drop.p != 0 + print(io, ", dropout_prob=", mha.attn_drop.p) # can't we rename this? + end + print(io, ")") +end + +Base.show(io::IO, ::MIME"text/plain", mha::MultiHeadAttention) = show(io, mha) + +#= + +# Test cases: + +MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1) +MultiHeadAttention(3 => (6, 7) => 8; nheads=1) +MultiHeadAttention(3 => 6 => 8; nheads=1) +MultiHeadAttention(8; bias=true) + +=# From 2be90998d33b1a582371105afc837d1c8840c400 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 1 Mar 2024 21:44:43 -0500 Subject: [PATCH 5/7] =?UTF-8?q?remove=20children=3D(=CE=B1,=CE=B2)=20keywo?= =?UTF-8?q?rd?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/layers/macro.jl | 25 +++---------------------- test/layers/macro.jl | 8 +++++--- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/src/layers/macro.jl b/src/layers/macro.jl index 02e5ef4540..9e770add87 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -3,19 +3,16 @@ @layer Dense @layer :expand Chain @layer BatchNorm trainable=(β,γ) - @layer Struct children=(α,β) trainable=(β,) This macro replaces most uses of `@functor`. Its basic purpose is the same: When you define a new layer, this tells Flux to explore inside it to see the parameters it trains, and also to move them to the GPU, change precision, etc. Like `@functor`, this assumes your struct has the default constructor, to enable re-building. -Some keywords allow you to limit this exploration, instead of visiting all `fieldnames(T)`. +The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`. Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. * If some fields look like parameters but should not be trained, then `trainable` lets you specify which fields to include, while the rest are ignored. -* You can likewise add restrictions to Functors's `children` (although this is seldom a good idea), - equivalent to `@functor Struct (α,β)`. Any `trainable` limitation must then be a subset of `children`. The macro also handles overloads of `show` for pretty printing. * By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`. @@ -69,21 +66,14 @@ macro layer(exs...) # This function exists only for depwarns when you use @functor directly push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) - i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :children, rest) - if isnothing(i) # then default like @functor Layer - push!(out.args, _macro_functor(esc(type))) - else - push!(out.args, _macro_functor(esc(type), rest[i].args[2])) - end + push!(out.args, _macro_functor(esc(type))) + for j in 1:length(rest) - j == i && continue ex = rest[j] Meta.isexpr(ex, :(=)) || error("The macro `@layer` expects here `keyword = (fields...,)`, got $ex") name = if ex.args[1] == :trainable :(Optimisers.trainable) - elseif ex.args[1] == :functor - error("Can't use `functor=(...)` as a keyword to `@layer`. Use `childen=(...)` to define a method for `functor`.") else error("`@layer` cannot define a method for `$(ex.args[1])` at the moment, sorry.") # @warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1 @@ -141,15 +131,6 @@ function _default_functor(::Type{T}, x) where {T} namedtuple(x), spl(Base.typename(T).wrapper) end end - -function _custom_functor(::Type{T}, x, ::Val{which}) where {T,which} - if false - # TODO write the @generated version. Or decide we don't care, or should forbid this? - else - remake(nt) = Base.typename(T).wrapper(map(f -> f in which ? getfield(nt, f) : getfield(x, f), fieldnames(T))...) - NamedTuple{which}(map(s -> getfield(x, s), which)), remake - end -end function namedtuple(x::T) where T F = fieldnames(T) diff --git a/test/layers/macro.jl b/test/layers/macro.jl index 1361a895f4..e41d5a2240 100644 --- a/test/layers/macro.jl +++ b/test/layers/macro.jl @@ -33,13 +33,15 @@ end m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6]) # Check that we can use the macro with a qualified type name, outside the defining module: - Flux.@layer :expand MacroTest.TwoThirds children=(:a,:c) trainable=(:a) # documented as (a,c) but allow quotes + Flux.@layer :expand MacroTest.TwoThirds trainable=(:a) # documented as (a,c) but allow quotes - @test Functors.children(m23) == (a = [1 2], c = [5 6]) - m23re = Functors.functor(m23)[2]((a = [10 20], c = [50 60])) + m23re = Functors.functor(m23)[2]((a = [10 20], b = [3 4], c = [50 60])) @test m23re isa MacroTest.TwoThirds @test Flux.namedtuple(m23re) == (a = [10 20], b = [3 4], c = [50 60]) @test Optimisers.trainable(m23) == (a = [1 2],) + + @test_throws LoadError @eval Flux.@layer :zzz MacroTest.TwoThirds + @test_throws LoadError @eval Flux.@layer MacroTest.TwoThirds chidren=(a, b) end From fe7208168e47faeabfada4991103b62279e33f7d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 2 Mar 2024 14:39:36 -0500 Subject: [PATCH 6/7] fixup attention --- src/layers/attention.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index e058088156..d4a33283d9 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -83,8 +83,8 @@ function MultiHeadAttention(dims; dropout_prob = 0.0) dims = normalize_mha_dims(dims) - dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)") - dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)") + dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)")) + dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)")) q_proj = Dense(dims.q_in => dims.qk; bias, init) k_proj = Dense(dims.k_in => dims.qk; bias, init) v_proj = Dense(dims.v_in => dims.v; bias, init) @@ -149,7 +149,7 @@ function Base.show(io::IO, mha::MultiHeadAttention) print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out) end print(io, "; nheads=", mha.nheads) - if mha.q_proj.bias === true + if mha.q_proj.bias !== false print(io, ", bias=true") end if mha.attn_drop.p != 0 @@ -158,11 +158,10 @@ function Base.show(io::IO, mha::MultiHeadAttention) print(io, ")") end -Base.show(io::IO, ::MIME"text/plain", mha::MultiHeadAttention) = show(io, mha) #= -# Test cases: +# Test cases for printing: MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1) MultiHeadAttention(3 => (6, 7) => 8; nheads=1) From 63e10f6074bd793f4d5321d45bc08125a484b911 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 2 Mar 2024 15:09:31 -0500 Subject: [PATCH 7/7] rm note --- docs/src/models/advanced.md | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 531d176cb5..ab045d96be 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -73,14 +73,7 @@ The exact same method of `trainable` can also be defined using the macro, for co Flux.@layer Affine trainable=(W,) ``` -There is a second, more severe, kind of restriction possible: - -``` -Flux.@layer Affine children=(W,) -``` - -This is equivalent to `Functors.@functor Affine (W,)`. It means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This is not usually recommended. - This is generally not recommended. It requires the `struct` to have a corresponding constructor that accepts only `W` as an argument. +There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument. ## Freezing Layer Parameters