Skip to content

Commit d32722b

Browse files
authored
Properly allocate workspace for conv algo search (#643)
1 parent 51921f7 commit d32722b

File tree

3 files changed

+15
-24
lines changed

3 files changed

+15
-24
lines changed

src/array.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,13 @@ function Base.copyto!(
135135
amount == 0 && return dest
136136
@boundscheck checkbounds(dest, d_offset + amount - 1)
137137
@boundscheck checkbounds(source, s_offset + amount - 1)
138+
stm = stream()
138139
Mem.download!(
139140
pointer(dest, d_offset),
140141
Mem.view(convert(Mem.AbstractAMDBuffer, source.buf[]),
141142
(source.offset + s_offset - 1) * sizeof(T)),
142-
amount * sizeof(T); stream=stream(), async)
143+
amount * sizeof(T); stream=stm)
144+
async || synchronize(stm)
143145
dest
144146
end
145147

src/dnn/convolution.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function get_workspace_size(
6363
wsize_ref = Ref{Csize_t}(0)
6464
get_workspace_size_func(conv_type)(
6565
handle, a_desc.handle, b_desc.handle,
66-
conv_desc.handle, c_desc.handle, wsize_ref) |> check
66+
conv_desc.handle, c_desc.handle, wsize_ref) # NOTE: do not |> check...
6767
wsize_ref[]
6868
end
6969

@@ -93,9 +93,12 @@ function find_algorithm(
9393
cache = get_benchmark_cache(conv_type, conv_args)
9494
isnothing(cache) || return cache
9595

96-
workspace = ROCArray{UInt8}(undef, 0)
96+
wsize = get_workspace_size(conv_type; handle, a_desc, b_desc, conv_desc, c_desc)
97+
workspace = ROCArray{UInt8}(undef, wsize)
9798
perf_results = find_conv_algo(conv_type;
9899
handle, workspace, a, a_desc, b, b_desc, conv_desc, c, c_desc)
100+
AMDGPU.unsafe_free!(workspace)
101+
99102
set_benchmark_cache!(conv_type, conv_args, perf_results)
100103
workspace = ROCArray{UInt8}(undef, perf_results.memory)
101104

src/runtime/memory/hip.jl

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,8 @@ function HIPBuffer(bytesize; stream::HIP.HIPStream)
3838
ptr = alloc_or_retry!(isnothing; stream) do
3939
try
4040
# Try to allocate.
41-
# NOTE Async is ~300x slower for small (≤ 16 bytes) allocations:
42-
# https://github.com/ROCm/HIP/issues/3370#issuecomment-1842938966
43-
if bytesize > 16
44-
HIP.hipMallocAsync(ptr_ref, bytesize, stream) |> HIP.check
45-
# HIP.hipMallocFromPoolAsync(ptr_ref, bytesize, pool, stream) |> HIP.check
46-
else
47-
HIP.hipMalloc(ptr_ref, bytesize) |> HIP.check
48-
end
41+
HIP.hipMallocAsync(ptr_ref, bytesize, stream) |> HIP.check
42+
# HIP.hipMallocFromPoolAsync(ptr_ref, bytesize, pool, stream) |> HIP.check
4943

5044
ptr = ptr_ref[]
5145
ptr == C_NULL && throw(HIP.HIPError(HIP.hipErrorOutOfMemory))
@@ -78,11 +72,7 @@ function free(buf::HIPBuffer; stream::HIP.HIPStream)
7872
buf.own || return
7973

8074
buf.ptr == C_NULL && return
81-
if buf.bytesize > 16
82-
HIP.hipFreeAsync(buf, stream) |> HIP.check
83-
else
84-
HIP.hipFree(buf) |> HIP.check
85-
end
75+
HIP.hipFreeAsync(buf, stream) |> HIP.check
8676
AMDGPU.account!(AMDGPU.memory_stats(buf.device), -buf.bytesize)
8777
return
8878
end
@@ -93,13 +83,9 @@ function upload!(dst::HIPBuffer, src::Ptr, bytesize::Int; stream::HIP.HIPStream)
9383
return
9484
end
9585

96-
function download!(dst::Ptr, src::HIPBuffer, bytesize::Int; stream::HIP.HIPStream, async::Bool)
86+
function download!(dst::Ptr, src::HIPBuffer, bytesize::Int; stream::HIP.HIPStream)
9787
bytesize == 0 && return
98-
if async
99-
HIP.hipMemcpyDtoHAsync(dst, src, bytesize, stream) |> HIP.check
100-
else
101-
HIP.hipMemcpyDtoH(dst, src, bytesize) |> HIP.check
102-
end
88+
HIP.hipMemcpyDtoHAsync(dst, src, bytesize, stream) |> HIP.check
10389
return
10490
end
10591

@@ -157,10 +143,10 @@ upload!(dst::HostBuffer, src::Ptr, sz::Int; stream::HIP.HIPStream) =
157143
upload!(dst::HostBuffer, src::HIPBuffer, sz::Int; stream::HIP.HIPStream) =
158144
HIP.memcpy(dst, src, sz, HIP.hipMemcpyDeviceToHost, stream)
159145

160-
download!(dst::Ptr, src::HostBuffer, sz::Int; stream::HIP.HIPStream, async::Bool) =
146+
download!(dst::Ptr, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =
161147
HIP.memcpy(dst, src, sz, HIP.hipMemcpyHostToHost, stream)
162148

163-
download!(dst::HIPBuffer, src::HostBuffer, sz::Int; stream::HIP.HIPStream, async::Bool) =
149+
download!(dst::HIPBuffer, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =
164150
HIP.memcpy(dst, src, sz, HIP.hipMemcpyHostToDevice, stream)
165151

166152
transfer!(dst::HostBuffer, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =

0 commit comments

Comments
 (0)