@@ -241,11 +241,23 @@ create_bias(x, ::Any...) = x
241
241
"""
242
242
unsqueeze(xs, dim)
243
243
244
- Return `xs` reshaped into an `Array` one dimensionality higher than `xs`,
244
+ Return `xs` reshaped into an array one dimensionality higher than `xs`,
245
245
where `dim` indicates in which dimension `xs` is extended.
246
246
247
+ See also [`flatten`](@ref), [`stack`](@ref).
248
+
247
249
# Examples
248
250
```jldoctest
251
+ julia> Flux.unsqueeze([1 2; 3 4], 2)
252
+ 2×1×2 Array{Int64,3}:
253
+ [:, :, 1] =
254
+ 1
255
+ 3
256
+
257
+ [:, :, 2] =
258
+ 2
259
+ 4
260
+
249
261
julia> xs = [[1, 2], [3, 4], [5, 6]]
250
262
3-element Array{Array{Int64,1},1}:
251
263
[1, 2]
@@ -255,19 +267,31 @@ julia> xs = [[1, 2], [3, 4], [5, 6]]
255
267
julia> Flux.unsqueeze(xs, 1)
256
268
1×3 Array{Array{Int64,1},2}:
257
269
[1, 2] [3, 4] [5, 6]
270
+ ```
271
+ """
272
+ unsqueeze (xs:: AbstractArray , dim:: Integer ) = reshape (xs, (size (xs)[1 : dim- 1 ]. .. , 1 , size (xs)[dim: end ]. .. ))
258
273
259
- julia> Flux.unsqueeze([1 2; 3 4], 2)
260
- 2×1×2 Array{Int64,3}:
261
- [:, :, 1] =
262
- 1
263
- 3
274
+ """
275
+ unsqueeze(dim)
264
276
265
- [:, :, 2] =
266
- 2
267
- 4
277
+ Returns a function which, acting on an array, inserts a dimension of size 1 at `dim`.
278
+
279
+ # Examples
280
+ ```jldoctest
281
+ julia> rand(21, 22, 23) |> Flux.unsqueeze(2) |> size
282
+ (21, 1, 22, 23)
283
+
284
+ julia> m = Chain(Flux.unsqueeze(3), Flux.unsqueeze(4), Conv((3,3), 1=>7, pad=SamePad()));
285
+
286
+ julia> rand(Float32, 10, 10) |> m |> size
287
+ (10, 10, 7, 1)
268
288
```
269
289
"""
270
- unsqueeze (xs, dim) = reshape (xs, (size (xs)[1 : dim- 1 ]. .. , 1 , size (xs)[dim: end ]. .. ))
290
+ unsqueeze (dim:: Integer ) = Base. Fix2 (unsqueeze, dim)
291
+
292
+ Base. show (io:: IO , u:: Base.Fix2{typeof(unsqueeze)} ) = print (io, " unsqueeze(" , u. x, " )" )
293
+ Base. show (io:: IO , :: MIME"text/plain" , u:: Base.Fix2{typeof(unsqueeze)} ) = show (io, u) # at top level
294
+ Base. show_function (io:: IO , u:: Base.Fix2{typeof(unsqueeze)} , :: Bool ) = show (io, u) # within Chain etc.
271
295
272
296
"""
273
297
stack(xs, dim)
0 commit comments