Skip to content

Commit 2809636

Browse files
committed
tidy up maxout
1 parent e8a67b4 commit 2809636

File tree

4 files changed

+54
-27
lines changed

4 files changed

+54
-27
lines changed

src/deprecations.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
3333

3434
ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, use Base.ones to specify the element type"))
3535
zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type"))
36+
37+
38+
# v0.13 deprecations
39+
@deprecate Maxout(layers::Tuple) Maxout(layers...)

src/layers/basic.jl

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -211,48 +211,66 @@ function Base.show(io::IO, l::Diagonal)
211211
end
212212

213213
"""
214-
Maxout(over)
214+
Maxout(layers...)
215+
Maxout(f, n_alts)
215216
216-
The [Maxout](https://arxiv.org/abs/1302.4389) layer has a number of
217-
internal layers which all receive the same input. It returns the elementwise
218-
maximum of the internal layers' outputs.
217+
This contains a number of internal layes, each of which receives the same input.
218+
Its output is the elementwise maximum of the the internal layers' outputs.
219219
220-
Maxout over linear dense layers satisfies the univeral approximation theorem.
221-
"""
222-
struct Maxout{FS<:Tuple}
223-
over::FS
224-
end
220+
Instead of defining layers individually, you can provide a zero-argument function
221+
which constructs them, and the number to construct.
225222
226-
"""
227-
Maxout(f, n_alts)
223+
Maxout over linear dense layers satisfies the univeral approximation theorem.
224+
See Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks"
225+
[https://arxiv.org/abs/1302.4389](1302.4389).
228226
229-
Construct a Maxout layer over `n_alts` instances of the layer given by `f`.
230-
The function takes no arguments and should return some callable layer.
231-
Conventionally, this is a linear dense layer.
227+
See also [`Parallel`](@ref) to reduce with other operators.
232228
233229
# Examples
230+
```
231+
julia> m = Maxout(Dense([1;;], false, abs2), Dense([3;;]));
234232
235-
This constructs a `Maxout` layer over 4 internal dense linear layers, each
236-
identical in structure (784 inputs, 128 outputs):
237-
```jldoctest
238-
julia> insize = 784;
233+
julia> m([-2 -1 0 1 2])
234+
1×5 Matrix{Int64}:
235+
4 1 0 3 6
239236
240-
julia> outsize = 128;
237+
julia> m3 = Maxout(() -> Dense(5, 7, tanh), 3)
238+
Maxout(
239+
Dense(5, 7, tanh), # 42 parameters
240+
Dense(5, 7, tanh), # 42 parameters
241+
Dense(5, 7, tanh), # 42 parameters
242+
) # Total: 6 arrays, 126 parameters, 888 bytes.
241243
242-
julia> Maxout(()->Dense(insize, outsize), 4);
244+
julia> Flux.outputsize(m3, (5, 11))
245+
(7, 11)
243246
```
244247
"""
245-
function Maxout(f, n_alts)
248+
struct Maxout{FS<:Tuple}
249+
over::FS
250+
Maxout(layers...) = new{typeof(layers)}(layers)
251+
end
252+
253+
function Maxout(f::Function, n_alts::Integer)
246254
over = Tuple(f() for _ in 1:n_alts)
247-
return Maxout(over)
255+
return Maxout(over...)
248256
end
249257

250258
@functor Maxout
251259

252260
function (mo::Maxout)(input::AbstractArray)
253-
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
261+
outs = map(lay -> lay(input), mo.over)
262+
return max.(outs...)
263+
end
264+
265+
trainable(mo::Maxout) = mo.over
266+
267+
function Base.show(io::IO, mo::Maxout)
268+
print(io, "Maxout(")
269+
_show_layers(io, mo.over)
270+
print(io, ")")
254271
end
255272

273+
256274
"""
257275
SkipConnection(layer, connection)
258276
@@ -277,6 +295,8 @@ julia> sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3));
277295
julia> size(sm(x)) == (5, 5, 11, 10)
278296
true
279297
```
298+
299+
See also [`Parallel`](@ref), [`Maxout`](@ref).
280300
"""
281301
struct SkipConnection{T,F}
282302
layers::T
@@ -390,7 +410,7 @@ end
390410
Parallel(connection, layers...)
391411
Parallel(connection; name = layer, ...)
392412
393-
Create a 'Parallel' layer that passes an input array to each path in
413+
Create a `Parallel` layer that passes an input array to each path in
394414
`layers`, before reducing the output with `connection`.
395415
396416
Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`.
@@ -399,6 +419,9 @@ If called with multiple inputs, they are `zip`ped with the layers, thus `Paralle
399419
Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
400420
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
401421
422+
See also [`SkipConnection`](@ref) which is `Parallel` with one `identity`,
423+
and [`Maxout`](@ref) which reduces by broadcasting `max`.
424+
402425
# Examples
403426
404427
```jldoctest

src/layers/show.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
for T in [
3-
:Chain, :Parallel, :SkipConnection, :Recur # container types
3+
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout # container types
44
]
55
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
66
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL

test/layers/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,13 @@ import Flux: activations
109109
end
110110

111111
@testset "simple alternatives" begin
112-
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
112+
mo = Maxout(x -> x, x -> 2x, x -> 0.5x)
113113
input = rand(40)
114114
@test mo(input) == 2*input
115115
end
116116

117117
@testset "complex alternatives" begin
118-
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
118+
mo = Maxout(x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x)
119119
input = [3.0 2.0]
120120
target = [0.5, 0.7].*input
121121
@test mo(input) == target

0 commit comments

Comments
 (0)