@@ -56,6 +56,7 @@ function Event(::CUDA)
56
56
end
57
57
58
58
wait (ev:: CudaEvent , progress= nothing ) = wait (CPU (), ev, progress)
59
+
59
60
function wait (:: CPU , ev:: CudaEvent , progress= nothing )
60
61
if progress === nothing
61
62
CUDAdrv. synchronize (ev. event)
@@ -68,30 +69,24 @@ function wait(::CPU, ev::CudaEvent, progress=nothing)
68
69
end
69
70
70
71
# Use this to synchronize between computation using the CuDefaultStream
71
- function wait (:: CUDA , ev:: CudaEvent , progress= nothing )
72
- CUDAdrv. wait (ev. event, CUDAdrv. CuDefaultStream ())
73
- end
72
+ wait (:: CUDA , ev:: CudaEvent , progress= nothing , stream= CUDAdrv. CuDefaultStream ()) = CUDAdrv. wait (ev. event, stream)
73
+ wait (:: CUDA , ev:: NoneEvent , progress= nothing , stream= nothing ) = nothing
74
74
75
75
# There is no efficient wait for CPU->GPU synchronization, so instead we
76
76
# do a CPU wait, and therefore block anyone from submitting more work.
77
77
# We maybe could do a spinning wait on the GPU and atomic flag to signal from the CPU,
78
78
# but which stream would we target?
79
- wait (:: CUDA , ev:: CPUEvent , progress= nothing ) = wait (CPU (), ev, progress)
80
-
81
- function __waitall (:: CUDA , dependencies, progress, stream)
82
- if dependencies isa Event
83
- dependencies = (dependencies,)
79
+ wait (:: CUDA , ev:: CPUEvent , progress= nothing , stream= nothing ) = wait (CPU (), ev, progress)
80
+
81
+ function wait (:: CUDA , ev:: MultiEvent , progress= nothing , stream= CUDAdrv. CuDefaultStream ())
82
+ dependencies = collect (ev. events)
83
+ cudadeps = filter (d-> d isa CudaEvent, dependencies)
84
+ otherdeps = filter (d-> ! (d isa CudaEvent), dependencies)
85
+ for event in cudadeps
86
+ CUDAdrv. wait (event. event, stream)
84
87
end
85
- if dependencies != = nothing
86
- dependencies = collect (dependencies)
87
- cudadeps = filter (d-> d isa CudaEvent, dependencies)
88
- otherdeps = filter (d-> ! (d isa CudaEvent), dependencies)
89
- for event in cudadeps
90
- CUDAdrv. wait (event. event, stream)
91
- end
92
- for event in otherdeps
93
- wait (CUDA (), event, progress)
94
- end
88
+ for event in otherdeps
89
+ wait (CUDA (), event, progress)
95
90
end
96
91
end
97
92
@@ -119,7 +114,7 @@ function async_copy!(::CUDA, A, B; dependencies=nothing)
119
114
B isa Array && __pin! (B)
120
115
121
116
stream = next_stream ()
122
- __waitall (CUDA (), dependencies, yield, stream)
117
+ wait (CUDA (), MultiEvent ( dependencies) , yield, stream)
123
118
event = CuEvent (CUDAdrv. EVENT_DISABLE_TIMING)
124
119
GC. @preserve A B begin
125
120
destptr = pointer (A)
@@ -145,12 +140,9 @@ function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, wor
145
140
if workgroupsize isa Integer
146
141
workgroupsize = (workgroupsize, )
147
142
end
148
- if dependencies isa Event
149
- dependencies = (dependencies,)
150
- end
151
143
152
144
stream = next_stream ()
153
- __waitall (CUDA (), dependencies, yield, stream)
145
+ wait (CUDA (), MultiEvent ( dependencies) , yield, stream)
154
146
155
147
if KernelAbstractions. workgroupsize (obj) <: DynamicSize && workgroupsize === nothing
156
148
# TODO : allow for NDRange{1, DynamicSize, DynamicSize}(nothing, nothing)
0 commit comments