Skip to content

Commit f49e81e

Browse files
authored
Print channel dimensions of Dense like those of Conv (#1658)
* print channel dims of Dense like Conv, and accept as input * do the same for Bilinear * fix tests * fix tests * docstring * change a few more * update * docs * rm circular ref * fixup * news + fixes
1 parent b35b23b commit f49e81e

File tree

17 files changed

+142
-130
lines changed

17 files changed

+142
-130
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ been removed in favour of MLDatasets.jl.
77
* `flatten` is not exported anymore due to clash with Iterators.flatten.
88
* Remove Juno.jl progress bar support as it is now obsolete.
99
* `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable.
10+
* Notation `Dense(2 => 3, σ)` for channels matches `Conv`; the equivalent `Dense(2, 3, σ)` still works.
1011
* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874).
1112
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
1213
* Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights.

docs/src/gpu.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`)
3939
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
4040

4141
```julia
42-
d = Dense(10, 5, σ)
42+
d = Dense(10 => 5, σ)
4343
d = fmap(cu, d)
4444
d.weight # CuArray
4545
d(cu(rand(10))) # CuArray output
4646

47-
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
47+
m = Chain(Dense(10 => 5, σ), Dense(5 => 2), softmax)
4848
m = fmap(cu, m)
4949
d(cu(rand(10)))
5050
```
@@ -54,8 +54,8 @@ As a convenience, Flux provides the `gpu` function to convert models and data to
5454
```julia
5555
julia> using Flux, CUDA
5656

57-
julia> m = Dense(10,5) |> gpu
58-
Dense(10, 5)
57+
julia> m = Dense(10, 5) |> gpu
58+
Dense(10 => 5)
5959

6060
julia> x = rand(10) |> gpu
6161
10-element CuArray{Float32,1}:

docs/src/models/advanced.md

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ this using the slicing features `Chain` provides:
7474

7575
```julia
7676
m = Chain(
77-
Dense(784, 64, relu),
78-
Dense(64, 64, relu),
79-
Dense(32, 10)
80-
)
77+
Dense(784 => 64, relu),
78+
Dense(64 => 64, relu),
79+
Dense(32 => 10)
80+
);
8181

8282
ps = Flux.params(m[3:end])
8383
```
@@ -142,10 +142,11 @@ Lastly, we can test our new layer. Thanks to the proper abstractions in Julia, o
142142
```julia
143143
model = Chain(
144144
Join(vcat,
145-
Chain(Dense(1, 5),Dense(5, 1)), # branch 1
146-
Dense(1, 2), # branch 2
147-
Dense(1, 1)), # branch 3
148-
Dense(4, 1)
145+
Chain(Dense(1 => 5, relu), Dense(5 => 1)), # branch 1
146+
Dense(1 => 2), # branch 2
147+
Dense(1 => 1) # branch 3
148+
),
149+
Dense(4 => 1)
149150
) |> gpu
150151

151152
xs = map(gpu, (rand(1), rand(1), rand(1)))
@@ -164,11 +165,11 @@ Join(combine, paths...) = Join(combine, paths)
164165
# use vararg/tuple version of Parallel forward pass
165166
model = Chain(
166167
Join(vcat,
167-
Chain(Dense(1, 5),Dense(5, 1)),
168-
Dense(1, 2),
169-
Dense(1, 1)
168+
Chain(Dense(1 => 5, relu), Dense(5 => 1)),
169+
Dense(1 => 2),
170+
Dense(1 => 1)
170171
),
171-
Dense(4, 1)
172+
Dense(4 => 1)
172173
) |> gpu
173174

