Skip to content

Commit 85252bf

Browse files
committed
fixup
1 parent 6370374 commit 85252bf

File tree

9 files changed

+98
-98
lines changed

9 files changed

+98
-98
lines changed

docs/src/models/basics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ m(5) # => 26
216216
Flux provides a set of helpers for custom layers, which you can enable by calling
217217

218218
```julia
219-
Flux.@functor Affine
219+
Flux.@layer Affine
220220
```
221221

222222
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using MacroTools: @forward
77

88
@reexport using NNlib
99
using MLUtils
10+
const stack = MLUtils.stack # now exported by Base
1011
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
1112

1213
using Zygote, ChainRulesCore

src/layers/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ struct SkipConnection{T,F}
338338
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
339339
end
340340

341-
@layer SkipConnection # should this be expand?
341+
@layer :expand SkipConnection
342342

343343
function (skip::SkipConnection)(input)
344344
skip.connection(skip.layers(input), input)

src/layers/macro.jl

Lines changed: 48 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,50 @@
33
@layer Dense
44
@layer :expand Chain
55
@layer BatchNorm trainable=(β,γ)
6-
@layer Struct functor=(α,β) trainable=(β,)
6+
@layer Struct children=(α,β) trainable=(β,)
77
88
This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
99
When you define a new layer, this tells Flux to explore inside it
1010
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
1111
1212
Some "keywords" allow control of the recursion:
1313
* If some fields look like parameters but should not be trained,
14-
then `Optimisers.trainable` lets you specify fields to include, and ignore the rest.
15-
* We can likewise add restructions to `Functors.functor`, but not yet written.
16-
* In fact you can provide an arbitrary keyword with this syntax, and it will
17-
overload this function alla `trainable`... that might be a terrible idea.
14+
then `trainable` lets you specify fields to include, and ignore the rest.
15+
* We can likewise add restructions to Functors's `children`,
16+
but this is not yet written (as this is seldom a good idea).
1817
1918
It also handles overloads of `show` for pretty printing.
2019
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
2120
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
22-
* To disable all `show` overloads, maybe we want a `:ignore` option too.
21+
* To disable all `show` overloads, there is an `:ignore` option too.
2322
2423
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
24+
25+
Note that re-running the macro with different options does not overwrite all methods, you will need to restart.
26+
27+
# Example
28+
```jldoctest
29+
julia> struct Trio; a; b; c end
30+
31+
julia> tri = Trio(Dense([1.1 2.2],), Dense([3.3;;], false), Dropout(0.4))
32+
Trio(Dense(1 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
33+
34+
julia> Flux.destructure(tri) # parameters not visible to Flux
35+
(Bool[], Restructure(Trio, ..., 0))
36+
37+
julia> Flux.@layer :expand Trio
38+
39+
julia> Flux.destructure(tri) # now gpu, train!, etc will see inside too
40+
([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4))
41+
42+
julia> tri
43+
Trio(
44+
Dense(2 => 1), # 3 parameters
45+
Dense(1 => 1; bias=false), # 1 parameters
46+
Dropout(0.4),
47+
) # Total: 3 arrays, 4 parameters, 224 bytes.
48+
```
49+
2550
"""
2651
macro layer(exs...)
2752
out = quote end
@@ -40,10 +65,10 @@ macro layer(exs...)
4065
end
4166

4267
# This function exists only for depwarns when you use @functor directly
43-
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) # scope is weird ?? can't use $ on func name?
68+
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing))
4469

45-
i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :functor, rest)
46-
if isnothing(i)
70+
i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :children, rest)
71+
if isnothing(i) # then default like @functor Layer
4772
push!(out.args, _macro_functor(esc(type)))
4873
else
4974
push!(out.args, _macro_functor(esc(type), rest[i].args[2]))
@@ -52,13 +77,14 @@ macro layer(exs...)
5277
j == i && continue
5378
ex = rest[j]
5479
Meta.isexpr(ex, :(=)) || error("expected keyword = fields")
55-
if ex.args[1] == :trainable
56-
push!(out.args, _macro_trainable(type, trainable, ex.args[2])) # pass the function "trainable" not the symbol
80+
81+
name = if ex.args[1] == :trainable
82+
:(Optimisers.trainable)
5783
else
58-
error()
59-
# @warn "defining a method for $(ex.args[1]) in your scope" # ??
60-
# push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2]))
84+
@warn "trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
85+
esc(ex.args[1])
6186
end
87+
push!(out.args, _macro_trainable(esc(type), name, ex.args[2]))
6288
end
6389

6490
out
@@ -72,17 +98,16 @@ function _check_new_macro(x::T) where T
7298
end
7399
_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users
74100
_check_new_macro(::NamedTuple) = nothing
75-
_check_new_macro(::Transpose) = nothing
76-
_check_new_macro(::Adjoint) = nothing
101+
_check_new_macro(::AbstractArray) = nothing
77102
_check_new_macro(::Ref) = nothing
78103

79104
# @layer's code for Functors & Adapt
80105
# Unlike @functor, _default_functor doesn't need to eval anything
81106

82107
function _macro_functor(type)
83108
quote
84-
Functors.functor(::Type{T}, x) where {T<:$type} = _default_functor(T, x)
85-
Adapt.adapt_structure(to, layer::$type) = fmap(adapt(to), layer)
109+
Functors.functor(::Type{T}, x) where {T<:$type} = $_default_functor(T, x)
110+
Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer)
86111
end
87112
end
88113

@@ -94,12 +119,13 @@ function _default_functor(::Type{T}, x) where {T}
94119
if @generated
95120
F = fieldnames(T)
96121
args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F)
97-
C = Base.typename(T).name # constructor
122+
C = Base.typename(T).wrapper # constructor
98123
recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C))
99124
:((NamedTuple{$F}(($(args...),)), $recon))
100125
else
101126
# Getting this parameterless type takes about 2μs, every time:
102-
namedtuple(x), Base.splat(Base.typename(T).wrapper)
127+
spl = VERSION > v"1.9-" ? Splat : Base.splat
128+
namedtuple(x), spl(Base.typename(T).wrapper)
103129
end
104130
end
105131

@@ -117,61 +143,12 @@ function _macro_trainable(type, fun, fields)
117143
quoted = map(QuoteNode, symbols)
118144
gets = [:(getfield(x, $f)) for f in quoted]
119145
quote
120-
# $fun(x::$type) = NamedTuple{$names}(($(gets...),))
121-
Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
146+
$fun(x::$type) = NamedTuple{$symbols}(($(gets...),))
147+
# Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
122148
end
123149
end
124150
_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma
125151

126152
_noquotenode(s::Symbol) = s
127153
_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y)
128154
_noquotenode(ex) = error("expected a symbol, got $ex")
129-
130-
131-
132-
133-
134-
135-
# @big_show Chain
136-
# @big_show Parallel
137-
# @big_show SkipConnection
138-
# @big_show Recur
139-
# @big_show Maxout
140-
141-
142-
143-
144-
"""
145-
@big_show MyContainer
146-
147-
This macro lets you opt-in to Flux's fancy printing.
148-
149-
When `model::MyContainer` is returned at the REPL it will be treated like `Chain`,
150-
and the printing routine will recursively unfold its children.
151-
This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`.
152-
153-
Custom layers which do not contain other layers (more like `Dense` than like `Chain`)
154-
need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`.
155-
156-
# Example
157-
```jldoctest
158-
julia> struct Trio{A,B,C}; a::A; b::B; c::C end
159-
160-
julia> Flux.@functor Trio
161-
162-
julia> Flux.@big_show Trio
163-
164-
julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax)
165-
Trio(
166-
Dense(10 => 5, tanh), # 55 parameters
167-
Dense(5 => 2), # 12 parameters
168-
NNlib.softmax,
169-
) # Total: 4 arrays, 67 parameters, 492 bytes.
170-
```
171-
172-
Note that there is no automatic method for 2-arg `show`, and thus
173-
something like `(tri, tri)` will print all the type parameters.
174-
175-
However, `Chain(tri, tri)` will always use Flux's recursive printing,
176-
even without using this macro: `Chain` is the entry point.
177-
"""

src/layers/normalise.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ end
178178
testmode!(m::AlphaDropout, mode=true) =
179179
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
180180

181+
Base.show(io::IO, d::AlphaDropout) = print(io, "AlphaDropout(", d.p, ")")
182+
181183
"""
182184
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
183185

