Skip to content

Commit dd5d110

Browse files
bors[bot]Michael Abbott
andauthored
Merge #1486
1486: Make `outputsize` understand multiple inputs r=mcabbott a=mcabbott Closes #1466 Co-authored-by: Michael Abbott <me@escbook>
2 parents f4b01a2 + 666d6b5 commit dd5d110

File tree

2 files changed

+72
-19
lines changed

2 files changed

+72
-19
lines changed

src/outputsize.jl

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,19 @@ using .NilNumber: Nil, nil
5050
"""
5151
outputsize(m, inputsize::Tuple; padbatch=false)
5252
53-
Calculate the output size of model `m` given the input size.
53+
Calculate the size of the output from model `m`, given the size of the input.
5454
Obeys `outputsize(m, size(x)) == size(m(x))` for valid input `x`.
55-
Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and
55+
56+
Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and
5657
returns the final size including this extra batch dimension.
5758
58-
This should be faster than calling `size(m(x))`. It uses a trivial number type,
59-
and thus should work out of the box for custom layers.
59+
This should be faster than calling `size(m(x))`. It uses a trivial number type,
60+
which should work out of the box for custom layers.
6061
6162
If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`.
6263
6364
# Examples
64-
```jldoctest
65+
```julia-repl
6566
julia> using Flux: outputsize
6667
6768
julia> outputsize(Dense(10, 4), (10,); padbatch=true)
@@ -79,32 +80,70 @@ julia> outputsize(m, (10, 10, 3, 64))
7980
(6, 6, 32, 64)
8081
8182
julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end
83+
┌ Error: layer Conv((3, 3), 3=>16), index 1 in Chain, gave an error with input of size (10, 10, 7, 64)
84+
└ @ Flux ~/.julia/dev/Flux/src/outputsize.jl:114
8285
DimensionMismatch("Input channels must match! (7 vs. 3)")
8386
84-
julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1))
87+
julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1)) # Vector of layers becomes a Chain
8588
(2, 1)
89+
```
90+
"""
91+
function outputsize(m, inputsizes::Tuple...; padbatch=false)
92+
x = nil_input(padbatch, inputsizes...)
93+
return size(m(x))
94+
end
8695

87-
julia> using LinearAlgebra: norm
96+
nil_input(pad::Bool, s::Tuple{Vararg{Integer}}) = pad ? fill(nil, (s...,1)) : fill(nil, s)
97+
nil_input(pad::Bool, multi::Tuple{Vararg{Integer}}...) = nil_input.(pad, multi)
98+
nil_input(pad::Bool, tup::Tuple{Vararg{Tuple}}) = nil_input(pad, tup...)
99+
100+
function outputsize(m::Chain, inputsizes::Tuple{Vararg{Integer}}...; padbatch=false)
101+
x = nil_input(padbatch, inputsizes...)
102+
for (i,lay) in enumerate(m.layers)
103+
try
104+
x = lay(x)
105+
catch err
106+
str = x isa AbstractArray ? "with input of size $(size(x))" : ""
107+
@error "layer $lay, index $i in Chain, gave an error $str"
108+
rethrow(err)
109+
end
110+
end
111+
return size(x)
112+
end
88113

89-
julia> f(x) = x ./ norm.(eachcol(x));
114+
"""
115+
outputsize(m, x_size, y_size, ...; padbatch=false)
90116
91-
julia> outputsize(f, (10, 1)) # manually specify batch size as 1
92-
(10, 1)
117+
For model or layer `m` accepting multiple arrays as input,
118+
this returns `size(m((x, y, ...)))` given `size_x = size(x)`, etc.
93119
94-
julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size
95-
(10, 1)
120+
# Examples
121+
```jldoctest
122+
julia> x, y = rand(Float32, 5, 64), rand(Float32, 7, 64);
123+
124+
julia> par = Parallel(vcat, Dense(5, 9), Dense(7, 11));
125+
126+
julia> Flux.outputsize(par, (5, 64), (7, 64))
127+
(20, 64)
128+
129+
julia> m = Chain(par, Dense(20, 13), softmax);
130+
131+
julia> Flux.outputsize(m, (5,), (7,); padbatch=true)
132+
(13, 1)
133+
134+
julia> par(x, y) == par((x, y)) == Chain(par, identity)((x, y))
135+
true
96136
```
137+
Notice that `Chain` only accepts multiple arrays as a tuple,
138+
while `Parallel` also accepts them as multiple arguments;
139+
`outputsize` always supplies the tuple.
97140
"""
98-
function outputsize(m, inputsize::Tuple; padbatch=false)
99-
inputsize = padbatch ? (inputsize..., 1) : inputsize
100-
101-
return size(m(fill(nil, inputsize)))
102-
end
141+
outputsize
103142

104143
## make tuples and vectors be like Chains
105144

106-
outputsize(m::Tuple, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch)
107-
outputsize(m::AbstractVector, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch)
145+
outputsize(m::Tuple, input::Tuple...; padbatch=false) = outputsize(Chain(m...), input...; padbatch=padbatch)
146+
outputsize(m::AbstractVector, input::Tuple...; padbatch=false) = outputsize(Chain(m...), input...; padbatch=padbatch)
108147

109148
## bypass statistics in normalization layers
110149

test/outputsize.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@
3636
@test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1)
3737
end
3838

39+
@testset "multiple inputs" begin
40+
m = Parallel(vcat, Dense(2, 4, relu), Dense(3, 6, relu))
41+
@test outputsize(m, (2,), (3,)) == (10,)
42+
@test outputsize(m, ((2,), (3,))) == (10,)
43+
@test outputsize(m, (2,), (3,); padbatch=true) == (10, 1)
44+
@test outputsize(m, (2,7), (3,7)) == (10, 7)
45+
46+
m = Chain(m, Dense(10, 13, tanh), softmax)
47+
@test outputsize(m, (2,), (3,)) == (13,)
48+
@test outputsize(m, ((2,), (3,))) == (13,)
49+
@test outputsize(m, (2,), (3,); padbatch=true) == (13, 1)
50+
@test outputsize(m, (2,7), (3,7)) == (13, 7)
51+
end
52+
3953
@testset "activations" begin
4054
@testset for f in [celu, elu, gelu, hardsigmoid, hardtanh,
4155
leakyrelu, lisht, logcosh, logσ, mish,

0 commit comments

Comments
 (0)