Skip to content

Commit 6370374

Browse files
committed
upgrade to at-layer macro, replaces at-functor
1 parent 9f9051f commit 6370374

File tree

8 files changed

+259
-88
lines changed

8 files changed

+259
-88
lines changed

src/Flux.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@ include("functor.jl")
4444
# Pirate error to catch a common mistake.
4545
Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.")
4646

47+
include("layers/show.jl")
48+
include("layers/macro.jl")
49+
4750
include("layers/stateless.jl")
4851
include("layers/basic.jl")
4952
include("layers/conv.jl")
5053
include("layers/recurrent.jl")
5154
include("layers/normalise.jl")
5255
include("layers/upsample.jl")
53-
include("layers/show.jl")
5456

5557
include("loading.jl")
5658

src/functor.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ function params!(p::Params, x, seen = IdSet())
4242
elseif x in seen
4343
nothing
4444
else
45+
_check_new_macro(x) # complains if you used @functor not @layer
4546
push!(seen, x)
4647
for child in trainable(x)
4748
params!(p, child, seen)

src/layers/basic.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
4747
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
4848

49-
@functor Chain
49+
@layer :expand Chain # the + opts-in to container-style pretty-printing
5050

5151
(c::Chain)(x) = _applychain(c.layers, x)
5252

@@ -165,7 +165,7 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity;
165165
Dense(init(out, in), bias, σ)
166166
end
167167

168-
@functor Dense
168+
@layer Dense
169169

170170
function (a::Dense)(x::AbstractVecOrMat)
171171
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
@@ -236,7 +236,7 @@ end
236236
Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act)
237237
Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end])
238238

239-
@functor Scale
239+
@layer Scale
240240

241241
function (a::Scale)(x::AbstractArray)
242242
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
@@ -291,7 +291,7 @@ end
291291
Maxout(layers...) = Maxout(layers)
292292
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)
293293

294-
@functor Maxout
294+
@layer :expand Maxout
295295

296296
function (mo::Maxout)(input::AbstractArray)
297297
# Perhaps surprisingly, pairwise max broadcast is often faster,
@@ -338,7 +338,7 @@ struct SkipConnection{T,F}
338338
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
339339
end
340340

341-
@functor SkipConnection
341+
@layer SkipConnection # should this be expand?
342342

343343
function (skip::SkipConnection)(input)
344344
skip.connection(skip.layers(input), input)
@@ -408,7 +408,7 @@ struct Bilinear{F,A,B}
408408
end
409409
end
410410

411-
@functor Bilinear
411+
@layer Bilinear
412412

413413
function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity;
414414
bias = true, init = glorot_uniform)
@@ -507,7 +507,7 @@ function Parallel(connection; kw...)
507507
Parallel(connection, layers)
508508
end
509509

510-
@functor Parallel
510+
@layer :expand Parallel
511511

512512
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
513513
(m::Parallel)(xs::Tuple) = m(xs...)
@@ -628,7 +628,7 @@ end
628628
end
629629
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)
630630

631-
@functor PairwiseFusion
631+
@layer :expand PairwiseFusion
632632

633633
Base.getindex(m::PairwiseFusion, i) = m.layers[i]
634634
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
@@ -676,7 +676,7 @@ struct Embedding{W}
676676
weight::W
677677
end
678678

679-
@functor Embedding
679+
@layer Embedding
680680

681681
Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))
682682

src/layers/conv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
187187
init(filter..., cin÷groups, cout)
188188
end
189189

190-
@functor Conv
190+
@layer Conv
191191

192192
conv_dims(c::Conv, x::AbstractArray) =
193193
DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
@@ -307,7 +307,7 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
307307
ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
308308
end
309309

310-
@functor ConvTranspose
310+
@layer ConvTranspose
311311

312312
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
313313
# Calculate size of "input", from ∇conv_data()'s perspective...
@@ -453,7 +453,7 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden
453453
return CrossCor(weight, bias, σ; stride, pad, dilation)
454454
end
455455

456-
@functor CrossCor
456+
@layer CrossCor
457457

458458
function crosscor(x, w, ddims::DenseConvDims)
459459
ddims = DenseConvDims(ddims, F=true)