src/layers/show.jl

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function _macro_big_show(ex)
1515
end
1616

1717
# Don't show Chain(Tuple(...)), always splat that:
18-
_show_children(x::$ex) = _flat_children(x)
18+
Flux._show_children(x::$ex) = _flat_children(x)
1919
end
2020
end
2121

@@ -56,12 +56,10 @@ _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LS
5656
# _show_leaflike(::Scale) = true # appears inside LayerNorm
5757
_show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays
5858

59-
_show_children(x) = trainable(x) # except for layers which hide their Tuple:
60-
# _show_children(c::Chain) = c.layers
61-
# _show_children(m::Maxout) = m.layers
62-
# _show_children(p::Parallel) = (p.connection, p.layers...)
63-
# _show_children(f::PairwiseFusion) = (f.connection, f.layers...)
64-
59+
_show_children(x) = trainable(x)
60+
# This used to have methods for Chain, Maxout, Parallel, PairwiseFusion. Now @layer instead
61+
# writes a method to use this function. It flattens the Tuple within Chain etc.
62+
# (Some still special-cased above, for printing of layer names when NamedTuple.)
6563
function _flat_children(x)
6664
alpha = map(f -> getfield(x, f), fieldnames(typeof(x)))
6765
beta = map(y -> y isa Union{Tuple, NamedTuple} ? y : (y,), alpha)
@@ -79,25 +77,11 @@ function _macro_layer_show(ex)
7977
show(io, x)
8078
end
8179
end
82-
83-
# Exit from _big_show recursion, do we need this and _show_leaflike?
84-
_big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name)
85-
# Since this isn't a container, do not recurse into its children, if any:
86-
_show_leaflike(::$ex) = true
80+
81+
# Exit from _big_show recursion:
82+
Flux._big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name)
8783
end
8884
end
89-
# for T in [
90-
# :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
91-
# :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
92-
# ]
93-
# @eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
94-
# if !get(io, :compact, false)
95-
# _layer_show(io, x)
96-
# else
97-
# show(io, x)
98-
# end
99-
# end
100-
# end
10185

