Skip to content

Commit 2c8c3c7

Browse files
Merge #1726
1726: add unbatch r=CarloLucibello a=CarloLucibello Add `unbatch` for convenience and consistency (we currently have `stack/unstack` but only `batch` with no `unbatch`). Can be specialized by downstream packages on their custom batched types. ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [ ] API changes require approval from a committer (different from the author, if applicable) Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
2 parents 934b9db + 61293d5 commit 2c8c3c7

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

docs/src/utilities.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Flux.unstack
1414
Flux.chunk
1515
Flux.frequencies
1616
Flux.batch
17+
Flux.unbatch
1718
Flux.batchseq
1819
Base.rpad(v::AbstractVector, n::Integer, p)
1920
```

src/utils.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,8 @@ squeezebatch(x) = reshape(x, head(size(x)))
542542
543543
Batch the arrays in `xs` into a single array.
544544
545+
See also [`unbatch`](@ref)
546+
545547
# Examples
546548
```jldoctest
547549
julia> Flux.batch([[1,2,3],[4,5,6]])
@@ -561,6 +563,28 @@ function batch(xs)
561563
return data
562564
end
563565

566+
"""
567+
unbatch(x)
568+
569+
Reverse of the [`batch`](@ref) operation,
570+
unstacking the last dimension of the array `x`.
571+
572+
See also [`unstack`](@ref).
573+
574+
# Examples
575+
576+
```jldoctest
577+
julia> Flux.unbatch([1 3 5 7;
578+
2 4 6 8])
579+
4-element Vector{Vector{Int64}}:
580+
[1, 2]
581+
[3, 4]
582+
[5, 6]
583+
[7, 8]
584+
"""
585+
unbatch(x::AbstractArray) = unstack(x, ndims(x))
586+
unbatch(x::AbstractVector) = x
587+
564588
"""
565589
Return the given sequence padded with `p` up to a maximum length of `n`.
566590

test/utils.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Flux
2-
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, sparse_init, stack, unstack, Zeros
2+
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
3+
kaiming_normal, kaiming_uniform, orthogonal,
4+
sparse_init, stack, unstack, Zeros, batch, unbatch
35
using StatsBase: var, std
46
using Random
57
using Test
@@ -329,6 +331,25 @@ end
329331
@test stack(unstack(stacked_array, 1), 1) == stacked_array
330332
end
331333

334+
335+
@testset "Batching" begin
336+
stacked_array=[ 8 9 3 5
337+
9 6 6 9
338+
9 1 7 2
339+
7 4 10 6 ]
340+
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
341+
@test unbatch(stacked_array) == unstacked_array
342+
@test batch(unstacked_array) == stacked_array
343+
344+
# no-op for vector of non-arrays
345+
@test batch([1,2,3]) == [1,2,3]
346+
@test unbatch([1,2,3]) == [1,2,3]
347+
348+
# generic iterable
349+
@test batch(ones(2) for i=1:3) == ones(2, 3)
350+
@test unbatch(ones(2, 3)) == [ones(2) for i=1:3]
351+
end
352+
332353
@testset "Param remapping" begin
333354
ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense
334355
dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout))

0 commit comments

Comments
 (0)