Skip to content

Commit 43e51a2

Browse files
trainables
1 parent 4132da4 commit 43e51a2

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

src/destructure.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function _flatten(x)
7373
o
7474
end
7575
isempty(arrays) && return Bool[], off, 0
76-
reduce(vcat, arrays), off, len[]
76+
return reduce(vcat, arrays), off, len[]
7777
end
7878

7979
struct _TrainableStructWalk <: AbstractWalk end
@@ -174,3 +174,4 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
174174
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
175175
nothing, _ -> (NoT,)
176176
end
177+

src/interface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ and `trainable(x)` must contain a subset of these.
167167
"""
168168
trainable(x) = functor(x)[1]
169169

170+
# like trainable(x), but also tries to output non-trainable children giving value nothing
170171
_trainable(x) = _trainable(functor(x)[1], trainable(x))
171172
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
172173
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr

src/trainables.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
3+
function trainables1(x)
4+
isnumeric(x) && return [x]
5+
arrays = AbstractArray[]
6+
fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
7+
push!(arrays, y)
8+
return y
9+
end
10+
return arrays
11+
end
12+
13+
############
14+
15+
using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk
16+
17+
struct TrainableWalk2 <: AbstractWalk end
18+
19+
function (walk::TrainableWalk2)(recurse, x, ys...)
20+
x_children = _values(Optimisers.trainable(x))
21+
ys_children = map(Optimisers.trainable, ys)
22+
res = _map(recurse, x_children, ys_children...)
23+
@show _values(res)
24+
return _values(res)
25+
end
26+
27+
function trainables2(x)
28+
exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x)
29+
return execute(ExcludeWalk(TrainableWalk2(), x -> x, exclude), x)
30+
end
31+
32+
using Flux
33+
34+
m = Chain(Dense(2 => 3, relu), BatchNorm(3), Dense(3 => 2))
35+
trainables2(m)

0 commit comments

Comments
 (0)