10286
function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
10387
_str = isnothing(name) ? "" : "$name = "

test/layers/macro.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using Flux, Functors, Optimisers
2+
3+
module MacroTest
4+
using Flux: @layer
5+
6+
struct Duo{T,S}; x::T; y::S; end
7+
@layer :expand Duo
8+
9+
struct Trio; a; b; c end
10+
@layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget
11+
12+
# struct TwoThirds; a; b; c; end
13+
# Flux.@layer :expand TwoThirds children=(a,c) trainable=(a) # should be (a,) but it lets you forget
14+
15+
end
16+
17+
@testset "@layer macro" begin
18+
@test !isdefined(MacroTest, :Flux) # That's why the module, to check scope
19+
20+
m2 = MacroTest.Duo(Dense(2=>2), Chain(Flux.Scale(2), Dropout(0.2)))
21+
22+
@test Functors.children(m2) isa NamedTuple{(:x, :y)}
23+
@test length(Optimisers.destructure(m2)[1]) == 10
24+
25+
m3 = MacroTest.Trio([1.0], [2.0], [3.0])
26+
27+
@test Functors.children(m3) isa NamedTuple{(:a, :b, :c)}
28+
@test fmap(zero, m3) isa MacroTest.Trio
29+
30+
@test Optimisers.trainable(m3) isa NamedTuple{(:a, :b)}
31+
@test Optimisers.destructure(m3)[1] == [1, 2]
32+
33+
@test MacroTest.test(m3) == (c = [3.0],)
34+
end
35+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Random.seed!(0)
3838
include("layers/conv.jl")
3939
include("layers/upsample.jl")
4040
include("layers/show.jl")
41+
include("layers/macro.jl")
4142
end
4243

4344
@testset "outputsize" begin

test/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ end
723723
a::A
724724
b::A
725725
end
726-
Flux.@functor Model
726+
Flux.@layer Model
727727
(m::Model)(x) = m.a(x) .+ m.b(x)
728728

729729
d = Dense(1, 1)

0 commit comments

Comments
 (0)