Skip to content

Commit 78dd3f6

Browse files
author
cossio
committed
make unsqueeze type stable
1 parent 91d42a9 commit 78dd3f6

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/utils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,10 @@ julia> Flux.unsqueeze(xs, 1)
420420
[1, 2] [3, 4] [5, 6]
421421
```
422422
"""
423-
unsqueeze(xs::AbstractArray, dim::Integer) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
423+
function unsqueeze(xs::AbstractArray, dim::Integer)
424+
sz = ntuple(i -> i < dim ? size(xs, i) : i == dim ? 1 : size(xs, i - 1), ndims(xs) + 1)
425+
return reshape(xs, sz)
426+
end
424427

425428
"""
426429
unsqueeze(dim)

test/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ using Test
99

1010
@testset "unsqueeze" begin
1111
x = randn(2, 3, 2)
12-
@test unsqueeze(x, 1) == reshape(x, 1, 2, 3, 2)
13-
@test unsqueeze(x, 2) == reshape(x, 2, 1, 3, 2)
14-
@test unsqueeze(x, 3) == reshape(x, 2, 3, 1, 2)
15-
@test unsqueeze(x, 4) == reshape(x, 2, 3, 2, 1)
12+
@test @inferred(unsqueeze(x, 1)) == reshape(x, 1, 2, 3, 2)
13+
@test @inferred(unsqueeze(x, 2)) == reshape(x, 2, 1, 3, 2)
14+
@test @inferred(unsqueeze(x, 3)) == reshape(x, 2, 3, 1, 2)
15+
@test @inferred(unsqueeze(x, 4)) == reshape(x, 2, 3, 2, 1)
1616
end
1717

1818
@testset "Throttle" begin

0 commit comments

Comments
 (0)