@@ -77,7 +77,7 @@ or [`update!`](@ref).
77
77
julia> m = (x = rand(3), y = (true, false), z = tanh);
78
78
79
79
julia> Optimisers.setup(Momentum(), m) # same field names as m
80
- (x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = (nothing, nothing) , z = nothing )
80
+ (x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()) , z = () )
81
81
```
82
82
83
83
The recursion into structures uses Functors.jl, and any new `struct`s containing parameters
@@ -90,15 +90,15 @@ julia> struct Layer; mat; fun; end
90
90
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);
91
91
92
92
julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
93
- (lay = nothing , vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
93
+ (lay = () , vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
94
94
95
95
julia> destructure(model)
96
96
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))
97
97
98
98
julia> using Functors; @functor Layer # annotate this type as containing parameters
99
99
100
100
julia> Optimisers.setup(Momentum(), model)
101
- (lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = nothing ), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
101
+ (lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = () ), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
102
102
103
103
julia> destructure(model)
104
104
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
@@ -120,12 +120,12 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s
120
120
julia> m = (x = Float32[1,2,3], y = tanh);
121
121
122
122
julia> t = Optimisers.setup(Descent(0.1f0), m)
123
- (x = Leaf(Descent{Float32}(0.1), nothing), y = nothing )
123
+ (x = Leaf(Descent{Float32}(0.1), nothing), y = () )
124
124
125
125
julia> g = (x = [1,1,1], y = nothing); # fake gradient
126
126
127
127
julia> Optimisers.update(t, m, g)
128
- ((x = Leaf(Descent{Float32}(0.1), nothing), y = nothing ), (x = Float32[0.9, 1.9, 2.9], y = tanh))
128
+ ((x = Leaf(Descent{Float32}(0.1), nothing), y = () ), (x = Float32[0.9, 1.9, 2.9], y = tanh))
129
129
```
130
130
"""
131
131
update
0 commit comments