src/layers/macro.jl

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
2+
"""
3+
@layer Dense
4+
@layer :expand Chain
5+
@layer BatchNorm trainable=(β,γ)
6+
@layer Struct functor=(α,β) trainable=(β,)
7+
8+
This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
9+
When you define a new layer, this tells Flux to explore inside it
10+
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
11+
12+
Some "keywords" allow control of the recursion:
13+
* If some fields look like parameters but should not be trained,
14+
then `Optimisers.trainable` lets you specify fields to include, and ignore the rest.
15+
* We can likewise add restructions to `Functors.functor`, but not yet written.
16+
* In fact you can provide an arbitrary keyword with this syntax, and it will
17+
overload this function alla `trainable`... that might be a terrible idea.
18+
19+
It also handles overloads of `show` for pretty printing.
20+
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
21+
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
22+
* To disable all `show` overloads, maybe we want a `:ignore` option too.
23+
24+
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
25+
"""
26+
macro layer(exs...)
27+
out = quote end
28+
29+
# These functions are defined in show.jl, and each return an expression overloading Base.show
30+
type, rest... = if exs[1] == QuoteNode(:expand)
31+
push!(out.args, _macro_big_show(esc(exs[2])))
32+
exs[2:end]
33+
elseif exs[1] == QuoteNode(:ignore)
34+
exs[2:end]
35+
elseif exs[1] isa QuoteNode
36+
error("before the type, only accepted options are `:expand` and `:ignore`")
37+
else
38+
push!(out.args, _macro_layer_show(esc(exs[1])))
39+
exs
40+
end
41+
42+
# This function exists only for depwarns when you use @functor directly
43+
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) # scope is weird ?? can't use $ on func name?
44+
45+
i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :functor, rest)
46+
if isnothing(i)
47+
push!(out.args, _macro_functor(esc(type)))
48+
else
49+
push!(out.args, _macro_functor(esc(type), rest[i].args[2]))
50+
end
51+
for j in 1:length(rest)
52+
j == i && continue
53+
ex = rest[j]
54+
Meta.isexpr(ex, :(=)) || error("expected keyword = fields")
55+
if ex.args[1] == :trainable
56+
push!(out.args, _macro_trainable(type, trainable, ex.args[2])) # pass the function "trainable" not the symbol
57+
else
58+
error()
59+
# @warn "defining a method for $(ex.args[1]) in your scope" # ??
60+
# push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2]))
61+
end
62+
end
63+
64+
out
65+
end
66+
67+
# Temporary depwarn function:
68+
69+
function _check_new_macro(x::T) where T
70+
Functors.isleaf(x) && return
71+
@warn "you used @functor for this type, but should now use @layer" T maxlog=1 _id=hash(T)
72+
end
73+
_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users
74+
_check_new_macro(::NamedTuple) = nothing
75+
_check_new_macro(::Transpose) = nothing
76+
_check_new_macro(::Adjoint) = nothing
77+
_check_new_macro(::Ref) = nothing
78+
79+
# @layer's code for Functors & Adapt
80+
# Unlike @functor, _default_functor doesn't need to eval anything
81+
82+
function _macro_functor(type)
83+
quote
84+
Functors.functor(::Type{T}, x) where {T<:$type} = _default_functor(T, x)
85+
Adapt.adapt_structure(to, layer::$type) = fmap(adapt(to), layer)
86+
end
87+
end
88+
89+
function _macro_functor(type, fields)
90+
error("the equivalent of @functor Layer (:x,) isn't written yet, sorry")
91+
end
92+
93+
function _default_functor(::Type{T}, x) where {T}
94+
if @generated
95+
F = fieldnames(T)
96+
args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F)
97+
C = Base.typename(T).name # constructor
98+
recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C))
99+
:((NamedTuple{$F}(($(args...),)), $recon))
100+
else
101+
# Getting this parameterless type takes about 2μs, every time:
102+
namedtuple(x), Base.splat(Base.typename(T).wrapper)
103+
end
104+
end
105+
106+
function namedtuple(x::T) where T
107+
F = fieldnames(T)
108+
NamedTuple{F}(map(sy -> getfield(x, sy), F))
109+
end
110+
111+
# @layer's code for Optimisers.trainable, and perhaps anything else,
112+
# with the pattern that keywords mean function names & what fields they pick.
113+
114+
function _macro_trainable(type, fun, fields)
115+
Meta.isexpr(fields, :tuple) || error("expected a tuple of field names")
116+
symbols = Tuple(map(_noquotenode, fields.args))
117+
quoted = map(QuoteNode, symbols)
118+
gets = [:(getfield(x, $f)) for f in quoted]
119+
quote
120+
# $fun(x::$type) = NamedTuple{$names}(($(gets...),))
121+
Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
122+
end
123+
end
124+
_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma
125+
126+
_noquotenode(s::Symbol) = s
127+
_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y)
128+
_noquotenode(ex) = error("expected a symbol, got $ex")
129+
130+
131+
132+
133+
134+
135+
# @big_show Chain
136+
# @big_show Parallel
137+
# @big_show SkipConnection
138+
# @big_show Recur
139+
# @big_show Maxout
140+
141+
142+
143+
144+
"""
145+
@big_show MyContainer
146+
147+
This macro lets you opt-in to Flux's fancy printing.
148+
149+
When `model::MyContainer` is returned at the REPL it will be treated like `Chain`,
150+
and the printing routine will recursively unfold its children.
151+
This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`.
152+
153+
Custom layers which do not contain other layers (more like `Dense` than like `Chain`)
154+
need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`.
155+
156+
# Example
157+
```jldoctest
158+
julia> struct Trio{A,B,C}; a::A; b::B; c::C end
159+
160+
julia> Flux.@functor Trio
161+
162+
julia> Flux.@big_show Trio
163+
164+
julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax)
165+
Trio(
166+
Dense(10 => 5, tanh), # 55 parameters
167+
Dense(5 => 2), # 12 parameters
168+
NNlib.softmax,
169+
) # Total: 4 arrays, 67 parameters, 492 bytes.
170+
```
171+
172+
Note that there is no automatic method for 2-arg `show`, and thus
173+
something like `(tri, tri)` will print all the type parameters.
174+
175+
However, `Chain(tri, tri)` will always use Flux's recursive printing,
176+
even without using this macro: `Chain` is the entry point.
177+
"""

