Skip to content

Commit 37c838d

Browse files
author
Michael Abbott
committed
one-arg Fix2 unsqueeze method
1 parent 1ac78b5 commit 37c838d

File tree

2 files changed

+54
-13
lines changed

2 files changed

+54
-13
lines changed

src/layers/stateless.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
11
"""
22
flatten(x::AbstractArray)
33
4-
Reshape arbitrarly-shaped input into a matrix-shaped output
5-
preserving the last dimension size.
6-
Equivalent to `reshape(x, :, size(x)[end])`.
4+
Reshape arbitrarly-shaped input into a matrix-shaped output,
5+
preserving the size of the last dimension.
6+
7+
See also [`unsqueeze`](@ref).
8+
9+
# Examples
10+
```jldoctest
11+
julia> rand(3,4,5) |> Flux.flatten |> size
12+
(12, 5)
13+
14+
julia> xs = rand(Float32, 10,10,3,7);
15+
16+
julia> m = Chain(Conv((3,3), 3=>4, pad=1), Flux.flatten, Dense(400,33));
17+
18+
julia> xs |> m[1] |> size
19+
(10, 10, 4, 7)
20+
21+
julia> xs |> m |> size
22+
(33, 7)
23+
```
724
"""
825
function flatten(x::AbstractArray)
926
return reshape(x, :, size(x)[end])

src/utils.jl

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,23 @@ create_bias(x, ::Any...) = x
241241
"""
242242
unsqueeze(xs, dim)
243243
244-
Return `xs` reshaped into an `Array` one dimensionality higher than `xs`,
244+
Return `xs` reshaped into an array one dimensionality higher than `xs`,
245245
where `dim` indicates in which dimension `xs` is extended.
246246
247+
See also [`flatten`](@ref), [`stack`](@ref).
248+
247249
# Examples
248250
```jldoctest
251+
julia> Flux.unsqueeze([1 2; 3 4], 2)
252+
2×1×2 Array{Int64,3}:
253+
[:, :, 1] =
254+
1
255+
3
256+
257+
[:, :, 2] =
258+
2
259+
4
260+
249261
julia> xs = [[1, 2], [3, 4], [5, 6]]
250262
3-element Array{Array{Int64,1},1}:
251263
[1, 2]
@@ -255,19 +267,31 @@ julia> xs = [[1, 2], [3, 4], [5, 6]]
255267
julia> Flux.unsqueeze(xs, 1)
256268
1×3 Array{Array{Int64,1},2}:
257269
[1, 2] [3, 4] [5, 6]
270+
```
271+
"""
272+
unsqueeze(xs::AbstractArray, dim::Integer) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
258273

259-
julia> Flux.unsqueeze([1 2; 3 4], 2)
260-
2×1×2 Array{Int64,3}:
261-
[:, :, 1] =
262-
1
263-
3
274+
"""
275+
unsqueeze(dim)
264276
265-
[:, :, 2] =
266-
2
267-
4
277+
Returns a function which, acting on an array, inserts a dimension of size 1 at `dim`.
278+
279+
# Examples
280+
```jldoctest
281+
julia> rand(21, 22, 23) |> Flux.unsqueeze(2) |> size
282+
(21, 1, 22, 23)
283+
284+
julia> m = Chain(Flux.unsqueeze(3), Flux.unsqueeze(4), Conv((3,3), 1=>7, pad=SamePad()));
285+
286+
julia> rand(Float32, 10, 10) |> m |> size
287+
(10, 10, 7, 1)
268288
```
269289
"""
270-
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
290+
unsqueeze(dim::Integer) = Base.Fix2(unsqueeze, dim)
291+
292+
Base.show(io::IO, u::Base.Fix2{typeof(unsqueeze)}) = print(io, "unsqueeze(", u.x, ")")
293+
Base.show(io::IO, ::MIME"text/plain", u::Base.Fix2{typeof(unsqueeze)}) = show(io, u) # at top level
294+
Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = show(io, u) # within Chain etc.
271295

272296
"""
273297
stack(xs, dim)

0 commit comments

Comments
 (0)