1
+
2
+ """
3
+ @layer Dense
4
+ @layer :expand Chain
5
+ @layer BatchNorm trainable=(β,γ)
6
+ @layer Struct functor=(α,β) trainable=(β,)
7
+
8
+ This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
9
+ When you define a new layer, this tells Flux to explore inside it
10
+ to see the parameters it trains, and also to move them to the GPU, change precision, etc.
11
+
12
+ Some "keywords" allow control of the recursion:
13
+ * If some fields look like parameters but should not be trained,
14
+ then `Optimisers.trainable` lets you specify fields to include, and ignore the rest.
15
+ * We can likewise add restructions to `Functors.functor`, but not yet written.
16
+ * In fact you can provide an arbitrary keyword with this syntax, and it will
17
+ overload this function alla `trainable`... that might be a terrible idea.
18
+
19
+ It also handles overloads of `show` for pretty printing.
20
+ * By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
21
+ * If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
22
+ * To disable all `show` overloads, maybe we want a `:ignore` option too.
23
+
24
+ (You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
25
+ """
26
+ macro layer (exs... )
27
+ out = quote end
28
+
29
+ # These functions are defined in show.jl, and each return an expression overloading Base.show
30
+ type, rest... = if exs[1 ] == QuoteNode (:expand )
31
+ push! (out. args, _macro_big_show (esc (exs[2 ])))
32
+ exs[2 : end ]
33
+ elseif exs[1 ] == QuoteNode (:ignore )
34
+ exs[2 : end ]
35
+ elseif exs[1 ] isa QuoteNode
36
+ error (" before the type, only accepted options are `:expand` and `:ignore`" )
37
+ else
38
+ push! (out. args, _macro_layer_show (esc (exs[1 ])))
39
+ exs
40
+ end
41
+
42
+ # This function exists only for depwarns when you use @functor directly
43
+ push! (out. args, :(Flux. _check_new_macro (:: $ (esc (type))) = nothing )) # scope is weird ?? can't use $ on func name?
44
+
45
+ i = findfirst (ex -> Meta. isexpr (ex, :(= )) && ex. args[1 ] == :functor , rest)
46
+ if isnothing (i)
47
+ push! (out. args, _macro_functor (esc (type)))
48
+ else
49
+ push! (out. args, _macro_functor (esc (type), rest[i]. args[2 ]))
50
+ end
51
+ for j in 1 : length (rest)
52
+ j == i && continue
53
+ ex = rest[j]
54
+ Meta. isexpr (ex, :(= )) || error (" expected keyword = fields" )
55
+ if ex. args[1 ] == :trainable
56
+ push! (out. args, _macro_trainable (type, trainable, ex. args[2 ])) # pass the function "trainable" not the symbol
57
+ else
58
+ error ()
59
+ # @warn "defining a method for $(ex.args[1]) in your scope" # ??
60
+ # push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2]))
61
+ end
62
+ end
63
+
64
+ out
65
+ end
66
+
67
+ # Temporary depwarn function:
68
+
69
+ function _check_new_macro (x:: T ) where T
70
+ Functors. isleaf (x) && return
71
+ @warn " you used @functor for this type, but should now use @layer" T maxlog= 1 _id= hash (T)
72
+ end
73
+ _check_new_macro (:: Tuple ) = nothing # defined by Functors.jl, not by users
74
+ _check_new_macro (:: NamedTuple ) = nothing
75
+ _check_new_macro (:: Transpose ) = nothing
76
+ _check_new_macro (:: Adjoint ) = nothing
77
+ _check_new_macro (:: Ref ) = nothing
78
+
79
+ # @layer's code for Functors & Adapt
80
+ # Unlike @functor, _default_functor doesn't need to eval anything
81
+
82
+ function _macro_functor (type)
83
+ quote
84
+ Functors. functor (:: Type{T} , x) where {T<: $type } = _default_functor (T, x)
85
+ Adapt. adapt_structure (to, layer:: $type ) = fmap (adapt (to), layer)
86
+ end
87
+ end
88
+
89
+ function _macro_functor (type, fields)
90
+ error (" the equivalent of @functor Layer (:x,) isn't written yet, sorry" )
91
+ end
92
+
93
+ function _default_functor (:: Type{T} , x) where {T}
94
+ if @generated
95
+ F = fieldnames (T)
96
+ args = map (sy -> :(getfield (x, $ (QuoteNode (sy)))), F)
97
+ C = Base. typename (T). name # constructor
98
+ recon = VERSION > v " 1.9-" ? :(Splat ($ C)) : :(Base. splat ($ C))
99
+ :((NamedTuple {$F} (($ (args... ),)), $ recon))
100
+ else
101
+ # Getting this parameterless type takes about 2μs, every time:
102
+ namedtuple (x), Base. splat (Base. typename (T). wrapper)
103
+ end
104
+ end
105
+
106
+ function namedtuple (x:: T ) where T
107
+ F = fieldnames (T)
108
+ NamedTuple {F} (map (sy -> getfield (x, sy), F))
109
+ end
110
+
111
+ # @layer's code for Optimisers.trainable, and perhaps anything else,
112
+ # with the pattern that keywords mean function names & what fields they pick.
113
+
114
+ function _macro_trainable (type, fun, fields)
115
+ Meta. isexpr (fields, :tuple ) || error (" expected a tuple of field names" )
116
+ symbols = Tuple (map (_noquotenode, fields. args))
117
+ quoted = map (QuoteNode, symbols)
118
+ gets = [:(getfield (x, $ f)) for f in quoted]
119
+ quote
120
+ # $fun(x::$type) = NamedTuple{$names}(($(gets...),))
121
+ Flux. trainable (x:: $type ) = NamedTuple {$symbols} (($ (gets... ),)) # ?? scope is weird
122
+ end
123
+ end
124
+ _macro_trainable (type, fun, field:: Union{Symbol,QuoteNode} ) = _macro_trainable (type, fun, :(($ field,))) # lets you forget a comma
125
+
126
+ _noquotenode (s:: Symbol ) = s
127
+ _noquotenode (q:: QuoteNode ) = q. value # lets you write trainable=(:x,:y) instead of (x,y)
128
+ _noquotenode (ex) = error (" expected a symbol, got $ex " )
129
+
130
+
131
+
132
+
133
+
134
+
135
+ # @big_show Chain
136
+ # @big_show Parallel
137
+ # @big_show SkipConnection
138
+ # @big_show Recur
139
+ # @big_show Maxout
140
+
141
+
142
+
143
+
144
+ """
145
+ @big_show MyContainer
146
+
147
+ This macro lets you opt-in to Flux's fancy printing.
148
+
149
+ When `model::MyContainer` is returned at the REPL it will be treated like `Chain`,
150
+ and the printing routine will recursively unfold its children.
151
+ This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`.
152
+
153
+ Custom layers which do not contain other layers (more like `Dense` than like `Chain`)
154
+ need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`.
155
+
156
+ # Example
157
+ ```jldoctest
158
+ julia> struct Trio{A,B,C}; a::A; b::B; c::C end
159
+
160
+ julia> Flux.@functor Trio
161
+
162
+ julia> Flux.@big_show Trio
163
+
164
+ julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax)
165
+ Trio(
166
+ Dense(10 => 5, tanh), # 55 parameters
167
+ Dense(5 => 2), # 12 parameters
168
+ NNlib.softmax,
169
+ ) # Total: 4 arrays, 67 parameters, 492 bytes.
170
+ ```
171
+
172
+ Note that there is no automatic method for 2-arg `show`, and thus
173
+ something like `(tri, tri)` will print all the type parameters.
174
+
175
+ However, `Chain(tri, tri)` will always use Flux's recursive printing,
176
+ even without using this macro: `Chain` is the entry point.
177
+ """
0 commit comments