Skip to content

Commit ea26f45

Browse files
bors[bot]cossio
andauthored
Merge #1759
1759: Make unsqueeze type stable r=CarloLucibello a=cossio This PR makes Flux.unsqueeze type stable and improves its performance. Closes #1737. Please see linked issue for comparison. I also added some tests. Co-authored-by: cossio <j.cossio.diaz@gmail.com>
2 parents 69afb67 + 78dd3f6 commit ea26f45

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

src/utils.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_i
285285
"""
286286
identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0)
287287
288-
Return an `Array` of size `dims` which yields an identity mapping when used as parameters in
288+
Return an `Array` of size `dims` which yields an identity mapping when used as parameters in
289289
most Flux layers. Use `gain` to scale the identity by a constant.
290290
291291
Often useful in the context of transfer learning, i.e when one wants to add more capacity to
@@ -297,10 +297,10 @@ Equivalent to `Base.circshift(identity(dims...), shift)`.
297297
Some caveats: Not all layers will be identity mapping when used with this init. Exceptions
298298
include recurrent layers, `DepthwiseConv` and normalization layers.
299299
300-
Also note that layers must have `input_size == output_size` for identity mapping to be
300+
Also note that layers must have `input_size == output_size` for identity mapping to be
301301
possible. When this is not the case, extra dimensions of the array are padded with zeros.
302302
303-
For convolutional layers, in addition to the above, the kernel sizes must also be odd and
303+
For convolutional layers, in addition to the above, the kernel sizes must also be odd and
304304
padding must be applied so that output feature maps have the same size as input feature maps,
305305
e.g by using [`SamePad`](@ref).
306306
@@ -420,7 +420,10 @@ julia> Flux.unsqueeze(xs, 1)
420420
[1, 2] [3, 4] [5, 6]
421421
```
422422
"""
423-
unsqueeze(xs::AbstractArray, dim::Integer) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
423+
function unsqueeze(xs::AbstractArray, dim::Integer)
424+
sz = ntuple(i -> i < dim ? size(xs, i) : i == dim ? 1 : size(xs, i - 1), ndims(xs) + 1)
425+
return reshape(xs, sz)
426+
end
424427

425428
"""
426429
unsqueeze(dim)
@@ -574,7 +577,7 @@ See also [`unstack`](@ref).
574577
# Examples
575578
576579
```jldoctest
577-
julia> Flux.unbatch([1 3 5 7;
580+
julia> Flux.unbatch([1 3 5 7;
578581
2 4 6 8])
579582
4-element Vector{Vector{Int64}}:
580583
[1, 2]

test/utils.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
using Flux
22
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
3-
kaiming_normal, kaiming_uniform, orthogonal,
4-
sparse_init, stack, unstack, Zeros, batch, unbatch
3+
kaiming_normal, kaiming_uniform, orthogonal,
4+
sparse_init, stack, unstack, Zeros, batch, unbatch,
5+
unsqueeze
56
using StatsBase: var, std
67
using Random
78
using Test
89

10+
@testset "unsqueeze" begin
11+
x = randn(2, 3, 2)
12+
@test @inferred(unsqueeze(x, 1)) == reshape(x, 1, 2, 3, 2)
13+
@test @inferred(unsqueeze(x, 2)) == reshape(x, 2, 1, 3, 2)
14+
@test @inferred(unsqueeze(x, 3)) == reshape(x, 2, 3, 1, 2)
15+
@test @inferred(unsqueeze(x, 4)) == reshape(x, 2, 3, 2, 1)
16+
end
17+
918
@testset "Throttle" begin
1019
@testset "default behaviour" begin
1120
a = []
@@ -178,10 +187,10 @@ end
178187

179188
@testset "$layer ID mapping with kernelsize $kernelsize" for layer in (Conv, ConvTranspose, CrossCor), kernelsize in (
180189
(1,),
181-
(3,),
182-
(1, 3),
183-
(3, 5),
184-
(3, 5, 7))
190+
(3,),
191+
(1, 3),
192+
(3, 5),
193+
(3, 5, 7))
185194
nch = 3
186195
l = layer(kernelsize, nch=>nch, init=identity_init, pad=SamePad())
187196

@@ -333,9 +342,9 @@ end
333342

334343

335344
@testset "Batching" begin
336-
stacked_array=[ 8 9 3 5
337-
9 6 6 9
338-
9 1 7 2
345+
stacked_array=[ 8 9 3 5
346+
9 6 6 9
347+
9 1 7 2
339348
7 4 10 6 ]
340349
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
341350
@test unbatch(stacked_array) == unstacked_array
@@ -445,7 +454,7 @@ end
445454

446455
modules = Flux.modules(Chain(SkipConnection(
447456
Conv((2,3), 4=>5; pad=6, stride=7),
448-
+),
457+
+),
449458
LayerNorm(8)))
450459
@test length(modules) == 5
451460
end
@@ -475,16 +484,16 @@ end
475484
@testset "early stopping" begin
476485
@testset "args & kwargs" begin
477486
es = Flux.early_stopping((x; y = 1) -> x + y, 10; min_dist=3)
478-
487+
479488
n_iter = 0
480489
while n_iter < 99
481490
es(-n_iter; y=-n_iter) && break
482491
n_iter += 1
483492
end
484-
493+
485494
@test n_iter == 9
486495
end
487-
496+
488497
@testset "distance" begin
489498
es = Flux.early_stopping(identity, 10; distance=(best_score, score) -> score - best_score)
490499

@@ -496,7 +505,7 @@ end
496505

497506
@test n_iter == 99
498507
end
499-
508+
500509
@testset "init_score" begin
501510
es = Flux.early_stopping(identity, 10; init_score=10)
502511

0 commit comments

Comments
 (0)