Skip to content

Commit 02ea511

Browse files
bors[bot]Michael Abbott
andauthored
Merge #1469
1469: One-arg unsqueeze method r=CarloLucibello a=mcabbott This makes `unsqeeze(3)` return a function. In fact a `Base.Fix2` so that it has a predictable type, although overloading `show` turns out to be a pain for `<:Function` (perhaps it would be simpler to give it its own `struct`?), right now this is pretty within `Chain` but not when used by itself. The only tests of the existing method seem to be doctests, so this adds a few more. And cross-links to `flatten`. Plus a test for `outputsize`, now. Co-authored-by: Michael Abbott <me@pseudomac>
2 parents 1ac78b5 + 0619f15 commit 02ea511

File tree

3 files changed

+55
-13
lines changed

3 files changed

+55
-13
lines changed

src/layers/stateless.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
11
"""
22
flatten(x::AbstractArray)
33
4-
Reshape arbitrarly-shaped input into a matrix-shaped output
5-
preserving the last dimension size.
6-
Equivalent to `reshape(x, :, size(x)[end])`.
4+
Reshape arbitrarly-shaped input into a matrix-shaped output,
5+
preserving the size of the last dimension.
6+
7+
See also [`unsqueeze`](@ref).
8+
9+
# Examples
10+
```jldoctest
11+
julia> rand(3,4,5) |> Flux.flatten |> size
12+
(12, 5)
13+
14+
julia> xs = rand(Float32, 10,10,3,7);
15+
16+
julia> m = Chain(Conv((3,3), 3=>4, pad=1), Flux.flatten, Dense(400,33));
17+
18+
julia> xs |> m[1] |> size
19+
(10, 10, 4, 7)
20+
21+
julia> xs |> m |> size
22+
(33, 7)
23+
```
724
"""
825
function flatten(x::AbstractArray)
926
return reshape(x, :, size(x)[end])

src/utils.jl

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,23 @@ create_bias(x, ::Any...) = x
241241
"""
242242
unsqueeze(xs, dim)
243243
244-
Return `xs` reshaped into an `Array` one dimensionality higher than `xs`,
244+
Return `xs` reshaped into an array one dimensionality higher than `xs`,
245245
where `dim` indicates in which dimension `xs` is extended.
246246
247+
See also [`flatten`](@ref), [`stack`](@ref).
248+
247249
# Examples
248250
```jldoctest
251+
julia> Flux.unsqueeze([1 2; 3 4], 2)
252+
2×1×2 Array{Int64,3}:
253+
[:, :, 1] =
254+
1
255+
3
256+
257+
[:, :, 2] =
258+
2
259+
4
260+
249261
julia> xs = [[1, 2], [3, 4], [5, 6]]
250262
3-element Array{Array{Int64,1},1}:
251263
[1, 2]
@@ -255,19 +267,29 @@ julia> xs = [[1, 2], [3, 4], [5, 6]]
255267
julia> Flux.unsqueeze(xs, 1)
256268
1×3 Array{Array{Int64,1},2}:
257269
[1, 2] [3, 4] [5, 6]
270+
```
271+
"""
272+
unsqueeze(xs::AbstractArray, dim::Integer) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
258273

259-
julia> Flux.unsqueeze([1 2; 3 4], 2)
260-
2×1×2 Array{Int64,3}:
261-
[:, :, 1] =
262-
1
263-
3
274+
"""
275+
unsqueeze(dim)
264276
265-
[:, :, 2] =
266-
2
267-
4
277+
Returns a function which, acting on an array, inserts a dimension of size 1 at `dim`.
278+
279+
# Examples
280+
```jldoctest
281+
julia> rand(21, 22, 23) |> Flux.unsqueeze(2) |> size
282+
(21, 1, 22, 23)
283+
284+
julia> m = Chain(Flux.unsqueeze(3), Flux.unsqueeze(4), Conv((3,3), 1=>7, pad=SamePad()));
285+
286+
julia> rand(Float32, 10, 10) |> m |> size
287+
(10, 10, 7, 1)
268288
```
269289
"""
270-
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
290+
unsqueeze(dim::Integer) = Base.Fix2(unsqueeze, dim)
291+
292+
Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = print(io, "unsqueeze(", u.x, ")")
271293

272294
"""
273295
stack(xs, dim)

test/outputsize.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
m = flatten
2323
@test outputsize(m, (5, 5, 3, 10)) == (75, 10)
2424

25+
m = Flux.unsqueeze(3)
26+
@test outputsize(m, (5, 7, 13)) == (5, 7, 1, 13)
27+
2528
m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10))
2629
@test outputsize(m, (10, 10, 3, 50)) == (10, 50)
2730
@test outputsize(m, (10, 10, 3, 2)) == (10, 2)

0 commit comments

Comments
 (0)