src/layers/normalise.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ function Dropout(p; dims=:, rng = default_rng_value())
103103
Dropout(p, dims, nothing, rng)
104104
end
105105

106-
@functor Dropout
107-
trainable(a::Dropout) = (;)
106+
@layer Dropout trainable=()
107+
# trainable(a::Dropout) = (;)
108108

109109
function (a::Dropout)(x)
110110
_isactive(a) || return x
@@ -158,8 +158,8 @@ end
158158
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
159159
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
160160

161-
@functor AlphaDropout
162-
trainable(a::AlphaDropout) = (;)
161+
@layer AlphaDropout trainable=()
162+
# trainable(a::AlphaDropout) = (;)
163163

164164
function (a::AlphaDropout)(x::AbstractArray{T}) where T
165165
_isactive(a) || return x
@@ -224,7 +224,7 @@ end
224224
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
225225
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)
226226

227-
@functor LayerNorm
227+
@layer LayerNorm
228228

229229
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
230230

@@ -352,8 +352,8 @@ function BatchNorm(chs::Int, λ=identity;
352352
nothing, chs)
353353
end
354354

355-
@functor BatchNorm
356-
trainable(bn::BatchNorm) = hasaffine(bn) ?= bn.β, γ = bn.γ) : (;)
355+
@layer BatchNorm trainable=(β,γ)
356+
# trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
357357

358358
function (BN::BatchNorm)(x)
359359
@assert size(x, ndims(x)-1) == BN.chs
@@ -442,8 +442,8 @@ function InstanceNorm(chs::Int, λ=identity;
442442
nothing, chs)
443443
end
444444

445-
@functor InstanceNorm
446-
trainable(in::InstanceNorm) = hasaffine(in) ?= in.β, γ = in.γ) : (;)
445+
@layer InstanceNorm trainable=(β,γ)
446+
# trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)
447447

448448
function (l::InstanceNorm)(x)
449449
@assert ndims(x) > 2
@@ -521,8 +521,8 @@ mutable struct GroupNorm{F,V,N,W}
521521
chs::Int # number of channels
522522
end
523523

524-
@functor GroupNorm
525-
trainable(gn::GroupNorm) = hasaffine(gn) ?= gn.β, γ = gn.γ) : (;)
524+
@layer GroupNorm trainable=(β,γ)
525+
# trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)
526526

527527
function GroupNorm(chs::Int, G::Int, λ=identity;
528528
initβ=zeros32, initγ=ones32,

0 commit comments

Comments
 (0)