Skip to content

Commit 5d8173a

Browse files
authored
Merge pull request #937 from Bensuo/support-prefetch-advise-cmd-buffers
[EXP][CMDBUF] Add adapters code for Prefetch and Advise commands
2 parents 749d8e5 + 01cd56d commit 5d8173a

File tree

8 files changed

+228
-0
lines changed

8 files changed

+228
-0
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,77 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
525525
return Result;
526526
}
527527

528+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
529+
ur_exp_command_buffer_handle_t hCommandBuffer, const void * /* Mem */,
530+
size_t /*Size*/, ur_usm_migration_flags_t /*Flags*/,
531+
uint32_t numSyncPointsInWaitList,
532+
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
533+
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
534+
// Prefetch cmd is not supported by Cuda Graph.
535+
// We implement it as an empty node to enforce dependencies.
536+
ur_result_t Result = UR_RESULT_SUCCESS;
537+
CUgraphNode GraphNode;
538+
539+
std::vector<CUgraphNode> DepsList;
540+
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
541+
pSyncPointWaitList, DepsList),
542+
Result);
543+
544+
try {
545+
// Add an empty node to preserve dependencies.
546+
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
547+
DepsList.data(), DepsList.size()));
548+
549+
// Get sync point and register the cuNode with it.
550+
*pSyncPoint =
551+
hCommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
552+
553+
setErrorMessage("Prefetch hint ignored and replaced with empty node as "
554+
"prefetch is not supported by CUDA Graph backend",
555+
UR_RESULT_SUCCESS);
556+
Result = UR_RESULT_ERROR_ADAPTER_SPECIFIC;
557+
} catch (ur_result_t Err) {
558+
Result = Err;
559+
}
560+
return Result;
561+
}
562+
563+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
564+
ur_exp_command_buffer_handle_t hCommandBuffer, const void * /* Mem */,
565+
size_t /*Size*/, ur_usm_advice_flags_t /*Advice*/,
566+
uint32_t numSyncPointsInWaitList,
567+
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
568+
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
569+
// Mem-Advise cmd is not supported by Cuda Graph.
570+
// We implement it as an empty node to enforce dependencies.
571+
ur_result_t Result = UR_RESULT_SUCCESS;
572+
CUgraphNode GraphNode;
573+
574+
std::vector<CUgraphNode> DepsList;
575+
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
576+
pSyncPointWaitList, DepsList),
577+
Result);
578+
579+
try {
580+
// Add an empty node to preserve dependencies.
581+
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
582+
DepsList.data(), DepsList.size()));
583+
584+
// Get sync point and register the cuNode with it.
585+
*pSyncPoint =
586+
hCommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
587+
588+
setErrorMessage("Memory advice ignored and replaced with empty node as "
589+
"memory advice is not supported by CUDA Graph backend",
590+
UR_RESULT_SUCCESS);
591+
Result = UR_RESULT_ERROR_ADAPTER_SPECIFIC;
592+
} catch (ur_result_t Err) {
593+
Result = Err;
594+
}
595+
596+
return Result;
597+
}
598+
528599
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
529600
ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
530601
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,

source/adapters/cuda/ur_interface_loader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
289289
urCommandBufferAppendMemBufferWriteExp;
290290
pDdiTable->pfnAppendMemBufferWriteRectExp =
291291
urCommandBufferAppendMemBufferWriteRectExp;
292+
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
293+
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
292294
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;
293295

294296
return retVal;

source/adapters/hip/command_buffer.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,21 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
122122
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
123123
}
124124

125+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
126+
ur_exp_command_buffer_handle_t, const void *, size_t,
127+
ur_usm_migration_flags_t, uint32_t,
128+
const ur_exp_command_buffer_sync_point_t *,
129+
ur_exp_command_buffer_sync_point_t *) {
130+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
131+
}
132+
133+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
134+
ur_exp_command_buffer_handle_t, const void *, size_t, ur_usm_advice_flags_t,
135+
uint32_t, const ur_exp_command_buffer_sync_point_t *,
136+
ur_exp_command_buffer_sync_point_t *) {
137+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
138+
}
139+
125140
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
126141
ur_exp_command_buffer_handle_t, ur_queue_handle_t, uint32_t,
127142
const ur_event_handle_t *, ur_event_handle_t *) {

source/adapters/hip/ur_interface_loader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
286286
urCommandBufferAppendMemBufferWriteExp;
287287
pDdiTable->pfnAppendMemBufferWriteRectExp =
288288
urCommandBufferAppendMemBufferWriteRectExp;
289+
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
290+
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
289291
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;
290292

291293
return retVal;

source/adapters/level_zero/command_buffer.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,106 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
683683
SyncPointWaitList, SyncPoint);
684684
}
685685