174175
xs = map(gpu, (rand(1), rand(1), rand(1)))
@@ -201,8 +202,8 @@ Flux.@functor Split
201202
Now we can test to see that our `Split` does indeed produce multiple outputs.
202203
```julia
203204
model = Chain(
204-
Dense(10, 5),
205-
Split(Dense(5, 1),Dense(5, 3),Dense(5, 2))
205+
Dense(10 => 5),
206+
Split(Dense(5 => 1, tanh), Dense(5 => 3, tanh), Dense(5 => 2))
206207
) |> gpu
207208

208209
model(gpu(rand(10)))

docs/src/models/basics.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,14 @@ a(rand(10)) # => 5-element vector
158158

159159
Congratulations! You just built the `Dense` layer that comes with Flux. Flux has many interesting layers available, but they're all things you could have built yourself very easily.
160160

161-
(There is one small difference with `Dense` – for convenience it also takes an activation function, like `Dense(10, 5, σ)`.)
161+
(There is one small difference with `Dense` – for convenience it also takes an activation function, like `Dense(10 => 5, σ)`.)
162162

163163
## Stacking It Up
164164

165165
It's pretty common to write models that look something like:
166166

167167
```julia
168-
layer1 = Dense(10, 5, σ)
168+
layer1 = Dense(10 => 5, σ)
169169
# ...
170170
model(x) = layer3(layer2(layer1(x)))
171171
```
@@ -175,7 +175,7 @@ For long chains, it might be a bit more intuitive to have a list of layers, like
175175
```julia
176176
using Flux
177177

178-
layers = [Dense(10, 5, σ), Dense(5, 2), softmax]
178+
layers = [Dense(10 => 5, σ), Dense(5 => 2), softmax]
179179

180180
model(x) = foldl((x, m) -> m(x), layers, init = x)
181181

@@ -186,8 +186,8 @@ Handily, this is also provided for in Flux:
186186

187187
```julia
188188
model2 = Chain(
189-
Dense(10, 5, σ),
190-
Dense(5, 2),
189+
Dense(10 => 5, σ),
190+
Dense(5 => 2),
191191
softmax)
192192

193193
model2(rand(10)) # => 2-element vector
@@ -198,7 +198,7 @@ This quickly starts to look like a high-level deep learning library; yet you can
198198
A nice property of this approach is that because "models" are just functions (possibly with trainable parameters), you can also see this as simple function composition.
199199

200200
```julia
201-
m = Dense(5, 2) Dense(10, 5, σ)
201+
m = Dense(5 => 2) Dense(10 => 5, σ)
202202

203203
m(rand(10))
204204
```

docs/src/models/overview.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ Normally, your training and test data come from real world observations, but thi
4343
Now, build a model to make predictions with `1` input and `1` output:
4444

4545
```julia
46-
julia> model = Dense(1, 1)
47-
Dense(1, 1)
46+
julia> model = Dense(1 => 1)
47+
Dense(1 => 1)
4848

4949
julia> model.weight
5050
1×1 Matrix{Float32}:
@@ -58,10 +58,10 @@ julia> model.bias
5858
Under the hood, a dense layer is a struct with fields `weight` and `bias`. `weight` represents a weights' matrix and `bias` represents a bias vector. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:
5959

6060
```julia
61-
julia> predict = Dense(1, 1)
61+
julia> predict = Dense(1 => 1)
6262
```
6363

64-
`Dense(1, 1)` also implements the function `σ(Wx+b)` where `W` and `b` are the weights and biases. `σ` is an activation function (more on activations later). Our model has one weight and one bias, but typical models will have many more. Think of weights and biases as knobs and levers Flux can use to tune predictions. Activation functions are transformations that tailor models to your needs.
64+
`Dense(1 => 1)` also implements the function `σ(Wx+b)` where `W` and `b` are the weights and biases. `σ` is an activation function (more on activations later). Our model has one weight and one bias, but typical models will have many more. Think of weights and biases as knobs and levers Flux can use to tune predictions. Activation functions are transformations that tailor models to your needs.
6565

6666
This model will already make predictions, though not accurate ones yet:
6767

@@ -185,7 +185,7 @@ The predictions are good. Here's how we got there.
185185

186186
First, we gathered real-world data into the variables `x_train`, `y_train`, `x_test`, and `y_test`. The `x_*` data defines inputs, and the `y_*` data defines outputs. The `*_train` data is for training the model, and the `*_test` data is for verifying the model. Our data was based on the function `4x + 2`.
187187

188-
Then, we built a single input, single output predictive model, `predict = Dense(1, 1)`. The initial predictions weren't accurate, because we had not trained the model yet.
188+
Then, we built a single input, single output predictive model, `predict = Dense(1 => 1)`. The initial predictions weren't accurate, because we had not trained the model yet.
189189

