Skip to content

Commit e2f58a8

Browse files
set expand option as default for @layer (#2532)
1 parent e2b3f06 commit e2f58a8

File tree

14 files changed

+98
-80
lines changed

14 files changed

+98
-80
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2]
2727

2828
model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)
2929

30-
optim = Flux.setup(Adam(), model)
30+
opt_state = Flux.setup(Adam(), model)
3131
for epoch in 1:1000
32-
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
32+
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, opt_state)
3333
end
3434

3535
plot(x -> 2x-x^3, -2, 2, legend=false)

docs/src/guide/saving.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ julia> Flux.@layer MyModel
2121
julia> MyModel() = MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2)));
2222
2323
julia> model = MyModel()
24-
MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) # 67 parameters
24+
MyModel(
25+
Chain(
26+
Dense(10 => 5, relu), # 55 parameters
27+
Dense(5 => 2), # 12 parameters
28+
),
29+
) # Total: 4 arrays, 67 parameters, 484 bytes.
2530
2631
julia> model_state = Flux.state(model);
2732

src/Flux.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using MacroTools, Reexport, ProgressLogging, SpecialFunctions
88
using MacroTools: @forward
99

1010
@reexport using NNlib
11+
using NNlib: conv, ∇conv_data, depthwiseconv, output_size
1112
using MLUtils
1213

1314
using Optimisers: Optimisers, destructure, freeze!, thaw!, adjust!, trainables, update!
@@ -27,7 +28,7 @@ export gradient
2728
CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice,
2829
XLADevice,
2930
# get_device, # we define get_device here for retrocompatibility
30-
# gpu_backend!, # have to define here due to https://github.com/JuliaPackaging/Preferences.jl/issues/39
31+
gpu_backend!,
3132
get_device_type,
3233
DeviceIterator
3334

@@ -118,7 +119,7 @@ include("losses/Losses.jl")
118119
using .Losses
119120

120121
include("devices.jl")
121-
export get_device, gpu_backend!
122+
export get_device
122123

123124
# Distributed Training
124125
include("distributed/backend.jl")

src/functor.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,6 @@ julia> m.bias
102102
"""
103103
cpu(x) = cpu_device()(x)
104104

105-
# TODO remove after https://github.com/LuxDL/Lux.jl/pull/1089
106-
ChainRulesCore.@non_differentiable cpu_device()
107-
108-
109-
# Remove when
110-
# https://github.com/JuliaPackaging/Preferences.jl/issues/39
111-
# is resolved
112-
function gpu_backend!(backend::String)
113-
@set_preferences!("gpu_backend" => backend)
114-
MLDataDevices.gpu_backend!(backend)
115-
end
116-
117105
"""
118106
gpu(m)
119107

src/layers/attention.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2}
7474
out_proj::P2
7575
end
7676

77-
@layer MultiHeadAttention
77+
@layer :noexpand MultiHeadAttention
7878

7979
function MultiHeadAttention(dims;
8080
nheads::Int = 8,

src/layers/basic.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end
6060
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
6161
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
6262

63-
@layer :expand Chain # the option :expand opts-in to container-style pretty-printing
63+
@layer Chain
6464

6565
(c::Chain)(x) = _applychain(c.layers, x)
6666
(c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...))
@@ -334,7 +334,7 @@ end
334334
Maxout(layers...) = Maxout(layers)
335335
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)
336336

337-
@layer :expand Maxout
337+
@layer Maxout
338338

339339
function (mo::Maxout)(input::AbstractArray)
340340
# Perhaps surprisingly, pairwise max broadcast is often faster,
@@ -381,7 +381,7 @@ struct SkipConnection{T,F}
381381
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
382382
end
383383

384-
@layer :expand SkipConnection
384+
@layer SkipConnection
385385

386386
function (skip::SkipConnection)(input)
387387
skip.connection(skip.layers(input), input)
@@ -575,7 +575,7 @@ end
575575
Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) =
576576
throw(ArgumentError("cannot construct a Parallel layer with no sub-layers"))
577577

578-
@layer :expand Parallel
578+
@layer Parallel
579579

580580
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument
581581

@@ -705,7 +705,7 @@ end
705705
end
706706
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)
707707

708-
@layer :expand PairwiseFusion
708+
@layer PairwiseFusion
709709

710710
Base.getindex(m::PairwiseFusion, i) = m.layers[i]
711711
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])

src/layers/conv.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using NNlib: conv, ∇conv_data, depthwiseconv, output_size
21

32
# pad dims of x with dims of y until ndims(x) == ndims(y)
43
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)

src/layers/macro.jl

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11

