Skip to content

Commit bd91d81

Browse files
committed
tidy up, add NEWS
1 parent cd28cc7 commit bd91d81

File tree

5 files changed

+22
-23
lines changed

5 files changed

+22
-23
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## v0.13.7
44
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
5+
* New macro `Flux.@layer` which should be used in place of `@functor`.
6+
This also adds `show` methods for pretty printing.
57

68
## v0.13.4
79
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)

src/layers/macro.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,43 @@
44
@layer :expand Chain
55
@layer BatchNorm trainable=(β,γ)
66
@layer Struct children=(α,β) trainable=(β,)
7-
8-
This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
7+
8+
This macro replaces most uses of `@functor`. Its basic purpose is the same:
99
When you define a new layer, this tells Flux to explore inside it
1010
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
1111
Like `@functor`, this assumes your struct has the default constructor, to enable re-building.
1212
13-
Some "keywords" allow control of the recursion:
13+
Some "keywords" allow control of the recursion.
1414
* If some fields look like parameters but should not be trained,
1515
then `trainable` lets you specify fields to include, and ignore the rest.
1616
* You can likewise add restructions to Functors's `children` (although this is seldom a good idea).
1717
18+
The defaults are `fieldnames(T)` for both. They must be subsets of this, and `trainable` must be a subset of `children`.
19+
1820
It also handles overloads of `show` for pretty printing.
1921
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
2022
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
2123
* To disable all `show` overloads, there is an `:ignore` option too.
2224
25+
Note that re-running the macro with different options does not overwrite all methods, you will need to restart.
2326
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
2427
25-
Note that re-running the macro with different options does not overwrite all methods, you will need to restart.
26-
2728
# Example
2829
```jldoctest
2930
julia> struct Trio; a; b; c end
3031
31-
julia> tri = Trio(Dense([1.1 2.2],), Dense([3.3;;], false), Dropout(0.4))
32-
Trio(Dense(1 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
32+
julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense([3.3;;], false), Dropout(0.4))
33+
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
3334
34-
julia> Flux.destructure(tri) # parameters not visible to Flux
35+
julia> Flux.destructure(tri) # parameters are not yet visible to Flux
3536
(Bool[], Restructure(Trio, ..., 0))
3637
3738
julia> Flux.@layer :expand Trio
3839
3940
julia> Flux.destructure(tri) # now gpu, train!, etc will see inside too
4041
([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4))
4142
42-
julia> tri
43+
julia> tri # and layer is printed like Chain
4344
Trio(
4445
Dense(2 => 1), # 3 parameters
4546
Dense(1 => 1; bias=false), # 1 parameters
@@ -58,7 +59,7 @@ macro layer(exs...)
5859
elseif exs[1] == QuoteNode(:ignore)
5960
exs[2:end]
6061
elseif exs[1] isa QuoteNode
61-
error("before the type, only accepted options are `:expand` and `:ignore`")
62+
error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)")
6263
else
6364
push!(out.args, _macro_layer_show(esc(exs[1])))
6465
exs
@@ -76,12 +77,14 @@ macro layer(exs...)
7677
for j in 1:length(rest)
7778
j == i && continue
7879
ex = rest[j]
79-
Meta.isexpr(ex, :(=)) || error("expected keyword = fields")
80+
Meta.isexpr(ex, :(=)) || error("The macro `@layer` expects here `keyword = (fields...,)`, got $ex")
8081

8182
name = if ex.args[1] == :trainable
8283
:(Optimisers.trainable)
84+
elseif ex.args[1] == :functor
85+
error("Can't use `functor=(...)` as a keyword to `@layer`. Use `childen=(...)` to define a method for `functor`.")
8386
else
84-
@warn "trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
87+
@warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
8588
esc(ex.args[1])
8689
end
8790
push!(out.args, _macro_trainable(esc(type), name, ex.args[2]))
@@ -94,7 +97,7 @@ end
9497

9598
function _check_new_macro(x::T) where T
9699
Functors.isleaf(x) && return
97-
@warn "This type should now use Flux.@layer instead of @functor" T maxlog=1 _id=hash(T)
100+
Base.depwarn("This type should probably now use `Flux.@layer` instead of `@functor`: $T", Symbol("@functor"))
98101
end
99102
_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users
100103
_check_new_macro(::NamedTuple) = nothing
@@ -159,11 +162,10 @@ function _macro_trainable(type, fun, fields)
159162
gets = [:(getfield(x, $f)) for f in quoted]
160163
quote
161164
$fun(x::$type) = NamedTuple{$symbols}(($(gets...),))
162-
# Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
163165
end
164166
end
165167
_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma
166168

167169
_noquotenode(s::Symbol) = s
168170
_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y)
169-
_noquotenode(ex) = error("expected a symbol, got $ex")
171+
_noquotenode(ex) = error("expected a symbol here, as a field name, but got $ex")

src/layers/normalise.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ function Dropout(p; dims=:, rng = default_rng_value())
104104
end
105105

106106
@layer Dropout trainable=()
107-
# trainable(a::Dropout) = (;)
108107

109108
function (a::Dropout)(x)
110109
_isactive(a) || return x
@@ -159,7 +158,6 @@ AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
159158
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
160159

161160
@layer AlphaDropout trainable=()
162-
# trainable(a::AlphaDropout) = (;)
163161

164162
function (a::AlphaDropout)(x::AbstractArray{T}) where T
165163
_isactive(a) || return x
@@ -355,7 +353,6 @@ function BatchNorm(chs::Int, λ=identity;
355353
end
356354

357355
@layer BatchNorm trainable=(β,γ)
358-
# trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
359356

360357
function (BN::BatchNorm)(x)
361358
@assert size(x, ndims(x)-1) == BN.chs
@@ -449,7 +446,6 @@ function InstanceNorm(chs::Int, λ=identity;
449446
end
450447

451448
@layer InstanceNorm trainable=(β,γ)
452-
# trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)
453449

454450
function (l::InstanceNorm)(x)
455451
@assert ndims(x) > 2
@@ -528,7 +524,6 @@ mutable struct GroupNorm{F,V,N,W}
528524
end
529525

530526
@layer GroupNorm trainable=(β,γ)
531-
# trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)
532527

533528
function GroupNorm(chs::Int, G::Int, λ=identity;
534529
initβ=zeros32, initγ=ones32,

src/layers/show.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ end
5555
_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
5656
_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
5757
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
58-
# _show_leaflike(::Scale) = true # appears inside LayerNorm
5958
_show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays
6059

6160
_show_children(x) = trainable(x)

test/layers/macro.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ module MacroTest
1010
@layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget
1111

1212
struct TwoThirds; a; b; c; end
13-
@layer :expand TwoThirds children=(a,c) trainable=(a)
14-
1513
end
1614

1715
@testset "@layer macro" begin
@@ -33,6 +31,9 @@ end
3331
@test MacroTest.test(m3) == (c = [3.0],)
3432

3533
m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6])
34+
# Check that we can use the macro with a qualified type name, outside the defining module:
35+
Flux.@layer :expand MacroTest.TwoThirds children=(:a,:c) trainable=(:a) # documented as (a,c) but allow quotes
36+
3637
@test Functors.children(m23) == (a = [1 2], c = [5 6])
3738
m23re = Functors.functor(m23)[2]((a = [10 20], c = [50 60]))
3839
@test m23re isa MacroTest.TwoThirds

0 commit comments

Comments
 (0)