Skip to content

Commit faea10f

Browse files
add doc section
1 parent 06f786c commit faea10f

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

docs/src/index.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,29 @@ flat, re = destructure(params)
290290
end
291291
```
292292

293+
## Collecting all trainable parameters
294+
295+
Sometimes it is useful to collect all trainable parameters in a model,
296+
similarly to what [`destructure`](@ref Optimisers.destructure) does but keeping
297+
the arrays separate.
298+
This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays:
299+
300+
```julia
301+
julia> using Flux, Optimisers
302+
303+
julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2));
304+
305+
julia> trainables(model)
306+
6-element Vector{AbstractArray}:
307+
Float32[0.5756773 -0.1975264; 0.4723181 -0.7546912; -0.91631395 0.07392061]
308+
Float32[0.0, 0.0, 0.0]
309+
Float32[0.0, 0.0, 0.0]
310+
Float32[1.0, 1.0, 1.0]
311+
Float32[-0.8764882 0.40812716 0.1919528; -0.9123545 -0.4462516 0.6751252]
312+
Float32[0.0, 0.0]
313+
314+
julia> l2reg(model) = sum([sum(abs2,p) for p in trainables(model)]);
315+
316+
julia> g = gradient(l2reg, model)[1];
317+
```
318+
Notice that the `BatchNorm` layer has two trainable parameters, `γ` and `β`, which are included in the list, while the `μ ` and `σ²` buffers are not.

0 commit comments

Comments
 (0)