Skip to content

Commit 9c3ed38

Browse files
committed
add wait on device and generic Event constructor
``` ev = Event(CUDA()) ev = kernel(..., dependencies(ev,)) wait(CUDA(), ev) ```
1 parent 0e8ce58 commit 9c3ed38

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

src/KernelAbstractions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module KernelAbstractions
22

33
export @kernel
44
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize
5-
export Device, GPU, CPU, CUDA
5+
export Device, GPU, CPU, CUDA, Event
66

77
using MacroTools
88
using StaticArrays

src/backends/cpu.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
struct CPUEvent <: Event
2-
task::Core.Task
2+
task::Union{Nothing, Core.Task}
33
end
44

5-
function wait(ev::CPUEvent, progress=nothing)
5+
function Event(::CPU)
6+
return CPUEvent(nothing)
7+
end
8+
9+
wait(ev::CPUEvent, progress=nothing) = wait(CPU(), ev, progress)
10+
function wait(::CPU, ev::CPUEvent, progress=nothing)
11+
ev.task === nothing && return
12+
613
if progress === nothing
714
wait(ev.task)
815
else
@@ -50,7 +57,7 @@ function __run(obj, ndrange, iterspace, args, dependencies)
5057
!isempty(cpu_tasks) && Base.sync_end(cpu_tasks)
5158
for event in dependencies
5259
if !(event isa CPUEvent)
53-
wait(event, ()->yield())
60+
wait(CPU(), event, ()->yield())
5461
end
5562
end
5663
end

src/backends/cuda.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,16 @@ end
4747
struct CudaEvent <: Event
4848
event::CuEvent
4949
end
50-
function wait(ev::CudaEvent, progress=nothing)
50+
51+
function Event(::CUDA)
52+
stream = CUDAdrv.CuDefaultStream()
53+
event = CuEvent(CUDAdrv.EVENT_DISABLE_TIMING)
54+
CUDAdrv.record(event, stream)
55+
CudaEvent(event)
56+
end
57+
58+
wait(ev::CudaEvent, progress=nothing) = wait(CPU(), ev, progress)
59+
function wait(::CPU, ev::CudaEvent, progress=nothing)
5160
if progress === nothing
5261
CUDAdrv.synchronize(ev.event)
5362
else
@@ -58,6 +67,19 @@ function wait(ev::CudaEvent, progress=nothing)
5867
end
5968
end
6069

70+
# Use this to synchronize between computation using the CuDefaultStream
71+
wait(::CUDA, ev::CudaEvent, progress=nothing) = __enqueue_wait(ev, CUDAdrv.CuDefaultStream())
72+
73+
# There is no efficient wait for CPU->GPU synchronization, so instead we
74+
# do a CPU wait, and therefore block anyone from submitting more work.
75+
# We maybe could do a spinning wait on the GPU and atomic flag to signal from the CPU,
76+
# but which stream would we target?
77+
wait(::CUDA, ev::CPUEvent, progress=nothing) = wait(CPU(), ev, progress)
78+
79+
function __enqueue_wait(ev::CudaEvent, stream::CuStream)
80+
CUDAdrv.wait(ev.event, stream)
81+
end
82+
6183
function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, workgroupsize=nothing)
6284
if ndrange isa Integer
6385
ndrange = (ndrange,)
@@ -73,12 +95,12 @@ function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, wor
7395
if dependencies !== nothing
7496
for event in dependencies
7597
if event isa CudaEvent
76-
CUDAdrv.wait(event.event, stream)
98+
__enqueue_wait(event, stream)
7799
end
78100
end
79101
for event in dependencies
80102
if !(event isa CudaEvent)
81-
wait(event, ()->yield())
103+
wait(CUDA(), event, ()->yield())
82104
end
83105
end
84106
end

0 commit comments

Comments
 (0)