Skip to content

Commit b0aed59

Browse files
Adds CUDA support
1 parent 5116372 commit b0aed59

File tree

1 file changed

+43
-16
lines changed

1 file changed

+43
-16
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -531,14 +531,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
531531
uint32_t numSyncPointsInWaitList,
532532
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
533533
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
534-
(void)hCommandBuffer;
535-
(void)numSyncPointsInWaitList;
536-
(void)pSyncPointWaitList;
537-
(void)pSyncPoint;
538-
539-
detail::ur::die("Experimental Command-buffer feature is not "
540-
"implemented for CUDA adapter.");
541-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
534+
// Prefetch cmd is not supported by Cuda Graph.
535+
// We implement it as an empty node to enforce dependencies.
536+
ur_result_t Result = UR_RESULT_SUCCESS;
537+
CUgraphNode GraphNode;
538+
539+
std::vector<CUgraphNode> DepsList;
540+
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
541+
pSyncPointWaitList, DepsList));
542+
543+
try {
544+
// Add an empty node to preserve dependencies.
545+
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
546+
DepsList.data(), DepsList.size()));
547+
548+
// Get sync point and register the cuNode with it.
549+
*pSyncPoint =
550+
hCommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
551+
} catch (ur_result_t Err) {
552+
Result = Err;
553+
}
554+
return Result;
542555
}
543556

544557
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
@@ -547,14 +560,28 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
547560
uint32_t numSyncPointsInWaitList,
548561
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
549562
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
550-
(void)hCommandBuffer;
551-
(void)numSyncPointsInWaitList;
552-
(void)pSyncPointWaitList;
553-
(void)pSyncPoint;
554-
555-
detail::ur::die("Experimental Command-buffer feature is not "
556-
"implemented for CUDA adapter.");
557-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
563+
// Mem-Advise cmd is not supported by Cuda Graph.
564+
// We implement it as an empty node to enforce dependencies.
565+
ur_result_t Result = UR_RESULT_SUCCESS;
566+
CUgraphNode GraphNode;
567+
568+
std::vector<CUgraphNode> DepsList;
569+
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
570+
pSyncPointWaitList, DepsList));
571+
572+
try {
573+
// Add an empty node to preserve dependencies.
574+
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
575+
DepsList.data(), DepsList.size()));
576+
577+
// Get sync point and register the cuNode with it.
578+
*pSyncPoint =
579+
hCommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
580+
} catch (ur_result_t Err) {
581+
Result = Err;
582+
}
583+
584+
return Result;
558585
}
559586

560587
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(

0 commit comments

Comments
 (0)