@@ -50,18 +50,19 @@ using .NilNumber: Nil, nil
50
50
"""
51
51
outputsize(m, inputsize::Tuple; padbatch=false)
52
52
53
- Calculate the output size of model `m` given the input size.
53
+ Calculate the size of the output from model `m`, given the size of the input.
54
54
Obeys `outputsize(m, size(x)) == size(m(x))` for valid input `x`.
55
- Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and
55
+
56
+ Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and
56
57
returns the final size including this extra batch dimension.
57
58
58
- This should be faster than calling `size(m(x))`. It uses a trivial number type,
59
- and thus should work out of the box for custom layers.
59
+ This should be faster than calling `size(m(x))`. It uses a trivial number type,
60
+ which should work out of the box for custom layers.
60
61
61
62
If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`.
62
63
63
64
# Examples
64
- ```jldoctest
65
+ ```julia-repl
65
66
julia> using Flux: outputsize
66
67
67
68
julia> outputsize(Dense(10, 4), (10,); padbatch=true)
@@ -79,32 +80,70 @@ julia> outputsize(m, (10, 10, 3, 64))
79
80
(6, 6, 32, 64)
80
81
81
82
julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end
83
+ ┌ Error: layer Conv((3, 3), 3=>16), index 1 in Chain, gave an error with input of size (10, 10, 7, 64)
84
+ └ @ Flux ~/.julia/dev/Flux/src/outputsize.jl:114
82
85
DimensionMismatch("Input channels must match! (7 vs. 3)")
83
86
84
- julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1))
87
+ julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1)) # Vector of layers becomes a Chain
85
88
(2, 1)
89
+ ```
90
+ """
91
+ function outputsize (m, inputsizes:: Tuple... ; padbatch= false )
92
+ x = nil_input (padbatch, inputsizes... )
93
+ return size (m (x))
94
+ end
86
95
87
- julia> using LinearAlgebra: norm
96
+ nil_input (pad:: Bool , s:: Tuple{Vararg{Integer}} ) = pad ? fill (nil, (s... ,1 )) : fill (nil, s)
97
+ nil_input (pad:: Bool , multi:: Tuple{Vararg{Integer}} ...) = nil_input .(pad, multi)
98
+ nil_input (pad:: Bool , tup:: Tuple{Vararg{Tuple}} ) = nil_input (pad, tup... )
99
+
100
+ function outputsize (m:: Chain , inputsizes:: Tuple{Vararg{Integer}} ...; padbatch= false )
101
+ x = nil_input (padbatch, inputsizes... )
102
+ for (i,lay) in enumerate (m. layers)
103
+ try
104
+ x = lay (x)
105
+ catch err
106
+ str = x isa AbstractArray ? " with input of size $(size (x)) " : " "
107
+ @error " layer $lay , index $i in Chain, gave an error $str "
108
+ rethrow (err)
109
+ end
110
+ end
111
+ return size (x)
112
+ end
88
113
89
- julia> f(x) = x ./ norm.(eachcol(x));
114
+ """
115
+ outputsize(m, x_size, y_size, ...; padbatch=false)
90
116
91
- julia> outputsize(f, (10, 1)) # manually specify batch size as 1
92
- (10, 1)
117
+ For model or layer `m` accepting multiple arrays as input,
118
+ this returns `size(m((x, y, ...)))` given `size_x = size(x)`, etc.
93
119
94
- julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size
95
- (10, 1)
120
+ # Examples
121
+ ```jldoctest
122
+ julia> x, y = rand(Float32, 5, 64), rand(Float32, 7, 64);
123
+
124
+ julia> par = Parallel(vcat, Dense(5, 9), Dense(7, 11));
125
+
126
+ julia> Flux.outputsize(par, (5, 64), (7, 64))
127
+ (20, 64)
128
+
129
+ julia> m = Chain(par, Dense(20, 13), softmax);
130
+
131
+ julia> Flux.outputsize(m, (5,), (7,); padbatch=true)
132
+ (13, 1)
133
+
134
+ julia> par(x, y) == par((x, y)) == Chain(par, identity)((x, y))
135
+ true
96
136
```
137
+ Notice that `Chain` only accepts multiple arrays as a tuple,
138
+ while `Parallel` also accepts them as multiple arguments;
139
+ `outputsize` always supplies the tuple.
97
140
"""
98
- function outputsize (m, inputsize:: Tuple ; padbatch= false )
99
- inputsize = padbatch ? (inputsize... , 1 ) : inputsize
100
-
101
- return size (m (fill (nil, inputsize)))
102
- end
141
+ outputsize
103
142
104
143
# # make tuples and vectors be like Chains
105
144
106
- outputsize (m:: Tuple , inputsize :: Tuple ; padbatch= false ) = outputsize (Chain (m... ), inputsize ; padbatch= padbatch)
107
- outputsize (m:: AbstractVector , inputsize :: Tuple ; padbatch= false ) = outputsize (Chain (m... ), inputsize ; padbatch= padbatch)
145
+ outputsize (m:: Tuple , input :: Tuple... ; padbatch= false ) = outputsize (Chain (m... ), input ... ; padbatch= padbatch)
146
+ outputsize (m:: AbstractVector , input :: Tuple... ; padbatch= false ) = outputsize (Chain (m... ), input ... ; padbatch= padbatch)
108
147
109
148
# # bypass statistics in normalization layers
110
149
0 commit comments