Skip to content

Commit c85f497

Browse files
committed
make DataLoader's docstring a doctest
1 parent e7686b2 commit c85f497

File tree

1 file changed

+46
-37
lines changed

1 file changed

+46
-37
lines changed

src/data/dataloader.jl

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,59 +13,68 @@ struct DataLoader{D,R<:AbstractRNG}
1313
end
1414

1515
"""
16-
DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
16+
Flux.DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
1717
18-
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
18+
An object that iterates over mini-batches of `data`,
19+
each mini-batch containing `batchsize` observations
1920
(except possibly the last one).
2021
2122
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
22-
The last dimension in each tensor is considered to be the observation dimension.
23+
The last dimension in each tensor is the observation dimension, i.e. the one
24+
divided into mini-batches.
2325
2426
If `shuffle=true`, shuffles the observations each time iterations are re-started.
2527
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
2628
2729
The original data is preserved in the `data` field of the DataLoader.
2830
29-
Usage example:
31+
# Examples
32+
```jldoctest
33+
julia> Xtrain = rand(10, 100);
3034
31-
Xtrain = rand(10, 100)
32-
train_loader = DataLoader(Xtrain, batchsize=2)
33-
# iterate over 50 mini-batches of size 2
34-
for x in train_loader
35-
@assert size(x) == (10, 2)
36-
...
37-
end
35+
julia> array_loader = Flux.DataLoader(Xtrain, batchsize=2);
3836
39-
train_loader.data # original dataset
37+
julia> for x in array_loader
38+
@assert size(x) == (10, 2)
39+
# do something with x, 50 times
40+
end
4041
41-
# similar, but yielding tuples
42-
train_loader = DataLoader((Xtrain,), batchsize=2)
43-
for (x,) in train_loader
44-
@assert size(x) == (10, 2)
45-
...
46-
end
42+
julia> array_loader.data === Xtrain
43+
true
4744
48-
Xtrain = rand(10, 100)
49-
Ytrain = rand(100)
50-
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
51-
for epoch in 1:100
52-
for (x, y) in train_loader
53-
@assert size(x) == (10, 2)
54-
@assert size(y) == (2,)
55-
...
56-
end
57-
end
45+
julia> tuple_loader = Flux.DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples
5846
59-
# train for 10 epochs
60-
using IterTools: ncycle
61-
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
47+
julia> for x in tuple_loader
48+
@assert x isa Tuple{Matrix}
49+
@assert size(x[1]) == (10, 2)
50+
end
6251
63-
# can use NamedTuple to name tensors
64-
train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
65-
for datum in train_loader
66-
@assert size(datum.images) == (10, 2)
67-
@assert size(datum.labels) == (2,)
68-
end
52+
julia> Ytrain = rand('a':'z', 100); # now make a DataLoader returning 2-element named tuples
53+
54+
julia> train_loader = Flux.DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true);
55+
56+
julia> for epoch in 1:100
57+
for (x, y) in train_loader # access via tuple destructuring
58+
@assert size(x) == (10, 5)
59+
@assert size(y) == (5,)
60+
# loss += f(x, y) # etc, runs 100 * 20 times
61+
end
62+
end
63+
64+
julia> first(train_loader) isa NamedTuple{(:data, :label)}
65+
true
66+
67+
julia> first(train_loader).label isa Vector{Char} # acces via property name
68+
true
69+
70+
julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true
71+
false
72+
73+
julia> foreach(println∘size, Flux.DataLoader(rand(10, 64), batchsize=30))
74+
(10, 30)
75+
(10, 30)
76+
(10, 4) # partial=false would omit this
77+
```
6978
"""
7079
function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
7180
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))

0 commit comments

Comments
 (0)