22
"""
3-
@layer Dense
4-
@layer :expand Chain
5-
@layer BatchNorm trainable=(β,γ)
6-
3+
@layer [showtype] MyModel [trainable=(field1,...)]
4+
75
This macro adds convenience functionality to a custom type to serve
8-
as a neural network layer, module, or entire model.
6+
as a neural network layer, as a module, or as an entire model.
97
10-
The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`.
8+
The optional keyword `trainable` allows you to specify which fields of your model can be trained,
9+
instead of assuming all `fieldnames(MyModel)` to trainable.
1110
Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes.
11+
This can be also be done by defining [`trainable(::MyModel)`](@ref Optimisers.trainable) for your type.
12+
13+
The macro also handles overloads of the 3-arg `show(::IO, ::MIME"text/plain", ::MyModel)` for pretty printing.
14+
The optional argument `showtype` can take any of the following values:
1215
13-
The macro also handles overloads of `show` for pretty printing.
14-
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
15-
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
16-
* To disable all `show` overloads, there is an `:ignore` option too.
16+
- `:expand` (default): This will expand the representation of container types like `Chain`,
17+
while maintaining a compat representation of types like `Dense` containing only arrays.
18+
- `:noexpand`: This is to be used in case your type contains other layers but you want to keep the representation simple.
19+
- `:ignore`: To opt out of the pretty printing.
1720
18-
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
21+
You probably still want to define 2-arg `show(::IO, ::MyModel)`, the macro does not touch this.
1922
2023
Note that re-running the macro with different options may not remove all methods, you will need to restart.
2124
@@ -26,16 +29,22 @@ julia> struct Trio; a; b; c end
2629
julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4))
2730
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
2831
29-
julia> Flux.@layer :expand Trio
32+
julia> Flux.@layer Trio
3033
3134
julia> tri # now the layer is printed like Chain
3235
Trio(
3336
Dense(2 => 1, tanh), # 3 parameters
3437
Dense(1 => 1; bias=false), # 1 parameters
3538
Dropout(0.4),
3639
) # Total: 3 arrays, 4 parameters, 240 bytes.
37-
```
3840
41+
julia> Flux.@layer :noexpand Trio trainable=(a,b)
42+
43+
julia> tri # now the layer is printed compactly
44+
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4)) # 4 parameters
45+
46+
julia> opt_state = Flux.setup(Adam(), tri); # `c` is not in the optimizer state
47+
```
3948
"""
4049
macro layer(exs...)
4150
_layer_macro(exs...)
@@ -46,14 +55,17 @@ function _layer_macro(exs...)
4655

4756
# These functions are defined in show.jl, and each return an expression overloading Base.show
4857
type, rest... = if exs[1] == QuoteNode(:expand)
49-
push!(out.args, _macro_big_show(esc(exs[2])))
58+
push!(out.args, _macro_big_show(esc(exs[2])))
59+
exs[2:end]
60+
elseif exs[1] == QuoteNode(:noexpand)
61+
push!(out.args, _macro_layer_show(esc(exs[2])))
5062
exs[2:end]
5163
elseif exs[1] == QuoteNode(:ignore)
5264
exs[2:end]
5365
elseif exs[1] isa QuoteNode
54-
error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)")
66+
error("`@layer` accepts only the options `:ignore`, `:noexpand`, and `:expand` before the layer type (to control `show`).")
5567
else
56-
push!(out.args, _macro_layer_show(esc(exs[1])))
68+
push!(out.args, _macro_big_show(esc(exs[1])))
5769
exs
5870
end
5971

src/layers/normalise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ end
198198
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
199199
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)
200200

201-
@layer LayerNorm
201+
@layer :noexpand LayerNorm
202202

203203
function (a::LayerNorm)(x::AbstractArray)
204204
ChainRulesCore.@ignore_derivatives if a.diag isa Scale

src/layers/recurrent.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ struct Model
158158
h0::AbstractVector
159159
end
160160
161-
Flux.@layer :expand Model
161+
Flux.@layer Model
162162
163163
(m::Model)(x) = m.rnn(x, m.h0)
164164
@@ -169,7 +169,7 @@ struct RNN{M}
169169
cell::M
170170
end
171171

172-
@layer :expand RNN
172+
@layer RNN
173173

174174
function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
175175
cell = RNNCell(in => out, σ; cell_kwargs...)
@@ -344,7 +344,7 @@ struct Model
344344
c0::AbstractVector
345345
end
346346
347-
Flux.@layer :expand Model
347+
Flux.@layer Model
348348
349349
(m::Model)(x) = m.lstm(x, (m.h0, m.c0))
350350
@@ -359,7 +359,7 @@ struct LSTM{M}
359359
cell::M
360360
end
361361

362-
@layer :expand LSTM
362+
@layer LSTM
363363

364364
function LSTM((in, out)::Pair; cell_kwargs...)
365365
cell = LSTMCell(in => out; cell_kwargs...)
@@ -531,7 +531,7 @@ struct GRU{M}
531531
cell::M
532532
end
533533

534-
@layer :expand GRU
534+
@layer GRU
535535

536536
function GRU((in, out)::Pair; cell_kwargs...)
537537
cell = GRUCell(in => out; cell_kwargs...)
@@ -669,7 +669,7 @@ struct GRUv3{M}
669669
cell::M
670670
end
671671

672-
@layer :expand GRUv3
672+
@layer GRUv3
673673

674674
function GRUv3((in, out)::Pair; cell_kwargs...)
675675
cell = GRUv3Cell(in => out; cell_kwargs...)

0 commit comments

Comments
 (0)