Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 731b915

Browse files
committed
Monitor memory coherency.
1 parent 18be296 commit 731b915

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

src/indexing.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,36 @@ import GPUArrays: allowscalar, @allowscalar
33

44
## unified memory indexing
55

6-
# TODO: needs to think about coherency -- otherwise this might crash since it doesn't sync
7-
# also, this optim would be relevant for CuArray<->Array memcpy as well.
6+
const coherent = Ref(true)
7+
8+
# toggle coherency based on API calls
9+
function set_coherency(apicall)
10+
# TODO: whitelist
11+
coherent[] = false
12+
return
13+
end
14+
15+
function force_coherency()
16+
# TODO: not on newer hardware with certain flags
17+
18+
if CUDAdrv.apicall_hook[] !== set_coherency
19+
# we didn't have our API call hook in place, all bets are off
20+
coherent[] = false
21+
end
22+
23+
if !coherent[]
24+
CUDAdrv.synchronize()
25+
coherent[] = true
26+
elseif CUDAdrv.apicall_hook[] === nothing
27+
# nobody else is hooking for CUDA API calls, so we can safely install ours
28+
CUDAdrv.apicall_hook[] = set_coherency
29+
end
30+
end
831

932
function GPUArrays._getindex(xs::CuArray{T}, i::Integer) where T
1033
buf = buffer(xs)
1134
if isa(buf, Mem.UnifiedBuffer)
35+
force_coherency()
1236
ptr = convert(Ptr{T}, buffer(xs))
1337
unsafe_load(ptr, i)
1438
else
@@ -21,6 +45,7 @@ end
2145
function GPUArrays._setindex!(xs::CuArray{T}, v::T, i::Integer) where T
2246
buf = buffer(xs)
2347
if isa(buf, Mem.UnifiedBuffer)
48+
force_coherency()
2449
ptr = convert(Ptr{T}, buffer(xs))
2550
unsafe_store!(ptr, v, i)
2651
else

0 commit comments

Comments
 (0)