Skip to content

Commit 2e3c150

Browse files
committed
forgotten changes
1 parent 34d569d commit 2e3c150

File tree

2 files changed

+48
-41
lines changed

2 files changed

+48
-41
lines changed

src/Fluxperimental.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export shinkansen!
1212
include("chain.jl")
1313

1414
include("compact.jl")
15+
export @compact
1516

1617
include("noshow.jl")
1718
export NoShow

src/compact.jl

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,64 +4,66 @@ import Flux: _big_show
44
@compact(forward::Function; name=nothing, parameters...)
55
66
Creates a layer by specifying some `parameters`, in the form of keywords,
7-
and (usually as a `do` block) a function for the forward pass.
7+
and a function for the forward pass (often as a `do` block).
8+
89
You may think of `@compact` as a specialized `let` block creating local variables
910
that are trainable in Flux.
1011
Declared variable names may be used within the body of the `forward` function.
1112
12-
Here is a linear model:
13+
# Examples
14+
15+
Here is a linear model, equivalent to `Flux.Scale`:
1316
1417
```
15-
r = @compact(w = rand(3)) do x
16-
w .* x
17-
end
18-
r([1, 1, 1]) # x is set to [1, 1, 1].
18+
using Flux, Fluxperimental
19+
20+
w = rand(3)
21+
sc = @compact(x -> x .* w; w)
22+
23+
sc([1 10 100]) # 3×3 Matrix as output.
24+
ans ≈ Flux.Scale(w)([1 10 100]) # equivalent Flux layer
1925
```
2026
21-
Here is a linear model with bias and activation:
27+
Here is a linear model with bias and activation, equivalent to Flux's `Dense` layer.
28+
The forward pass function is now written as a do block, instead of `x -> begin y = W * x; ...`
2229
2330
```
24-
d_in = 5
31+
d_in = 3
2532
d_out = 7
26-
d = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
33+
layer = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
2734
y = W * x
2835
act.(y .+ b)
2936
end
30-
d(ones(5, 10)) # 7×10 Matrix as output.
31-
d([1,2,3,4,5]) ≈ Dense(d.variables.W, zeros(7), relu)([1,2,3,4,5]) # Equivalent to a dense layer
37+
38+
den = Dense(layer.variables.W, zeros(7), relu)([1,2,3]) # equivalent Flux layer
39+
layer(ones(3, 10)) ≈ layer(ones(3, 10)) # 7×10 Matrix as output.
3240
```
33-
```
3441
35-
Finally, here is a simple MLP:
42+
Finally, here is a simple MLP, equivalent to a `Chain` with 5 `Dense` layers:
3643
3744
```
38-
using Flux
39-
40-
n_in = 1
41-
n_out = 1
45+
d_in = 1
4246
nlayers = 3
4347
4448
model = @compact(
45-
w1=Dense(n_in, 128),
46-
w2=[Dense(128, 128) for i=1:nlayers],
47-
w3=Dense(128, n_out),
48-
act=relu
49+
lay1 = Dense(d_in => 64),
50+
lay234 = [Dense(64 => 64) for i=1:nlayers],
51+
wlast = rand32(64),
4952
) do x
50-
embed = act(w1(x))
51-
for w in w2
52-
embed = act(w(embed))
53+
y = tanh.(lay1(x))
54+
for lay in lay234
55+
y = relu.(lay(y))
5356
end
54-
out = w3(embed)
55-
return out
57+
return wlast' * y
5658
end
5759
58-
model(randn(n_in, 32)) # 1×32 Matrix as output.
60+
model(randn(Float32, d_in, 8)) # 1×8 array as output.
5961
```
6062
61-
We can train this model just like any `Chain`:
63+
We can train this model just like any `Chain`, for example:
6264
6365
```
64-
data = [([x], 2x-x^3) for x in -2:0.1f0:2]
66+
data = [([x], [2x-x^3]) for x in -2:0.1f0:2]
6567
optim = Flux.setup(Adam(), model)
6668
6769
for epoch in 1:1000
@@ -71,19 +73,23 @@ end
7173
To specify a custom printout for the model, you may find [`NoShow`](@ref) useful.
7274
"""
7375
macro compact(_exs...)
76+
_compact(_exs...) |> esc
77+
end
78+
79+
function _compact(_exs...)
7480
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
75-
isempty(_exs) && error("expects at least two expressions: a function and at least one keyword")
81+
isempty(_exs) && error("@compact expects at least two expressions: a function and at least one keyword")
7682
if Meta.isexpr(_exs[1], :parameters)
77-
length(_exs) >= 2 || error("expects an anonymous function")
83+
length(_exs) >= 2 || error("@compact expects an anonymous function")
7884
fex = _exs[2]
7985
_kwexs = (_exs[1], _exs[3:end]...)
8086
else
8187
fex = _exs[1]
8288
_kwexs = _exs[2:end]
8389
end
84-
Meta.isexpr(fex, :(->)) || error("expects an anonymous function")
85-
isempty(_kwexs) && error("expects keyword arguments")
86-
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("expects only keyword arguments")
90+
Meta.isexpr(fex, :(->)) || error("@compact expects an anonymous function")
91+
isempty(_kwexs) && error("@compact expects keyword arguments")
92+
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("@compact expects only keyword arguments")
8793

8894
# process keyword arguments
8995
if Meta.isexpr(_kwexs[1], :parameters) # handle keyword arguments provided after semicolon
@@ -101,20 +107,20 @@ macro compact(_exs...)
101107
fex_args = fex.args[1]
102108
isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ")
103109
catch e
104-
@warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments"
105-
""
110+
@warn """@compact's function stringifying does not yet handle all cases. Falling back to "?" """ maxlog=1
111+
"?"
106112
end
107-
block = string(Base.remove_linenums!(fex).args[2])
113+
block = string(Base.remove_linenums!(fex).args[2]) # TODO make this remove macro comments
108114

109115
# edit expressions
110116
vars = map(ex -> ex.args[1], kwexs)
111-
fex = supportself(fex, vars)
117+
fex = _supportself(fex, vars)
112118

113119
# assemble
114-
return esc(:($CompactLayer($fex, ($input, $block); $(kwexs...))))
120+
return :($CompactLayer($fex, ($input, $block); $(kwexs...)))
115121
end
116122

117-
function supportself(fex::Expr, vars)
123+
function _supportself(fex::Expr, vars)
118124
@gensym self
119125
@gensym curried_f
120126
# To avoid having to manipulate fex's arguments and body explicitly, we form a curried function first
@@ -174,7 +180,7 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
174180
print(io, " "^indent, post)
175181
end
176182

177-
input != "" && print(io, " do ", input)
183+
print(io, " do ", input)
178184
if block != ""
179185
block_to_print = block[6:end]
180186
# Increase indentation of block according to `indent`:

0 commit comments

Comments
 (0)