Skip to content

Commit 6829998

Browse files
authored
Make it possible to customize the CuIterator adaptor. (#2297)
1 parent 890fb04 commit 6829998

File tree

2 files changed

+48
-28
lines changed

2 files changed

+48
-28
lines changed

src/iterator.jl

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,69 @@
11
export CuIterator
22

33
"""
4-
CuIterator(batches)
4+
CuIterator([to], batches)
55
6-
Return a `CuIterator` that can iterate through the provided `batches` via `Base.iterate`.
6+
Create a `CuIterator` that iterates through the provided `batches` via `iterate`. Upon each
7+
iteration, the current `batch` is copied to the GPU, and the previous iteration is marked as
8+
freeable from GPU memory (via `unsafe_free!`).
79
8-
Upon each iteration, the current `batch` is copied to the GPU,
9-
and the previous iteration is marked as freeable from GPU memory (via `unsafe_free!`).
10-
Both of these use `adapt`, so that each `batch` can be an array, an array of arrays,
11-
or a more complex object such as a nested set of NamedTuples, which is explored recursively.
10+
The conversion to GPU memory is done recursively, using Adapt.jl, so that each batch can be
11+
an array, an array of arrays, or more complex iterable objects. To customize the conversion,
12+
an adaptor can be specified as the first argument, e.g., to change the element type:
1213
13-
This abstraction is useful for batching data into GPU memory in a manner that
14-
allows old iterations to potentially be freed (or marked as reusable) earlier
15-
than they otherwise would via CuArray's internal polling mechanism.
14+
```julia
15+
julia> first(CuIterator([[1.]]))
16+
1-element CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}:
17+
1.0
18+
19+
julia> first(CuIterator(CuArray{Float32}, [[1.]]))
20+
1-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
21+
1.0
22+
```
23+
24+
This abstraction is useful for batching data into GPU memory in a manner that allows old
25+
iterations to potentially be freed (or marked as reusable) earlier than they otherwise would
26+
via `CuArray`'s internal polling mechanism.
1627
"""
17-
mutable struct CuIterator{B}
28+
mutable struct CuIterator{T,B}
29+
to::T
1830
batches::B
1931
previous::Any
20-
CuIterator(batches) = new{typeof(batches)}(batches)
32+
33+
CuIterator(batches) = CuIterator(nothing, batches)
34+
CuIterator(to, batches) = new{typeof(to),typeof(batches)}(to, batches)
2135
end
2236

2337
function Base.iterate(c::CuIterator, state...)
2438
item = iterate(c.batches, state...)
25-
isdefined(c, :previous) && adapt(CuIteratorFree, c.previous)
39+
isdefined(c, :previous) && adapt(CuIteratorFree(), c.previous)
2640
item === nothing && return nothing
2741
batch, next_state = item
28-
cubatch = adapt(CuIterator, batch)
42+
cubatch = adapt(c, batch)
2943
c.previous = cubatch
3044
return cubatch, next_state
3145
end
3246

33-
Base.IteratorSize(::Type{CuIterator{B}}) where {B} = Base.IteratorSize(B)
47+
Base.IteratorSize(::Type{CuIterator{T,B}}) where {T,B} = Base.IteratorSize(B)
3448
Base.length(c::CuIterator) = length(c.batches) # required for HasLength
3549
Base.axes(c::CuIterator) = axes(c.batches) # required for HasShape{N}
3650

37-
Base.IteratorEltype(::Type{CuIterator{B}}) where {B} = Base.IteratorEltype(B)
51+
Base.IteratorEltype(::Type{CuIterator{T,B}}) where {T,B} = Base.IteratorEltype(B)
3852
Base.eltype(c::CuIterator) = eltype(c.batches) # required for HasEltype
3953

40-
# This struct exists to control adapt for clean-up-afterwards step:
41-
struct CuIteratorFree end
42-
Adapt.adapt_storage(::Type{CuIteratorFree}, x::CuArray) = unsafe_free!(x)
43-
44-
# We re-purpose struct CuIterator for the matching transfer-before-use step,
45-
# mostly fall back to adapt(CuArray, x) which recurses into Tuples etc:
46-
Adapt.adapt_storage(::Type{<:CuIterator}, x) = adapt(CuArray, x)
47-
48-
# But unlike adapt(CuArray, x), returse into arrays of arrays:
49-
function Adapt.adapt_storage(::Type{<:CuIterator}, xs::AbstractArray{T}) where T
50-
isbitstype(T) ? adapt(CuArray, xs) : map(adapt(CuArray), xs)
54+
# adaptor for uploading
55+
Adapt.adapt_storage(c::CuIterator, x) = adapt(something(c.to, CuArray), x)
56+
## unlike adapt(CuArray, x), recurse into arrays of arrays
57+
function Adapt.adapt_storage(c::CuIterator, xs::AbstractArray{T}) where T
58+
to = something(c.to, CuArray)
59+
isbitstype(T) ? adapt(to, xs) : map(to, xs)
5160
end
52-
function Adapt.adapt_storage(::Type{CuIteratorFree}, xs::AbstractArray{T}) where T
53-
foreach(adapt(CuIteratorFree), xs)
61+
62+
# adaptor for clean-up
63+
struct CuIteratorFree end
64+
Adapt.adapt_storage(::CuIteratorFree, x::CuArray) = unsafe_free!(x)
65+
function Adapt.adapt_storage(::CuIteratorFree, xs::AbstractArray{T}) where T
66+
foreach(adapt(CuIteratorFree()), xs)
5467
xs
5568
end
69+

test/base/iterator.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ batch2, _ = iterate(it_nt, state)
3737
it_vec = CuIterator([[i,i/2], [i/3, i/4]] for i in 1:4)
3838
@test first(it_vec)[1] isa CuArray{Float64}
3939

40+
# test element type conversion using a custom adaptor
41+
it_float64 = CuIterator([[1.0]])
42+
@test first(it_float64) isa CuArray{Float64}
43+
it_float32 = CuIterator(CuArray{Float32}, [[1.0]])
44+
@test first(it_float32) isa CuArray{Float32}
45+
4046
using StaticArrays: SVector, SA
4147
it_static = CuIterator([SA[i,i/2], SA[i/3, i/4]] for i in 1:4)
4248
@test first(it_static) isa CuArray{<:SVector}

0 commit comments

Comments
 (0)