686+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
687+
ur_exp_command_buffer_handle_t CommandBuffer, const void *Mem, size_t Size,
688+
ur_usm_migration_flags_t Flags, uint32_t NumSyncPointsInWaitList,
689+
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
690+
ur_exp_command_buffer_sync_point_t *SyncPoint) {
691+
std::ignore = Flags;
692+
693+
std::vector<ze_event_handle_t> ZeEventList;
694+
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
695+
SyncPointWaitList, ZeEventList));
696+
697+
if (NumSyncPointsInWaitList) {
698+
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
699+
(CommandBuffer->ZeCommandList, NumSyncPointsInWaitList,
700+
ZeEventList.data()));
701+
}
702+
703+
ur_event_handle_t LaunchEvent;
704+
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
705+
LaunchEvent->CommandType = UR_COMMAND_USM_PREFETCH;
706+
707+
// Get sync point and register the event with it.
708+
*SyncPoint = CommandBuffer->GetNextSyncPoint();
709+
CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent);
710+
711+
// Add the prefetch command to the command buffer.
712+
// Note that L0 does not handle migration flags.
713+
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
714+
(CommandBuffer->ZeCommandList, Mem, Size));
715+
716+
// Level Zero does not have a completion "event" with the prefetch API,
717+
// so manually add command to signal our event.
718+
ZE2UR_CALL(zeCommandListAppendSignalEvent,
719+
(CommandBuffer->ZeCommandList, LaunchEvent->ZeEvent));
720+
721+
return UR_RESULT_SUCCESS;
722+
}
723+
724+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
725+
ur_exp_command_buffer_handle_t CommandBuffer, const void *Mem, size_t Size,
726+
ur_usm_advice_flags_t Advice, uint32_t NumSyncPointsInWaitList,
727+
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
728+
ur_exp_command_buffer_sync_point_t *SyncPoint) {
729+
// A memory chunk can be advised with muliple memory advices
730+
// We therefore prefer if statements to switch cases to combine all potential
731+
// flags
732+
uint32_t Value = 0;
733+
if (Advice & UR_USM_ADVICE_FLAG_SET_READ_MOSTLY)
734+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_READ_MOSTLY);
735+
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY)
736+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_READ_MOSTLY);
737+
if (Advice & UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION)
738+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_PREFERRED_LOCATION);
739+
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION)
740+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_PREFERRED_LOCATION);
741+
if (Advice & UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY)
742+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_NON_ATOMIC_MOSTLY);
743+
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY)
744+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_NON_ATOMIC_MOSTLY);
745+
if (Advice & UR_USM_ADVICE_FLAG_BIAS_CACHED)
746+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_BIAS_CACHED);
747+
if (Advice & UR_USM_ADVICE_FLAG_BIAS_UNCACHED)
748+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_BIAS_UNCACHED);
749+
if (Advice & UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION_HOST)
750+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_PREFERRED_LOCATION);
751+
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION_HOST)
752+
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_PREFERRED_LOCATION);
753+
754+
ze_memory_advice_t ZeAdvice = static_cast<ze_memory_advice_t>(Value);
755+
756+
std::vector<ze_event_handle_t> ZeEventList;
757+
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
758+
SyncPointWaitList, ZeEventList));
759+
760+
if (NumSyncPointsInWaitList) {
761+
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
762+
(CommandBuffer->ZeCommandList, NumSyncPointsInWaitList,
763+
ZeEventList.data()));
764+
}
765+
766+
ur_event_handle_t LaunchEvent;
767+
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
768+
LaunchEvent->CommandType = UR_COMMAND_USM_ADVISE;
769+
770+
// Get sync point and register the event with it.
771+
*SyncPoint = CommandBuffer->GetNextSyncPoint();
772+
CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent);
773+
774+
ZE2UR_CALL(zeCommandListAppendMemAdvise,
775+
(CommandBuffer->ZeCommandList, CommandBuffer->Device->ZeDevice,
776+
Mem, Size, ZeAdvice));
777+
778+
// Level Zero does not have a completion "event" with the advise API,
779+
// so manually add command to signal our event.
780+
ZE2UR_CALL(zeCommandListAppendSignalEvent,
781+
(CommandBuffer->ZeCommandList, LaunchEvent->ZeEvent));
782+
783+
return UR_RESULT_SUCCESS;
784+
}
785+
686786
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
687787
ur_exp_command_buffer_handle_t CommandBuffer, ur_queue_handle_t Queue,
688788
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,

source/adapters/level_zero/ur_interface_loader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
336336
urCommandBufferAppendMemBufferWriteExp;
337337
pDdiTable->pfnAppendMemBufferWriteRectExp =
338338
urCommandBufferAppendMemBufferWriteRectExp;
339+
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
340+
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
339341
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;
340342

341343
return retVal;

source/adapters/opencl/command_buffer.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMembufferFillExp(
297297
return UR_RESULT_SUCCESS;
298298
}
299299

300+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
301+
ur_exp_command_buffer_handle_t hCommandBuffer, const void *mem, size_t size,
302+
ur_usm_migration_flags_t flags, uint32_t numSyncPointsInWaitList,
303+
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
304+
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
305+
(void)hCommandBuffer;
306+
(void)mem;
307+
(void)size;
308+
(void)flags;
309+
(void)numSyncPointsInWaitList;
310+
(void)pSyncPointWaitList;
311+
(void)pSyncPoint;
312+
313+
// Not implemented
314+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
315+
}
316+
317+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
318+
ur_exp_command_buffer_handle_t hCommandBuffer, const void *mem, size_t size,
319+
ur_usm_advice_flags_t advice, uint32_t numSyncPointsInWaitList,
320+
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
321+
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
322+
(void)hCommandBuffer;
323+
(void)mem;
324+
(void)size;
325+
(void)advice;
326+
(void)numSyncPointsInWaitList;
327+
(void)pSyncPointWaitList;
328+
(void)pSyncPoint;
329+
330+
// Not implemented
331+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
332+
}
333+
300334
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
301335
ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
302336
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,

source/adapters/opencl/ur_interface_loader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
296296
urCommandBufferAppendMemBufferWriteExp;
297297
pDdiTable->pfnAppendMemBufferWriteRectExp =
298298
urCommandBufferAppendMemBufferWriteRectExp;
299+
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
300+
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
299301
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;
300302

301303
return retVal;

0 commit comments

Comments
 (0)