Skip to content

Commit 9d8997c

Browse files
authored
More alloc cache improvements (#583)
1 parent 9457e03 commit 9d8997c

File tree

4 files changed

+130
-62
lines changed

4 files changed

+130
-62
lines changed

lib/JLArrays/src/JLArrays.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,16 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
8989
check_eltype(T)
9090
maxsize = prod(dims) * sizeof(T)
9191

92-
return GPUArrays.cached_alloc((JLArray, T, dims)) do
92+
ref = GPUArrays.cached_alloc((JLArray, maxsize)) do
9393
data = Vector{UInt8}(undef, maxsize)
94-
ref = DataRef(data) do data
94+
DataRef(data) do data
9595
resize!(data, 0)
9696
end
97-
obj = new{T, N}(ref, 0, dims)
98-
finalizer(unsafe_free!, obj)
99-
return obj
100-
end::JLArray{T, N}
97+
end
98+
99+
obj = new{T, N}(ref, 0, dims)
100+
finalizer(unsafe_free!, obj)
101+
return obj
101102
end
102103

103104
# low-level constructor for wrapping existing data

src/host/abstractarray.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,20 @@ end
5353

5454
# per-object state, with a flag to indicate whether the object has been freed.
5555
# this is to support multiple calls to `unsafe_free!` on the same object,
56-
# while only lowering the referene count of the underlying data once.
56+
# while only lowering the reference count of the underlying data once.
5757
mutable struct DataRef{D}
5858
rc::RefCounted{D}
5959
freed::Bool
60+
cached::Bool
6061
end
6162

62-
function DataRef(finalizer, data::D) where {D}
63-
rc = RefCounted{D}(data, finalizer, Threads.Atomic{Int}(1))
64-
DataRef{D}(rc, false)
63+
function DataRef(finalizer, ref::D) where {D}
64+
rc = RefCounted{D}(ref, finalizer, Threads.Atomic{Int}(1))
65+
DataRef{D}(rc, false, false)
6566
end
66-
DataRef(data; kwargs...) = DataRef(nothing, data; kwargs...)
67+
DataRef(ref; kwargs...) = DataRef(nothing, ref; kwargs...)
68+
69+
Base.sizeof(ref::DataRef) = sizeof(ref.rc[])
6770

6871
function Base.getindex(ref::DataRef)
6972
if ref.freed
@@ -77,18 +80,24 @@ function Base.copy(ref::DataRef{D}) where {D}
7780
throw(ArgumentError("Attempt to copy a freed reference."))
7881
end
7982
retain(ref.rc)
80-
return DataRef{D}(ref.rc, false)
83+
# copies of cached references are not managed by the cache, so
84+
# we need to mark them as such to make sure their refcount can drop.
85+
return DataRef{D}(ref.rc, false, false)
8186
end
8287

83-
function unsafe_free!(ref::DataRef, args...)
88+
function unsafe_free!(ref::DataRef)
89+
if ref.cached
90+
# lifetimes of cached references are tied to the cache.
91+
return
92+
end
8493
if ref.freed
8594
# multiple frees *of the same object* are allowed.
8695
# we should only ever call `release` once per object, though,
8796
# as multiple releases of the underlying data is not allowed.
8897
return
8998
end
9099
ref.freed = true
91-
release(ref.rc, args...)
100+
release(ref.rc)
92101
return
93102
end
94103

src/host/alloc_cache.jl

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ end
88

99
mutable struct AllocCache
1010
lock::ReentrantLock
11-
busy::Dict{UInt64, Vector{Any}} # hash(key) => GPUArray[]
12-
free::Dict{UInt64, Vector{Any}}
11+
busy::Dict{UInt64, Vector{DataRef}}
12+
free::Dict{UInt64, Vector{DataRef}}
1313

1414
function AllocCache()
1515
cache = new(
@@ -24,43 +24,48 @@ end
2424
function get_pool!(cache::AllocCache, pool::Symbol, uid::UInt64)
2525
pool = getproperty(cache, pool)
2626
uid_pool = get(pool, uid, nothing)
27-
if uid_pool nothing
28-
uid_pool = Base.@lock cache.lock pool[uid] = Any[]
27+
if uid_pool === nothing
28+
uid_pool = pool[uid] = DataRef[]
2929
end
3030
return uid_pool
3131
end
3232

3333
function cached_alloc(f, key)
3434
cache = ALLOC_CACHE[]
3535
if cache === nothing
36-
return f()::AbstractGPUArray
36+
return f()::DataRef
3737
end
3838

39-
x = nothing
39+
ref = nothing
4040
uid = hash(key)
4141

42-
busy_pool = get_pool!(cache, :busy, uid)
43-
free_pool = get_pool!(cache, :free, uid)
44-
isempty(free_pool) && (x = f()::AbstractGPUArray)
42+
Base.@lock cache.lock begin
43+
free_pool = get_pool!(cache, :free, uid)
44+
45+
if !isempty(free_pool)
46+
ref = Base.@lock cache.lock pop!(free_pool)
47+
end
48+
end
4549

46-
while !isempty(free_pool) && x nothing
47-
tmp = Base.@lock cache.lock pop!(free_pool)
48-
# Array was manually freed via `unsafe_free!`.
49-
GPUArrays.storage(tmp).freed && continue
50-
x = tmp
50+
if ref === nothing
51+
ref = f()::DataRef
52+
ref.cached = true
5153
end
5254

53-
x nothing && (x = f()::AbstractGPUArray)
54-
Base.@lock cache.lock push!(busy_pool, x)
55-
return x
55+
Base.@lock cache.lock begin
56+
busy_pool = get_pool!(cache, :busy, uid)
57+
push!(busy_pool, ref)
58+
end
59+
60+
return ref
5661
end
5762

5863
function free_busy!(cache::AllocCache)
59-
for uid in cache.busy.keys
60-
busy_pool = get_pool!(cache, :busy, uid)
61-
isempty(busy_pool) && continue
64+
Base.@lock cache.lock begin
65+
for uid in keys(cache.busy)
66+
busy_pool = get_pool!(cache, :busy, uid)
67+
isempty(busy_pool) && continue
6268

63-
Base.@lock cache.lock begin
6469
free_pool = get_pool!(cache, :free, uid)
6570
append!(free_pool, busy_pool)
6671
empty!(busy_pool)
@@ -71,14 +76,13 @@ end
7176

7277
function unsafe_free!(cache::AllocCache)
7378
Base.@lock cache.lock begin
74-
for (_, pool) in cache.busy
75-
isempty(pool) || error(
76-
"Invalidating allocations cache that's currently in use. " *
77-
"Invalidating inside `@cached` is not allowed."
78-
)
79+
for pool in values(cache.busy)
80+
isempty(pool) || error("Cannot invalidate a cache that's in active use")
7981
end
80-
for (_, pool) in cache.free
81-
map(unsafe_free!, pool)
82+
for pool in values(cache.free), ref in pool
83+
# release the reference
84+
ref.cached = false
85+
unsafe_free!(ref)
8286
end
8387
empty!(cache.free)
8488
end
@@ -143,13 +147,11 @@ GPUArrays.unsafe_free!(cache)
143147
See [`@uncached`](@ref).
144148
"""
145149
macro cached(cache, expr)
150+
try_expr = :(@with $(esc(ALLOC_CACHE)) => cache $(esc(expr)))
151+
fin_expr = :(free_busy!($(esc(cache))))
146152
return quote
147-
cache = $(esc(cache))
148-
GC.@preserve cache begin
149-
res = @with $(esc(ALLOC_CACHE)) => cache $(esc(expr))
150-
free_busy!(cache)
151-
res
152-
end
153+
local cache = $(esc(cache))
154+
GC.@preserve cache $(Expr(:tryfinally, try_expr, fin_expr))
153155
end
154156
end
155157

test/testsuite/alloc_cache.jl

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,98 @@
22
if AT <: AbstractGPUArray
33
cache = GPUArrays.AllocCache()
44

5+
# first allocation populates the cache
56
T, dims = Float32, (1, 2, 3)
67
GPUArrays.@cached cache begin
7-
x1 = AT(zeros(T, dims))
8+
cached1 = AT(zeros(T, dims))
89
end
9-
@test sizeof(cache) == sizeof(T) * prod(dims)
10+
@test sizeof(cache) == sizeof(cached1)
1011
key = first(keys(cache.free))
1112
@test length(cache.free[key]) == 1
1213
@test length(cache.busy[key]) == 0
13-
@test x1 === cache.free[key][1]
14+
@test cache.free[key][1] === GPUArrays.storage(cached1)
1415

15-
# Second allocation hits cache.
16+
# second allocation hits the cache
1617
GPUArrays.@cached cache begin
17-
x2 = AT(zeros(T, dims))
18-
# Does not hit the cache.
19-
GPUArrays.@uncached x_free = AT(zeros(T, dims))
18+
cached2 = AT(zeros(T, dims))
19+
20+
# explicitly uncached ones don't
21+
GPUArrays.@uncached uncached = AT(zeros(T, dims))
22+
end
23+
@test sizeof(cache) == sizeof(cached2)
24+
key = first(keys(cache.free))
25+
@test length(cache.free[key]) == 1
26+
@test length(cache.busy[key]) == 0
27+
@test cache.free[key][1] === GPUArrays.storage(cached2)
28+
@test uncached !== cached2
29+
30+
# compatible shapes should also hit the cache
31+
dims = (3, 2, 1)
32+
GPUArrays.@cached cache begin
33+
cached3 = AT(zeros(T, dims))
2034
end
21-
@test sizeof(cache) == sizeof(T) * prod(dims)
35+
@test sizeof(cache) == sizeof(cached3)
2236
key = first(keys(cache.free))
2337
@test length(cache.free[key]) == 1
2438
@test length(cache.busy[key]) == 0
25-
@test x2 === cache.free[key][1]
26-
@test x_free !== x2
39+
@test cache.free[key][1] === GPUArrays.storage(cached3)
2740

28-
# Third allocation is of different shape - allocates.
41+
# as should compatible eltypes
42+
T = Int32
43+
GPUArrays.@cached cache begin
44+
cached4 = AT(zeros(T, dims))
45+
end
46+
@test sizeof(cache) == sizeof(cached4)
47+
key = first(keys(cache.free))
48+
@test length(cache.free[key]) == 1
49+
@test length(cache.busy[key]) == 0
50+
@test cache.free[key][1] === GPUArrays.storage(cached4)
51+
52+
# different shapes should trigger a new allocation
2953
dims = (2, 2)
3054
GPUArrays.@cached cache begin
31-
x3 = AT(zeros(T, dims))
55+
cached5 = AT(zeros(T, dims))
56+
57+
# we're allowed to early free arrays, which should be a no-op for cached data
58+
GPUArrays.unsafe_free!(cached5)
3259
end
60+
@test sizeof(cache) == sizeof(cached4) + sizeof(cached5)
3361
_keys = collect(keys(cache.free))
3462
key2 = _keys[findfirst(i -> i != key, _keys)]
3563
@test length(cache.free[key]) == 1
3664
@test length(cache.free[key2]) == 1
37-
@test x3 === cache.free[key2][1]
65+
@test cache.free[key2][1] === GPUArrays.storage(cached5)
66+
67+
# we should be able to re-use the early-freed
68+
GPUArrays.@cached cache begin
69+
cached5 = AT(zeros(T, dims))
70+
end
71+
72+
# exceptions shouldn't cause issues
73+
@test_throws "Allowed exception" GPUArrays.@cached cache begin
74+
AT(zeros(T, dims))
75+
error("Allowed exception")
76+
end
77+
# NOTE: this should remaint the last test before calling `unsafe_free!` below,
78+
# as it caught an erroneous assertion in the original code.
3879

39-
# Freeing all memory held by cache.
80+
# freeing all memory held by cache should free all allocations
81+
@test !GPUArrays.storage(cached1).freed
82+
@test GPUArrays.storage(cached1).cached
83+
@test !GPUArrays.storage(cached5).freed
84+
@test GPUArrays.storage(cached5).cached
85+
@test !GPUArrays.storage(uncached).freed
86+
@test !GPUArrays.storage(uncached).cached
4087
GPUArrays.unsafe_free!(cache)
4188
@test sizeof(cache) == 0
89+
@test GPUArrays.storage(cached1).freed
90+
@test !GPUArrays.storage(cached1).cached
91+
@test GPUArrays.storage(cached5).freed
92+
@test !GPUArrays.storage(cached5).cached
93+
@test !GPUArrays.storage(uncached).freed
94+
## test that the underlying data was freed as well
95+
@test GPUArrays.storage(cached1).rc.count[] == 0
96+
@test GPUArrays.storage(cached5).rc.count[] == 0
97+
@test GPUArrays.storage(uncached).rc.count[] == 1
4298
end
4399
end

0 commit comments

Comments
 (0)