@@ -5,8 +5,10 @@ export CuIterator
5
5
6
6
Return a `CuIterator` that can iterate through the provided `batches` via `Base.iterate`.
7
7
8
- Upon each iteration, the current `batch` is adapted to the GPU (via `map(x -> adapt(CuArray, x), batch)`)
8
+ Upon each iteration, the current `batch` is copied to the GPU,
9
9
and the previous iteration is marked as freeable from GPU memory (via `unsafe_free!`).
10
+ Both of these use `adapt`, so that each `batch` can be an array, an array of arrays,
11
+ or a more complex object such as a nested set of NamedTuples, which is explored recursively.
10
12
11
13
This abstraction is useful for batching data into GPU memory in a manner that
12
14
allows old iterations to potentially be freed (or marked as reusable) earlier
20
22
21
23
function Base. iterate (c:: CuIterator , state... )
22
24
item = iterate (c. batches, state... )
23
- isdefined (c, :previous ) && foreach (unsafe_free! , c. previous)
25
+ isdefined (c, :previous ) && adapt (CuIteratorFree , c. previous)
24
26
item === nothing && return nothing
25
27
batch, next_state = item
26
- cubatch = map (x -> adapt (CuArray, x) , batch)
28
+ cubatch = adapt (CuIterator , batch)
27
29
c. previous = cubatch
28
30
return cubatch, next_state
29
31
end
@@ -35,3 +37,19 @@ Base.axes(c::CuIterator) = axes(c.batches) # required for HasShape{N}
35
37
Base. IteratorEltype (:: Type{CuIterator{B}} ) where {B} = Base. IteratorEltype (B)
36
38
Base. eltype (c:: CuIterator ) = eltype (c. batches) # required for HasEltype
37
39
40
+ # This struct exists to control adapt for clean-up-afterwards step:
41
+ struct CuIteratorFree end
42
+ Adapt. adapt_storage (:: Type{CuIteratorFree} , x:: CuArray ) = unsafe_free! (x)
43
+
44
+ # We re-purpose struct CuIterator for the matching transfer-before-use step,
45
+ # mostly fall back to adapt(CuArray, x) which recurses into Tuples etc:
46
+ Adapt. adapt_storage (:: Type{<:CuIterator} , x) = adapt (CuArray, x)
47
+
48
+ # But unlike adapt(CuArray, x), returse into arrays of arrays:
49
+ function Adapt. adapt_storage (:: Type{<:CuIterator} , xs:: AbstractArray{T} ) where T
50
+ isbitstype (T) ? adapt (CuArray, xs) : map (adapt (CuArray), xs)
51
+ end
52
+ function Adapt. adapt_storage (:: Type{CuIteratorFree} , xs:: AbstractArray{T} ) where T
53
+ foreach (adapt (CuIteratorFree), xs)
54
+ xs
55
+ end
0 commit comments