@@ -531,14 +531,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
531
531
uint32_t numSyncPointsInWaitList,
532
532
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
533
533
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
534
- (void )hCommandBuffer;
535
- (void )numSyncPointsInWaitList;
536
- (void )pSyncPointWaitList;
537
- (void )pSyncPoint;
538
-
539
- detail::ur::die (" Experimental Command-buffer feature is not "
540
- " implemented for CUDA adapter." );
541
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
534
+ // Prefetch cmd is not supported by Cuda Graph.
535
+ // We implement it as an empty node to enforce dependencies.
536
+ ur_result_t Result = UR_RESULT_SUCCESS;
537
+ CUgraphNode GraphNode;
538
+
539
+ std::vector<CUgraphNode> DepsList;
540
+ UR_CALL (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
541
+ pSyncPointWaitList, DepsList));
542
+
543
+ try {
544
+ // Add an empty node to preserve dependencies.
545
+ UR_CHECK_ERROR (cuGraphAddEmptyNode (&GraphNode, hCommandBuffer->CudaGraph ,
546
+ DepsList.data (), DepsList.size ()));
547
+
548
+ // Get sync point and register the cuNode with it.
549
+ *pSyncPoint =
550
+ hCommandBuffer->AddSyncPoint (std::make_shared<CUgraphNode>(GraphNode));
551
+ } catch (ur_result_t Err) {
552
+ Result = Err;
553
+ }
554
+ return Result;
542
555
}
543
556
544
557
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp (
@@ -547,14 +560,28 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
547
560
uint32_t numSyncPointsInWaitList,
548
561
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
549
562
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
550
- (void )hCommandBuffer;
551
- (void )numSyncPointsInWaitList;
552
- (void )pSyncPointWaitList;
553
- (void )pSyncPoint;
554
-
555
- detail::ur::die (" Experimental Command-buffer feature is not "
556
- " implemented for CUDA adapter." );
557
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
563
+ // Mem-Advise cmd is not supported by Cuda Graph.
564
+ // We implement it as an empty node to enforce dependencies.
565
+ ur_result_t Result = UR_RESULT_SUCCESS;
566
+ CUgraphNode GraphNode;
567
+
568
+ std::vector<CUgraphNode> DepsList;
569
+ UR_CALL (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
570
+ pSyncPointWaitList, DepsList));
571
+
572
+ try {
573
+ // Add an empty node to preserve dependencies.
574
+ UR_CHECK_ERROR (cuGraphAddEmptyNode (&GraphNode, hCommandBuffer->CudaGraph ,
575
+ DepsList.data (), DepsList.size ()));
576
+
577
+ // Get sync point and register the cuNode with it.
578
+ *pSyncPoint =
579
+ hCommandBuffer->AddSyncPoint (std::make_shared<CUgraphNode>(GraphNode));
580
+ } catch (ur_result_t Err) {
581
+ Result = Err;
582
+ }
583
+
584
+ return Result;
558
585
}
559
586
560
587
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp (
0 commit comments