Skip to content

Commit 8fa4317

Browse files
bors[bot]logankilpatrickmcabbott
authored
Merge #1786
1786: Add docstring for `params` r=mcabbott a=logankilpatrick `params` is used all over the model zoo and tutorials in Flux. There should be a docstring or we should not publicly use it. Not sure if the proposed docstring gets the job done 100% but this seems like a step in the right direction. Co-authored-by: Logan Kilpatrick <23kilpatrick23@gmail.com> Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
2 parents 6d0e123 + 73351c7 commit 8fa4317

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

docs/src/training/training.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ Such an object contains a reference to the model's parameters, not a copy, such
7070

7171
Handling all the parameters on a layer by layer basis is explained in the [Layer Helpers](../models/basics.md) section. Also, for freezing model parameters, see the [Advanced Usage Guide](../models/advanced.md).
7272

73+
```@docs
74+
Flux.params
75+
```
76+
7377
## Datasets
7478

7579
The `data` argument of `train!` provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy dataset with only one data point:

src/functor.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,37 @@ function params!(p::Params, x, seen = IdSet())
4848
end
4949
end
5050

51+
"""
52+
params(model)
53+
params(layers...)
54+
55+
Given a model or specific layers from a model, create a `Params` object pointing to its trainable parameters.
56+
57+
This can be used with the `gradient` function, see [Taking Gradients](@ref), or as input to the [`Flux.train!`](@ref Flux.train!) function.
58+
59+
The behaviour of `params` on custom types can be customized using [`Functor.@functor`](@ref) or [`Flux.trainable`](@ref).
60+
61+
# Examples
62+
```jldoctest
63+
julia> params(Chain(Dense(ones(2,3)), softmax)) # unpacks Flux models
64+
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])
65+
66+
julia> bn = BatchNorm(2, relu)
67+
BatchNorm(2, relu) # 4 parameters, plus 4 non-trainable
68+
69+
julia> params(bn) # only the trainable parameters
70+
Params([Float32[0.0, 0.0], Float32[1.0, 1.0]])
71+
72+
julia> params([1, 2, 3], [4.0]) # one or more arrays of numbers
73+
Params([[1, 2, 3], [4.0]])
74+
75+
julia> params([[1, 2, 3], [4.0]]) # unpacks array of arrays
76+
Params([[1, 2, 3], [4.0]])
77+
78+
julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin)) # ignores scalars, unpacks NamedTuples
79+
Params([[2 2], [3, 3, 3]])
80+
```
81+
"""
5182
function params(m...)
5283
ps = Params()
5384
params!(ps, m)

0 commit comments

Comments
 (0)