Skip to content

Commit 3b16df2

Browse files
committed
tweak
1 parent 66bfa20 commit 3b16df2

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/data/dataloader.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
2323
The last dimension in each tensor is the observation dimension, i.e. the one
2424
divided into mini-batches.
2525
26-
If `shuffle=true`, shuffles the observations each time iterations are re-started.
27-
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
26+
If `shuffle=true`, it shuffles the observations each time iterations are re-started.
27+
If `partial=false` and the number of observations is not divisible by the batchsize,
28+
then the last mini-batch is dropped.
2829
2930
The original data is preserved in the `data` field of the DataLoader.
3031
@@ -70,10 +71,10 @@ true
7071
julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true
7172
false
7273
73-
julia> foreach(println∘size, Flux.DataLoader(rand(10, 64), batchsize=30)) # partial=false would omit last
74-
(10, 30)
75-
(10, 30)
76-
(10, 4)
74+
julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize=30)) # partial=false would omit last
75+
10×30 Matrix{Int8}
76+
10×30 Matrix{Int8}
77+
10×4 Matrix{Int8}
7778
```
7879
"""
7980
function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)

0 commit comments

Comments
 (0)