Skip to content

Commit 61293d5

Browse files
unbatch vector
1 parent 29079cd commit 61293d5

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ julia> Flux.unbatch([1 3 5 7;
583583
[7, 8]
584584
"""
585585
unbatch(x::AbstractArray) = unstack(x, ndims(x))
586+
unbatch(x::AbstractVector) = x
586587

587588
"""
588589
Return the given sequence padded with `p` up to a maximum length of `n`.

test/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,14 @@ end
340340
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
341341
@test unbatch(stacked_array) == unstacked_array
342342
@test batch(unstacked_array) == stacked_array
343+
344+
# no-op for vector of non-arrays
345+
@test batch([1,2,3]) == [1,2,3]
346+
@test unbatch([1,2,3]) == [1,2,3]
347+
348+
# generic iterable
349+
@test batch(ones(2) for i=1:3) == ones(2, 3)
350+
@test unbatch(ones(2, 3)) == [ones(2) for i=1:3]
343351
end
344352

345353
@testset "Param remapping" begin

0 commit comments

Comments
 (0)