3
3
@layer Dense
4
4
@layer :expand Chain
5
5
@layer BatchNorm trainable=(β,γ)
6
- @layer Struct functor =(α,β) trainable=(β,)
6
+ @layer Struct children =(α,β) trainable=(β,)
7
7
8
8
This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
9
9
When you define a new layer, this tells Flux to explore inside it
10
10
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
11
11
12
12
Some "keywords" allow control of the recursion:
13
13
* If some fields look like parameters but should not be trained,
14
- then `Optimisers.trainable` lets you specify fields to include, and ignore the rest.
15
- * We can likewise add restructions to `Functors.functor`, but not yet written.
16
- * In fact you can provide an arbitrary keyword with this syntax, and it will
17
- overload this function alla `trainable`... that might be a terrible idea.
14
+ then `trainable` lets you specify fields to include, and ignore the rest.
15
+ * We can likewise add restructions to Functors's `children`,
16
+ but this is not yet written (as this is seldom a good idea).
18
17
19
18
It also handles overloads of `show` for pretty printing.
20
19
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
21
20
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
22
- * To disable all `show` overloads, maybe we want a `:ignore` option too.
21
+ * To disable all `show` overloads, there is an `:ignore` option too.
23
22
24
23
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
24
+
25
+ Note that re-running the macro with different options does not overwrite all methods, you will need to restart.
26
+
27
+ # Example
28
+ ```jldoctest
29
+ julia> struct Trio; a; b; c end
30
+
31
+ julia> tri = Trio(Dense([1.1 2.2],), Dense([3.3;;], false), Dropout(0.4))
32
+ Trio(Dense(1 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
33
+
34
+ julia> Flux.destructure(tri) # parameters not visible to Flux
35
+ (Bool[], Restructure(Trio, ..., 0))
36
+
37
+ julia> Flux.@layer :expand Trio
38
+
39
+ julia> Flux.destructure(tri) # now gpu, train!, etc will see inside too
40
+ ([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4))
41
+
42
+ julia> tri
43
+ Trio(
44
+ Dense(2 => 1), # 3 parameters
45
+ Dense(1 => 1; bias=false), # 1 parameters
46
+ Dropout(0.4),
47
+ ) # Total: 3 arrays, 4 parameters, 224 bytes.
48
+ ```
49
+
25
50
"""
26
51
macro layer (exs... )
27
52
out = quote end
@@ -40,10 +65,10 @@ macro layer(exs...)
40
65
end
41
66
42
67
# This function exists only for depwarns when you use @functor directly
43
- push! (out. args, :(Flux. _check_new_macro (:: $ (esc (type))) = nothing )) # scope is weird ?? can't use $ on func name?
68
+ push! (out. args, :(Flux. _check_new_macro (:: $ (esc (type))) = nothing ))
44
69
45
- i = findfirst (ex -> Meta. isexpr (ex, :(= )) && ex. args[1 ] == :functor , rest)
46
- if isnothing (i)
70
+ i = findfirst (ex -> Meta. isexpr (ex, :(= )) && ex. args[1 ] == :children , rest)
71
+ if isnothing (i) # then default like @functor Layer
47
72
push! (out. args, _macro_functor (esc (type)))
48
73
else
49
74
push! (out. args, _macro_functor (esc (type), rest[i]. args[2 ]))
@@ -52,13 +77,14 @@ macro layer(exs...)
52
77
j == i && continue
53
78
ex = rest[j]
54
79
Meta. isexpr (ex, :(= )) || error (" expected keyword = fields" )
55
- if ex. args[1 ] == :trainable
56
- push! (out. args, _macro_trainable (type, trainable, ex. args[2 ])) # pass the function "trainable" not the symbol
80
+
81
+ name = if ex. args[1 ] == :trainable
82
+ :(Optimisers. trainable)
57
83
else
58
- error ()
59
- # @warn "defining a method for $(ex.args[1]) in your scope" # ??
60
- # push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2]))
84
+ @warn " trying to define a method for `$(ex. args[1 ]) ` in your scope... this is experimental" maxlog= 1
85
+ esc (ex. args[1 ])
61
86
end
87
+ push! (out. args, _macro_trainable (esc (type), name, ex. args[2 ]))
62
88
end
63
89
64
90
out
@@ -72,17 +98,16 @@ function _check_new_macro(x::T) where T
72
98
end
73
99
_check_new_macro (:: Tuple ) = nothing # defined by Functors.jl, not by users
74
100
_check_new_macro (:: NamedTuple ) = nothing
75
- _check_new_macro (:: Transpose ) = nothing
76
- _check_new_macro (:: Adjoint ) = nothing
101
+ _check_new_macro (:: AbstractArray ) = nothing
77
102
_check_new_macro (:: Ref ) = nothing
78
103
79
104
# @layer's code for Functors & Adapt
80
105
# Unlike @functor, _default_functor doesn't need to eval anything
81
106
82
107
function _macro_functor (type)
83
108
quote
84
- Functors. functor (:: Type{T} , x) where {T<: $type } = _default_functor (T, x)
85
- Adapt. adapt_structure (to, layer:: $type ) = fmap (adapt (to), layer)
109
+ Functors. functor (:: Type{T} , x) where {T<: $type } = $ _default_functor (T, x)
110
+ Adapt. adapt_structure (to, layer:: $type ) = $ fmap ($ adapt (to), layer)
86
111
end
87
112
end
88
113
@@ -94,12 +119,13 @@ function _default_functor(::Type{T}, x) where {T}
94
119
if @generated
95
120
F = fieldnames (T)
96
121
args = map (sy -> :(getfield (x, $ (QuoteNode (sy)))), F)
97
- C = Base. typename (T). name # constructor
122
+ C = Base. typename (T). wrapper # constructor
98
123
recon = VERSION > v " 1.9-" ? :(Splat ($ C)) : :(Base. splat ($ C))
99
124
:((NamedTuple {$F} (($ (args... ),)), $ recon))
100
125
else
101
126
# Getting this parameterless type takes about 2μs, every time:
102
- namedtuple (x), Base. splat (Base. typename (T). wrapper)
127
+ spl = VERSION > v " 1.9-" ? Splat : Base. splat
128
+ namedtuple (x), spl (Base. typename (T). wrapper)
103
129
end
104
130
end
105
131
@@ -117,61 +143,12 @@ function _macro_trainable(type, fun, fields)
117
143
quoted = map (QuoteNode, symbols)
118
144
gets = [:(getfield (x, $ f)) for f in quoted]
119
145
quote
120
- # $fun(x::$type) = NamedTuple{$names }(($(gets...),))
121
- Flux. trainable (x:: $type ) = NamedTuple {$symbols} (($ (gets... ),)) # ?? scope is weird
146
+ $ fun (x:: $type ) = NamedTuple {$symbols } (($ (gets... ),))
147
+ # Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
122
148
end
123
149
end
124
150
_macro_trainable (type, fun, field:: Union{Symbol,QuoteNode} ) = _macro_trainable (type, fun, :(($ field,))) # lets you forget a comma
125
151
126
152
_noquotenode (s:: Symbol ) = s
127
153
_noquotenode (q:: QuoteNode ) = q. value # lets you write trainable=(:x,:y) instead of (x,y)
128
154
_noquotenode (ex) = error (" expected a symbol, got $ex " )
129
-
130
-
131
-
132
-
133
-
134
-
135
- # @big_show Chain
136
- # @big_show Parallel
137
- # @big_show SkipConnection
138
- # @big_show Recur
139
- # @big_show Maxout
140
-
141
-
142
-
143
-
144
- """
145
- @big_show MyContainer
146
-
147
- This macro lets you opt-in to Flux's fancy printing.
148
-
149
- When `model::MyContainer` is returned at the REPL it will be treated like `Chain`,
150
- and the printing routine will recursively unfold its children.
151
- This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`.
152
-
153
- Custom layers which do not contain other layers (more like `Dense` than like `Chain`)
154
- need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`.
155
-
156
- # Example
157
- ```jldoctest
158
- julia> struct Trio{A,B,C}; a::A; b::B; c::C end
159
-
160
- julia> Flux.@functor Trio
161
-
162
- julia> Flux.@big_show Trio
163
-
164
- julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax)
165
- Trio(
166
- Dense(10 => 5, tanh), # 55 parameters
167
- Dense(5 => 2), # 12 parameters
168
- NNlib.softmax,
169
- ) # Total: 4 arrays, 67 parameters, 492 bytes.
170
- ```
171
-
172
- Note that there is no automatic method for 2-arg `show`, and thus
173
- something like `(tri, tri)` will print all the type parameters.
174
-
175
- However, `Chain(tri, tri)` will always use Flux's recursive printing,
176
- even without using this macro: `Chain` is the entry point.
177
- """
0 commit comments