Skip to content

Commit da11bf2

Browse files
authored
Allow cpu(::DataLoader) (#2388)
1 parent 7e7d4fc commit da11bf2

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

src/functor.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,9 @@ function _metal end
403403

404404
"""
405405
gpu(data::DataLoader)
406+
cpu(data::DataLoader)
406407
407-
Transforms a given `DataLoader` to apply `gpu` to each batch of data,
408+
Transforms a given `DataLoader` to apply `gpu` or `cpu` to each batch of data,
408409
when iterated over. (If no GPU is available, this does nothing.)
409410
410411
# Example
@@ -456,6 +457,18 @@ function gpu(d::MLUtils.DataLoader)
456457
)
457458
end
458459

460+
function cpu(d::MLUtils.DataLoader)
461+
MLUtils.DataLoader(MLUtils.mapobs(cpu, d.data),
462+
d.batchsize,
463+
d.buffer,
464+
d.partial,
465+
d.shuffle,
466+
d.parallel,
467+
d.collate,
468+
d.rng,
469+
)
470+
end
471+
459472
# Defining device interfaces.
460473
"""
461474
Flux.AbstractDevice <: Function

test/data.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Flux: DataLoader
12
using Random
23

34
@testset "DataLoader" begin
@@ -14,6 +15,11 @@ using Random
1415
@test batches[2] == X[:,3:4]
1516
@test batches[3] == X[:,5:5]
1617

18+
d_cpu = d |> cpu # does nothing but shouldn't error
19+
@test d_cpu isa DataLoader
20+
@test first(d_cpu) == X[:,1:2]
21+
@test length(d_cpu) == 3
22+
1723
d = DataLoader(X, batchsize=2, partial=false)
1824
# @inferred first(d)
1925
batches = collect(d)

test/ext_cuda/cuda.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,14 @@ end
182182
X = randn(Float64, 3, 33)
183183
pre1 = Flux.DataLoader(X |> gpu; batchsize=13, shuffle=false)
184184
post1 = Flux.DataLoader(X; batchsize=13, shuffle=false) |> gpu
185+
rev1 = pre1 |> cpu # inverse operation
185186
for epoch in 1:2
186-
for (p, q) in zip(pre1, post1)
187+
for (p, q, a) in zip(pre1, post1, rev1)
187188
@test p isa CuArray{Float32}
188189
@test q isa CuArray{Float32}
189190
@test p q
191+
@test a isa Array{Float32}
192+
@test a Array(p)
190193
end
191194
end
192195

0 commit comments

Comments
 (0)