Skip to content

Commit 9c61ff0

Browse files
authored
Cache FFT handles & more aggressive maybe_collect (#646)
- More aggressive `maybe_collect`. Always trigger it when pressure is `> 0.9` and last time we called it we freed a lot of memory. - Add `soft_memory_limit` preference (similar to `hard_memory_limit`). It controls how much memory is returned back to the OS if it is not used. - Actually cache rocFFT handles. Significantly improves dispatch time.
1 parent 857644f commit 9c61ff0

File tree

4 files changed

+59
-19
lines changed

4 files changed

+59
-19
lines changed

src/fft/fft.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
abstract type ROCFFTPlan{T, K, inplace} <: Plan{T} end
22

3+
Base.eltype(::ROCFFTPlan{T}) where T = T
4+
5+
is_inplace(::ROCFFTPlan{<:Any, <:Any, inplace}) where inplace = inplace
6+
37
Base.unsafe_convert(::Type{rocfft_plan}, p::ROCFFTPlan) = p.handle
48

5-
function unsafe_free!(plan::ROCFFTPlan)
6-
rocfft_plan_destroy(plan.handle)
9+
function AMDGPU.unsafe_free!(plan::ROCFFTPlan)
10+
if plan.handle != C_NULL
11+
release_plan!(plan)
12+
plan.handle = C_NULL
13+
end
714
unsafe_free!(plan.workarea)
815
rocfft_execution_info_destroy(plan.execution_info)
916
end
@@ -44,7 +51,7 @@ mutable struct cROCFFTPlan{T, K, inplace, N} <: ROCFFTPlan{T, K, inplace}
4451
rocfft_execution_info_set_work_buffer(info, workarea, length(workarea))
4552
end
4653
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region)
47-
finalizer(unsafe_free!, p)
54+
finalizer(AMDGPU.unsafe_free!, p)
4855
p
4956
end
5057
end

src/fft/wrappers.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@ function get_plan(args...)
1313
return handle, workarea
1414
end
1515

16+
function release_plan!(plan)
17+
key = (
18+
AMDGPU.context(), plan.xtype, plan.sz,
19+
eltype(plan), is_inplace(plan), plan.region)
20+
value = (plan.handle, length(plan.workarea))
21+
push!(IDLE_HANDLES, key, value) do
22+
destroy_plan!(plan)
23+
end
24+
end
25+
26+
function destroy_plan!(plan)
27+
rocfft_plan_destroy(plan.handle)
28+
end
29+
1630
function create_plan(xtype::rocfft_transform_type, xdims, T, inplace, region)
1731
precision = (real(T) == Float64) ?
1832
rocfft_precision_double : rocfft_precision_single
@@ -169,9 +183,3 @@ function create_plan(xtype::rocfft_transform_type, xdims, T, inplace, region)
169183
rocfft_plan_get_work_buffer_size(handle_ref[], worksize_ref)
170184
return handle_ref[], Int(worksize_ref[])
171185
end
172-
173-
function release_plan(plan)
174-
push!(IDLE_HANDLES, plan) do
175-
unsafe_free!(plan)
176-
end
177-
end

src/memory.jl

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,34 @@ end
6767
"""
6868
Set a hard limit for total GPU memory allocations.
6969
"""
70-
set_memory_alloc_limit!(limit::String) =
70+
hard_memory_limit!(limit::String) =
7171
@set_preferences!("hard_memory_limit" => limit)
7272

73+
soft_memory_limit!(limit::String) =
74+
@set_preferences!("soft_memory_limit" => limit)
75+
7376
const HARD_MEMORY_LIMIT = Ref{Union{Nothing, UInt64}}(nothing)
7477
function hard_memory_limit()
75-
l = HARD_MEMORY_LIMIT[]
76-
l nothing && return l
78+
hard_limit = HARD_MEMORY_LIMIT[]
79+
hard_limit nothing && return hard_limit
7780

78-
HARD_MEMORY_LIMIT[] = parse_memory_limit(
81+
hard_limit = parse_memory_limit(
7982
@load_preference("hard_memory_limit", "none"))
83+
84+
@debug "Setting hard memory limit: $(Base.format_bytes(hard_limit))"
85+
HARD_MEMORY_LIMIT[] = hard_limit
86+
end
87+
88+
const SOFT_MEMORY_LIMIT = Ref{Union{Nothing, UInt64}}(nothing)
89+
function soft_memory_limit()
90+
soft_limit = SOFT_MEMORY_LIMIT[]
91+
soft_limit nothing && return soft_limit
92+
93+
soft_limit = parse_memory_limit(
94+
@load_preference("soft_memory_limit", "none"))
95+
96+
@debug "Setting soft memory limit: $(Base.format_bytes(soft_limit))"
97+
SOFT_MEMORY_LIMIT[] = soft_limit
8098
end
8199

82100
mutable struct MemoryStats
@@ -152,19 +170,23 @@ function maybe_collect(; blocking::Bool = false)
152170
# Tolerate 5% GC time.
153171
max_gc_rate = 0.05
154172
# If freed a lot of memory last time, double max GC rate.
155-
(stats.last_freed > 0.1 * stats.size) && (max_gc_rate *= 2;)
173+
freed_alot = stats.last_freed > 0.1 * stats.size
174+
freed_alot && (max_gc_rate *= 2;)
156175
# Be more aggressive if we are going to block.
157176
blocking && (max_gc_rate *= 2;)
158177

159-
# And even more if we are at a limit.
160-
pressure > 0.9 && (max_gc_rate *= 2;)
161-
pressure > 0.95 && (max_gc_rate *= 2;)
178+
# And even more if the pressure is high.
179+
pressure > 0.5 && (max_gc_rate *= 2;)
180+
pressure > 0.7 && (max_gc_rate *= 2;)
181+
182+
# Always free if pressure is 0.9 and we freed a lot.
183+
pressure > 0.9 && (max_gc_rate *= freed_alot ? Inf : 2;)
162184
gc_rate > max_gc_rate && return
163185

164186
# Call the GC.
165187
Base.@atomic stats.last_time = current_time
166188
pre_gc_live = stats.live
167-
gc_time = Base.@elapsed GC.gc(pressure > 0.9 ? true : false)
189+
gc_time = Base.@elapsed GC.gc(pressure > 0.7 ? true : false)
168190
post_gc_live = stats.live
169191

170192
# Update stats.

src/runtime/memory/hip.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ function pool_create(dev::HIPDevice)
99
max_size = max_size != typemax(UInt64) ? max_size : UInt64(0)
1010

1111
pool = HIP.HIPMemoryPool(dev; max_size)
12-
# TODO set soft threshold?
12+
# Allow pool to use up all device memory.
13+
soft_limit = AMDGPU.soft_memory_limit()
14+
HIP.attribute!(pool, HIP.hipMemPoolAttrReleaseThreshold, soft_limit)
15+
1316
HIP.memory_pool!(dev, pool)
1417
return pool
1518
end

0 commit comments

Comments
 (0)