Skip to content

Commit 94b96b4

Browse files
committed
Add CUDA sync inside host_task
This used to be inadvertently handled by an implementation detail which has since changed.
1 parent 0d3e261 commit 94b96b4

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

examples/cuda_interop/vec_add.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ int main(int argc, char *argv[]) {
7474
gridSize = static_cast<int>(ceil(static_cast<float>(n) / blockSize));
7575
// Call the CUDA kernel directly from SYCL
7676
vecAdd<<<gridSize, blockSize>>>(dA, dB, dC, n);
77+
// Interop with host_task doesn't add CUDA event to task graph
78+
// so we must manually sync here.
79+
cudaDeviceSynchronize();
7780
});
7881
});
7982

examples/sgemm_interop/sycl_sgemm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ int main() {
8181
auto d_C = b_C.get_access<sycl::access::mode::write>(h);
8282

8383
h.host_task([=](sycl::interop_handle ih) {
84+
auto cuStream = ih.get_native_queue<backend::ext_oneapi_cuda>();
85+
cublasSetStream(handle, cuStream);
8486
cuCtxSetCurrent(ih.get_native_context<backend::cuda>());
8587
cublasSetStream(handle, ih.get_native_queue<backend::cuda>());
8688
auto cuA = reinterpret_cast<float *>(ih.get_native_mem<backend::cuda>(d_A));
@@ -90,6 +92,7 @@ int main() {
9092
CHECK_ERROR(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, WIDTH, HEIGHT,
9193
WIDTH, &ALPHA, cuA, WIDTH, cuB, WIDTH, &BETA,
9294
cuC, WIDTH));
95+
cuStreamSynchronize(cuStream);
9396
});
9497
});
9598
}

examples/sgemm_interop/sycl_sgemm_usm.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,14 @@ int main() {
8686
h.host_task([=](sycl::interop_handle ih) {
8787

8888
// Set the correct cuda context & stream
89-
cuCtxSetCurrent(ih.get_native_context<backend::cuda>());
90-
cublasSetStream(handle, ih.get_native_queue<backend::cuda>());
89+
auto cuStream = ih.get_native_queue<backend::ext_oneapi_cuda>();
90+
cublasSetStream(handle, cuStream);
9191

9292
// Call generalised matrix-matrix multiply
9393
CHECK_ERROR(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, WIDTH, HEIGHT,
9494
WIDTH, &ALPHA, d_A, WIDTH, d_B, WIDTH, &BETA,
9595
d_C, WIDTH));
96+
cuStreamSynchronize(cuStream);
9697
});
9798
}).wait();
9899

0 commit comments

Comments
 (0)