Skip to content

Commit 7b56813

Browse files
authored
Replace unrolled foldl used to evaluate Chain with a better one (#1809)
* use foldl for Chain * use foldl less often * second derivative tests * Revert "use foldl less often" This reverts commit a74f86d. * replace foldl with generated expression * allow unstable Chain{Vector} too * trailing comma * fixup
1 parent 4a3483e commit 7b56813

File tree

4 files changed

+69
-10
lines changed

4 files changed

+69
-10
lines changed

src/layers/basic.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,12 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)),
2727
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
2828
true
2929
```
30+
31+
For large models, there is a special type-unstable path which can reduce compilation
32+
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
33+
This feature is somewhat experimental, beware!
3034
"""
31-
struct Chain{T<:Union{Tuple, NamedTuple}}
35+
struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}}
3236
layers::T
3337
end
3438

@@ -44,10 +48,22 @@ end
4448

4549
@functor Chain
4650

47-
applychain(::Tuple{}, x) = x
48-
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
51+
(c::Chain)(x) = applychain(c.layers, x)
52+
53+
@generated function applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
54+
symbols = vcat(:x, [gensym() for _ in 1:N])
55+
calls = [:($(symbols[i+1]) = layers[$i]($(symbols[i]))) for i in 1:N]
56+
Expr(:block, calls...)
57+
end
58+
59+
applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x)
4960

50-
(c::Chain)(x) = applychain(Tuple(c.layers), x)
61+
function applychain(layers::AbstractVector, x) # type-unstable path, helps compile times
62+
for f in layers
63+
x = f(x)
64+
end
65+
x
66+
end
5167

5268
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
5369
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
@@ -60,6 +76,7 @@ function Base.show(io::IO, c::Chain)
6076
end
6177
_show_layers(io, layers::Tuple) = join(io, layers, ", ")
6278
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ")
79+
_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]"))
6380

6481
# This is a temporary and naive implementation
6582
# it might be replaced in the future for better performance

src/layers/show.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ for T in [
1414
end
1515

1616
function _big_show(io::IO, obj, indent::Int=0, name=nothing)
17+
pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")")
1718
children = _show_children(obj)
1819
if all(_show_leaflike, children)
1920
_layer_show(io, obj, indent, name)
2021
else
21-
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(")
22+
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre)
2223
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
2324
# then we insert names -- can this be done more generically?
2425
for k in Base.keys(obj)
@@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
3536
end
3637
end
3738
if indent == 0 # i.e. this is the outermost container
38-
print(io, ")")
39+
print(io, rpad(post, 2))
3940
_big_finale(io, obj)
4041
else
41-
println(io, " "^indent, "),")
42+
println(io, " "^indent, post, ",")
4243
end
4344
end
4445
end
@@ -90,18 +91,18 @@ function _big_finale(io::IO, m)
9091
noncnt = _childarray_sum(_->1, m) - length(ps)
9192
if noncnt > 0
9293
nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps))
93-
printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
94+
printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
9495
println(io, pars, " parameters,")
9596
printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black)
9697
print(io, bytes, ".")
9798
else
98-
printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black)
99+
printstyled(io, " "^18, "# Total: ", length(ps), " arrays, "; color=:light_black)
99100
print(io, pars, " parameters, ", bytes, ".")
100101
end
101102
end
102103
end
103104

104-
_childarray_sum(f, x::AbstractArray) = f(x)
105+
_childarray_sum(f, x::AbstractArray{<:Number}) = f(x)
105106
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))
106107

107108
# utility functions

test/layers/basic.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import Flux: activations
2929
@test m == fmap(identity, m) # does not forget names
3030

3131
@test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name
32+
33+
@test_nowarn Chain([Dense(10, 5, σ), Dense(5, 2)])(randn(Float32, 10)) # vector of layers
3234
end
3335

3436
@testset "Activations" begin
@@ -297,3 +299,41 @@ import Flux: activations
297299
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
298300
end
299301
end
302+
303+
@testset "second derivatives" begin
304+
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
305+
@test Zygote.hessian_dual(summ1, [1,2,3]) Zygote.hessian_reverse(summ1, [1,2,3])
306+
307+
m1v = Chain([m1[1], m1[2]]) # vector of layers
308+
@test Zygote.hessian_dual(summ1v, [1,2,3]) Zygote.hessian_dual(summ1, [1,2,3])
309+
@test_broken Zygote.hessian_dual(summ1v, [1,2,3]) Zygote.hessian_reverse(summ1v, [1,2,3])
310+
311+
# NNlib's softmax gradient writes in-place
312+
m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax)
313+
@test_broken Zygote.hessian_dual(summ2, [1,2,3]) Zygote.hessian_reverse(summ2, [1,2,3])
314+
315+
# https://github.com/FluxML/NNlib.jl/issues/362
316+
m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2))
317+
x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3)
318+
@test_broken Zygote.hessian_dual(summ3, x3) Zygote.hessian_reverse(summ3, x3)
319+
end
320+
321+
@testset "gradients of Chain{Vector}" begin
322+
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
323+
m1v = Chain([m1[1], m1[2]])
324+
@test sum(length, params(m1)) == sum(length, params(m1v))
325+
326+
x1 = randn(Float32,3,5)
327+
@test m1(x1) m1v(x1)
328+
329+
y1 = rand(Bool,2,5)
330+
g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1))
331+
g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v))
332+
@test g1[m1[1].weight] g1v[m1v[1].weight]
333+
@test g1[m1[2].bias] g1v[m1v[2].bias]
334+
335+
@test Flux.destructure(m1)[1] Flux.destructure(m1v)[1]
336+
z1 = rand(22);
337+
@test Flux.destructure(m1)[2](z1)[1].weight Flux.destructure(m1v)[2](z1)[1].weight
338+
# Note that Flux.destructure(m1v)[2](z) has a Chain{Tuple}, as does m1v[1:2]
339+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Flux: params
55
using Test
66
using Random, Statistics, LinearAlgebra
77
using IterTools: ncycle
8+
using Zygote
89
using CUDA
910

1011
Random.seed!(0)

0 commit comments

Comments
 (0)