Skip to content

Commit fe803a1

Browse files
bors[bot]mcabbott
andauthored
Merge #1794
1794: Tidy up `Maxout` r=mcabbott a=mcabbott Maxout is from #698 . This: * adds pretty printing * changes the explicit signature to `Maxout(layer, layer, layer)`, rather than providing a tuple, to be more like other layers (with deprecation) * adds more examples to the docstring, and combines the two * changes not to use `mapreduce`. I see now this was a performance choice at the time, discussed here #647 (comment) , but with Zygote this is much slower. Before: ``` julia> using Flux julia> m3 = Maxout(() -> Dense(5, 7, tanh), 3) Maxout{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}((Dense(5, 7, tanh), Dense(5, 7, tanh), Dense(5, 7, tanh))) julia> x = rand(Float32, 5, 11); julia> `@btime` gradient(sum∘m3, $x); min 112.792 μs, mean 123.774 μs (930 allocations, 49.09 KiB. GC mean 3.71%) ``` After: ``` julia> m3 = Maxout(() -> Dense(5, 7, tanh), 3) Maxout( Dense(5, 7, tanh), # 42 parameters Dense(5, 7, tanh), # 42 parameters Dense(5, 7, tanh), # 42 parameters ) # Total: 6 arrays, 126 parameters, 888 bytes. julia> x = rand(Float32, 5, 11); julia> `@btime` gradient(sum∘m3, $x); min 34.541 μs, mean 38.448 μs (493 allocations, 32.48 KiB. GC mean 6.63%) ``` Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
2 parents 48596ef + 74eb83b commit fe803a1

File tree

4 files changed

+55
-27
lines changed

4 files changed

+55
-27
lines changed

src/deprecations.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
3333

3434
ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, use Base.ones to specify the element type"))
3535
zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type"))
36+
37+
38+
# v0.13 deprecations
39+
@deprecate Maxout(layers::Tuple) Maxout(layers...)

src/layers/basic.jl

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -211,48 +211,67 @@ function Base.show(io::IO, l::Diagonal)
211211
end
212212

213213
"""
214-
Maxout(over)
214+
Maxout(layers...)
215+
Maxout(f, n_alts)
215216
216-
The [Maxout](https://arxiv.org/abs/1302.4389) layer has a number of
217-
internal layers which all receive the same input. It returns the elementwise
218-
maximum of the internal layers' outputs.
217+
This contains a number of internal layes, each of which receives the same input.
218+
Its output is the elementwise maximum of the the internal layers' outputs.
219219
220-
Maxout over linear dense layers satisfies the univeral approximation theorem.
221-
"""
222-
struct Maxout{FS<:Tuple}
223-
over::FS
224-
end
220+
Instead of defining layers individually, you can provide a zero-argument function
221+
which constructs them, and the number to construct.
225222
226-
"""
227-
Maxout(f, n_alts)
223+
Maxout over linear dense layers satisfies the univeral approximation theorem.
224+
See Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks"
225+
[https://arxiv.org/abs/1302.4389](1302.4389).
228226
229-
Construct a Maxout layer over `n_alts` instances of the layer given by `f`.
230-
The function takes no arguments and should return some callable layer.
231-
Conventionally, this is a linear dense layer.
227+
See also [`Parallel`](@ref) to reduce with other operators.
232228
233229
# Examples
230+
```
231+
julia> m = Maxout(x -> abs2.(x), x -> x .* 3);
234232
235-
This constructs a `Maxout` layer over 4 internal dense linear layers, each
236-
identical in structure (784 inputs, 128 outputs):
237-
```jldoctest
238-
julia> insize = 784;
233+
julia> m([-2 -1 0 1 2])
234+
1×5 Matrix{Int64}:
235+
4 1 0 3 6
239236
240-
julia> outsize = 128;
237+
julia> m3 = Maxout(() -> Dense(5, 7, tanh), 3)
238+
Maxout(
239+
Dense(5, 7, tanh), # 42 parameters
240+
Dense(5, 7, tanh), # 42 parameters
241+
Dense(5, 7, tanh), # 42 parameters
242+
) # Total: 6 arrays, 126 parameters, 888 bytes.
241243
242-
julia> Maxout(()->Dense(insize, outsize), 4);
244+
julia> Flux.outputsize(m3, (5, 11))
245+
(7, 11)
243246
```
244247
"""
245-
function Maxout(f, n_alts)
248+
struct Maxout{FS<:Tuple}
249+
over::FS
250+
Maxout(layers...) = new{typeof(layers)}(layers)
251+
end
252+
253+
function Maxout(f::Function, n_alts::Integer)
246254
over = Tuple(f() for _ in 1:n_alts)
247-
return Maxout(over)
255+
return Maxout(over...)
248256
end
249257

250258
@functor Maxout
251259

252260
function (mo::Maxout)(input::AbstractArray)
253-
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
261+
# Perhaps surprisingly, pairwise max broadcast is often faster,
262+
# even with Zygote. See #698 and #1794
263+
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
264+
end
265+
266+
trainable(mo::Maxout) = mo.over
267+
268+
function Base.show(io::IO, mo::Maxout)
269+
print(io, "Maxout(")
270+
_show_layers(io, mo.over)
271+
print(io, ")")
254272
end
255273

274+
256275
"""
257276
SkipConnection(layer, connection)
258277
@@ -277,6 +296,8 @@ julia> sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3));
277296
julia> size(sm(x)) == (5, 5, 11, 10)
278297
true
279298
```
299+
300+
See also [`Parallel`](@ref), [`Maxout`](@ref).
280301
"""
281302
struct SkipConnection{T,F}
282303
layers::T
@@ -390,7 +411,7 @@ end
390411
Parallel(connection, layers...)
391412
Parallel(connection; name = layer, ...)
392413
393-
Create a 'Parallel' layer that passes an input array to each path in
414+
Create a `Parallel` layer that passes an input array to each path in
394415
`layers`, before reducing the output with `connection`.
395416
396417
Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`.
@@ -399,6 +420,9 @@ If called with multiple inputs, they are `zip`ped with the layers, thus `Paralle
399420
Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
400421
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
401422
423+
See also [`SkipConnection`](@ref) which is `Parallel` with one `identity`,
424+
and [`Maxout`](@ref) which reduces by broadcasting `max`.
425+
402426
# Examples
403427
404428
```jldoctest

src/layers/show.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
for T in [
3-
:Chain, :Parallel, :SkipConnection, :Recur # container types
3+
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout # container types
44
]
55
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
66
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL

test/layers/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,13 @@ import Flux: activations
109109
end
110110

111111
@testset "simple alternatives" begin
112-
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
112+
mo = Maxout(x -> x, x -> 2x, x -> 0.5x)
113113
input = rand(40)
114114
@test mo(input) == 2*input
115115
end
116116

117117
@testset "complex alternatives" begin
118-
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
118+
mo = Maxout(x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x)
119119
input = [3.0 2.0]
120120
target = [0.5, 0.7].*input
121121
@test mo(input) == target

0 commit comments

Comments
 (0)