Skip to content

Commit 99b0260

Browse files
committed
review suggestions
1 parent bb46e26 commit 99b0260

File tree

4 files changed

+25
-16
lines changed

4 files changed

+25
-16
lines changed

docs/src/models/advanced.md

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ function (m::CustomModel)(x)
1818
return m.chain(x) + x
1919
end
2020

21-
# Call @functor to allow for training. Described below in more detail.
22-
Flux.@functor CustomModel
21+
# Call @layer to allow for training. Described below in more detail.
22+
Flux.@layer CustomModel
2323
```
2424

2525
You can then use the model like:
@@ -41,29 +41,36 @@ By default all the fields in the `Affine` type are collected as its parameters,
4141
The first way of achieving this is through overloading the `trainable` function.
4242

4343
```julia-repl
44-
julia> @functor Affine
44+
julia> @layer Affine
4545
4646
julia> a = Affine(rand(3,3), rand(3))
4747
Affine{Array{Float64,2},Array{Float64,1}}([0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955])
4848
4949
julia> Flux.params(a) # default behavior
5050
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955]])
5151
52-
julia> Flux.trainable(a::Affine) = (a.W,)
52+
julia> Flux.trainable(a::Affine) = (W = a.W,) # must return a NamedTuple
5353
5454
julia> Flux.params(a)
5555
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297]])
5656
```
5757

5858
Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`.
5959

60-
Another way of achieving this is through the `@functor` macro directly. Here, we can mark the fields we are interested in by grouping them in the second argument:
60+
The exact same method of `trainable` can also be defined using the macro, for convenience:
6161

6262
```julia
63-
Flux.@functor Affine (W,)
63+
Flux.@layer Affine trainable=(W,)
6464
```
6565

66-
However, doing this requires the `struct` to have a corresponding constructor that accepts those parameters.
66+
There is a second, more severe, kind of restriction possible:
67+
68+
```
69+
Flux.@layer Affine children=(W,)
70+
```
71+
72+
This is equivalent to `Functors.@functor Affine (W,)`. It means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This is not usually recommended.
73+
6774

6875
## Freezing Layer Parameters
6976

@@ -127,9 +134,9 @@ Join(combine, paths...) = Join(combine, paths)
127134
```
128135
Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field.
129136

130-
The next step is to use [`Functors.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
137+
The next step is to use [`Functors.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
131138
```julia
132-
Flux.@functor Join
139+
Flux.@layer Join
133140
```
134141

135142
Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results.
@@ -182,7 +189,7 @@ model(xs)
182189

183190
Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs.
184191

185-
We start by following the same steps as the `Join` layer: define a struct, use [`Functors.@functor`](@ref), and define the forward pass.
192+
We start by following the same steps as the `Join` layer: define a struct, use [`@layer`](@ref), and define the forward pass.
186193
```julia
187194
using Flux
188195
using CUDA
@@ -194,7 +201,7 @@ end
194201

195202
Split(paths...) = Split(paths)
196203

197-
Flux.@functor Split
204+
Flux.@layer Split
198205

199206
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)
200207
```

src/layers/macro.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ macro layer(exs...)
8484
elseif ex.args[1] == :functor
8585
error("Can't use `functor=(...)` as a keyword to `@layer`. Use `childen=(...)` to define a method for `functor`.")
8686
else
87-
@warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
88-
esc(ex.args[1])
87+
error("`@layer` cannot define a method for `$(ex.args[1])` at the moment, sorry.")
88+
# @warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
89+
# esc(ex.args[1])
8990
end
9091
push!(out.args, _macro_trainable(esc(type), name, ex.args[2]))
9192
end

src/layers/recurrent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{
206206
return h, reshape_cell_output(h, x)
207207
end
208208

209-
@layer RNNCell # trainable=(Wi, Wh, b)
209+
@layer RNNCell # state0 is trainable, see issue 807 about this.
210210

211211
function Base.show(io::IO, l::RNNCell)
212212
print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1))

test/layers/macro.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ module MacroTest
77
@layer :expand Duo
88

99
struct Trio; a; b; c end
10-
@layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget
10+
# @layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget
11+
@layer Trio trainable=(a,b) # defining a method for test is made an error, for now
1112

1213
struct TwoThirds; a; b; c; end
1314
end
@@ -28,7 +29,7 @@ end
2829
@test Optimisers.trainable(m3) isa NamedTuple{(:a, :b)}
2930
@test Optimisers.destructure(m3)[1] == [1, 2]
3031

31-
@test MacroTest.test(m3) == (c = [3.0],)
32+
# @test MacroTest.test(m3) == (c = [3.0],) # removed, for now
3233

3334
m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6])
3435
# Check that we can use the macro with a qualified type name, outside the defining module:

0 commit comments

Comments
 (0)