190190
After building the model, we trained it with `train!(loss, parameters, data, opt)`. The loss function is first, followed by the `parameters` holding the weights and biases of the model, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.
191191

docs/src/models/recurrence.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also availabl
7474
Using these tools, we can now build the model shown in the above diagram with:
7575

7676
```julia
77-
m = Chain(RNN(2, 5), Dense(5, 1))
77+
m = Chain(RNN(2, 5), Dense(5 => 1))
7878
```
7979
In this example, each output has only one component.
8080

docs/src/models/regularisation.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ For example, say we have a simple regression.
99
```julia
1010
using Flux
1111
using Flux.Losses: logitcrossentropy
12-
m = Dense(10, 5)
12+
m = Dense(10 => 5)
1313
loss(x, y) = logitcrossentropy(m(x), y)
1414
```
1515

@@ -39,9 +39,9 @@ Here's a larger example with a multi-layer perceptron.
3939

4040
```julia
4141
m = Chain(
42-
Dense(28^2, 128, relu),
43-
Dense(128, 32, relu),
44-
Dense(32, 10))
42+
Dense(28^2 => 128, relu),
43+
Dense(128 => 32, relu),
44+
Dense(32 => 10))
4545

4646
sqnorm(x) = sum(abs2, x)
4747

@@ -55,8 +55,8 @@ One can also easily add per-layer regularisation via the `activations` function:
5555
```julia
5656
julia> using Flux: activations
5757

58-
julia> c = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
59-
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
58+
julia> c = Chain(Dense(10 => 5, σ), Dense(5 => 2), softmax)
59+
Chain(Dense(10 => 5, σ), Dense(5 => 2), softmax)
6060

6161
julia> activations(c, rand(10))
6262
3-element Array{Any,1}:

docs/src/saving.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ julia> using Flux
1111

1212
julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
1313
Chain(
14-
Dense(10, 5, relu), # 55 parameters
15-
Dense(5, 2), # 12 parameters
14+
Dense(10 => 5, relu), # 55 parameters
15+
Dense(5 => 2), # 12 parameters
1616
NNlib.softmax,
1717
) # Total: 4 arrays, 67 parameters, 524 bytes.
1818

@@ -32,8 +32,8 @@ julia> @load "mymodel.bson" model
3232

3333
julia> model
3434
Chain(
35-
Dense(10, 5, relu), # 55 parameters
36-
Dense(5, 2), # 12 parameters
35+
Dense(10 => 5, relu), # 55 parameters
36+
Dense(5 => 2), # 12 parameters
3737
NNlib.softmax,
3838
) # Total: 4 arrays, 67 parameters, 524 bytes.
3939

@@ -59,7 +59,7 @@ model parameters.
5959
```Julia
6060
julia> using Flux
6161

62-
julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax)
62+
julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
6363
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
6464

6565
julia> weights = Flux.params(model);
@@ -74,7 +74,7 @@ You can easily load parameters back into a model with `Flux.loadparams!`.
7474
```julia
7575
julia> using Flux
7676

77-
julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax)
77+
julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
7878
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
7979

8080
julia> using BSON: @load
@@ -94,7 +94,7 @@ In longer training runs it's a good idea to periodically save your model, so tha
9494
using Flux: throttle
9595
using BSON: @save
9696

97-
m = Chain(Dense(10,5,relu),Dense(5,2),softmax)
97+
m = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
9898

9999
evalcb = throttle(30) do
100100
# Show loss

docs/src/training/training.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ We can also define an objective in terms of some model:
4747

4848
```julia
4949
m = Chain(
50-
Dense(784, 32, σ),
51-
Dense(32, 10), softmax)
50+
Dense(784 => 32, σ),
51+
Dense(32 => 10), softmax)
5252

5353
loss(x, y) = Flux.Losses.mse(m(x), y)
5454
ps = Flux.params(m)

docs/src/utilities.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function make_model(width, height, inchannels, nclasses;
9292

9393
# the input dimension to Dense is programatically calculated from
9494
# width, height, and nchannels
95-
return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses))
95+
return Chain(conv_layers..., Dense(prod(conv_outsize) => nclasses))
9696
end
9797
```
9898

0 commit comments

Comments
 (0)