@@ -13,59 +13,68 @@ struct DataLoader{D,R<:AbstractRNG}
13
13
end
14
14
15
15
"""
16
- DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
16
+ Flux. DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
17
17
18
- An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
18
+ An object that iterates over mini-batches of `data`,
19
+ each mini-batch containing `batchsize` observations
19
20
(except possibly the last one).
20
21
21
22
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
22
- The last dimension in each tensor is considered to be the observation dimension.
23
+ The last dimension in each tensor is the observation dimension, i.e. the one
24
+ divided into mini-batches.
23
25
24
26
If `shuffle=true`, shuffles the observations each time iterations are re-started.
25
27
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
26
28
27
29
The original data is preserved in the `data` field of the DataLoader.
28
30
29
- Usage example:
31
+ # Examples
32
+ ```jldoctest
33
+ julia> Xtrain = rand(10, 100);
30
34
31
- Xtrain = rand(10, 100)
32
- train_loader = DataLoader(Xtrain, batchsize=2)
33
- # iterate over 50 mini-batches of size 2
34
- for x in train_loader
35
- @assert size(x) == (10, 2)
36
- ...
37
- end
35
+ julia> array_loader = Flux.DataLoader(Xtrain, batchsize=2);
38
36
39
- train_loader.data # original dataset
37
+ julia> for x in array_loader
38
+ @assert size(x) == (10, 2)
39
+ # do something with x, 50 times
40
+ end
40
41
41
- # similar, but yielding tuples
42
- train_loader = DataLoader((Xtrain,), batchsize=2)
43
- for (x,) in train_loader
44
- @assert size(x) == (10, 2)
45
- ...
46
- end
42
+ julia> array_loader.data === Xtrain
43
+ true
47
44
48
- Xtrain = rand(10, 100)
49
- Ytrain = rand(100)
50
- train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
51
- for epoch in 1:100
52
- for (x, y) in train_loader
53
- @assert size(x) == (10, 2)
54
- @assert size(y) == (2,)
55
- ...
56
- end
57
- end
45
+ julia> tuple_loader = Flux.DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples
58
46
59
- # train for 10 epochs
60
- using IterTools: ncycle
61
- Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
47
+ julia> for x in tuple_loader
48
+ @assert x isa Tuple{Matrix}
49
+ @assert size(x[1]) == (10, 2)
50
+ end
62
51
63
- # can use NamedTuple to name tensors
64
- train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
65
- for datum in train_loader
66
- @assert size(datum.images) == (10, 2)
67
- @assert size(datum.labels) == (2,)
68
- end
52
+ julia> Ytrain = rand('a':'z', 100); # now make a DataLoader returning 2-element named tuples
53
+
54
+ julia> train_loader = Flux.DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true);
55
+
56
+ julia> for epoch in 1:100
57
+ for (x, y) in train_loader # access via tuple destructuring
58
+ @assert size(x) == (10, 5)
59
+ @assert size(y) == (5,)
60
+ # loss += f(x, y) # etc, runs 100 * 20 times
61
+ end
62
+ end
63
+
64
+ julia> first(train_loader) isa NamedTuple{(:data, :label)}
65
+ true
66
+
67
+ julia> first(train_loader).label isa Vector{Char} # acces via property name
68
+ true
69
+
70
+ julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true
71
+ false
72
+
73
+ julia> foreach(println∘size, Flux.DataLoader(rand(10, 64), batchsize=30))
74
+ (10, 30)
75
+ (10, 30)
76
+ (10, 4) # partial=false would omit this
77
+ ```
69
78
"""
70
79
function DataLoader (data; batchsize= 1 , shuffle= false , partial= true , rng= GLOBAL_RNG)
71
80
batchsize > 0 || throw (ArgumentError (" Need positive batchsize" ))
0 commit comments