Skip to content

Commit b54f8bd

Browse files
mcabbottmaleadt
andauthored
Use adapt symmetrically in CuIterator (#1769)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent f78cd73 commit b54f8bd

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/iterator.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ export CuIterator
55
66
Return a `CuIterator` that can iterate through the provided `batches` via `Base.iterate`.
77
8-
Upon each iteration, the current `batch` is adapted to the GPU (via `map(x -> adapt(CuArray, x), batch)`)
8+
Upon each iteration, the current `batch` is copied to the GPU,
99
and the previous iteration is marked as freeable from GPU memory (via `unsafe_free!`).
10+
Both of these use `adapt`, so that each `batch` can be an array, an array of arrays,
11+
or a more complex object such as a nested set of NamedTuples, which is explored recursively.
1012
1113
This abstraction is useful for batching data into GPU memory in a manner that
1214
allows old iterations to potentially be freed (or marked as reusable) earlier
@@ -20,10 +22,10 @@ end
2022

2123
function Base.iterate(c::CuIterator, state...)
2224
item = iterate(c.batches, state...)
23-
isdefined(c, :previous) && foreach(unsafe_free!, c.previous)
25+
isdefined(c, :previous) && adapt(CuIteratorFree, c.previous)
2426
item === nothing && return nothing
2527
batch, next_state = item
26-
cubatch = map(x -> adapt(CuArray, x), batch)
28+
cubatch = adapt(CuIterator, batch)
2729
c.previous = cubatch
2830
return cubatch, next_state
2931
end
@@ -35,3 +37,19 @@ Base.axes(c::CuIterator) = axes(c.batches) # required for HasShape{N}
3537
Base.IteratorEltype(::Type{CuIterator{B}}) where {B} = Base.IteratorEltype(B)
3638
Base.eltype(c::CuIterator) = eltype(c.batches) # required for HasEltype
3739

40+
# This struct exists to control adapt for clean-up-afterwards step:
41+
struct CuIteratorFree end
42+
Adapt.adapt_storage(::Type{CuIteratorFree}, x::CuArray) = unsafe_free!(x)
43+
44+
# We re-purpose struct CuIterator for the matching transfer-before-use step,
45+
# mostly fall back to adapt(CuArray, x) which recurses into Tuples etc:
46+
Adapt.adapt_storage(::Type{<:CuIterator}, x) = adapt(CuArray, x)
47+
48+
# But unlike adapt(CuArray, x), returse into arrays of arrays:
49+
function Adapt.adapt_storage(::Type{<:CuIterator}, xs::AbstractArray{T}) where T
50+
isbitstype(T) ? adapt(CuArray, xs) : map(adapt(CuArray), xs)
51+
end
52+
function Adapt.adapt_storage(::Type{CuIteratorFree}, xs::AbstractArray{T}) where T
53+
foreach(adapt(CuIteratorFree), xs)
54+
xs
55+
end

test/iterator.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,18 @@ end
2525
@test eltype(cubatches) == eltype(batch for batch in batches) == Any
2626
@test Base.IteratorEltype(typeof(CuIterator(batches))) isa Base.HasEltype
2727
@test eltype(CuIterator(batches)) == eltype(batches) # Vector
28+
29+
it_nt = CuIterator((x=Float32[i,i/2], y=i) for i in 1:4)
30+
@test first(it_nt).x isa CuArray{Float32}
31+
batch1, state = iterate(it_nt)
32+
@test batch1.x == cu([1,1/2])
33+
batch2, _ = iterate(it_nt, state)
34+
@test batch2.x == cu([2,2/2])
35+
@test batch1.x.storage === nothing # unsafe_free! has worked inside
36+
37+
it_vec = CuIterator([[i,i/2], [i/3, i/4]] for i in 1:4)
38+
@test first(it_vec)[1] isa CuArray{Float64}
39+
40+
using StaticArrays: SVector, SA
41+
it_static = CuIterator([SA[i,i/2], SA[i/3, i/4]] for i in 1:4)
42+
@test first(it_static) isa CuArray{<:SVector}

0 commit comments

Comments
 (0)