Skip to content

Commit 1d78636

Browse files
authored
Merge pull request #938 from Bensuo/cmdbuf-fill-memset-l0
[EXP][CMDBUF] Implement Fill commands for L0 adapter
2 parents cf87428 + 3ee71a7 commit 1d78636

File tree

8 files changed

+231
-1
lines changed

8 files changed

+231
-1
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,91 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
9999
Params.Depth = 1;
100100
}
101101

102+
// Helper function for enqueuing memory fills
103+
static ur_result_t enqueueCommandBufferFillHelper(
104+
ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice,
105+
const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
106+
size_t Size, uint32_t NumSyncPointsInWaitList,
107+
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
108+
ur_exp_command_buffer_sync_point_t *SyncPoint) {
109+
ur_result_t Result = UR_RESULT_SUCCESS;
110+
std::vector<CUgraphNode> DepsList;
111+
UR_CALL(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
112+
SyncPointWaitList, DepsList),
113+
Result);
114+
115+
try {
116+
const size_t N = Size / PatternSize;
117+
auto Value = *static_cast<const uint32_t *>(Pattern);
118+
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
119+
? *static_cast<CUdeviceptr *>(DstDevice)
120+
: (CUdeviceptr)DstDevice;
121+
122+
if ((PatternSize == 1) || (PatternSize == 2) || (PatternSize == 4)) {
123+
// Create a new node
124+
CUgraphNode GraphNode;
125+
CUDA_MEMSET_NODE_PARAMS NodeParams = {};
126+
NodeParams.dst = DstPtr;
127+
NodeParams.elementSize = PatternSize;
128+
NodeParams.height = N;
129+
NodeParams.pitch = PatternSize;
130+
NodeParams.value = Value;
131+
NodeParams.width = 1;
132+
133+
UR_CHECK_ERROR(cuGraphAddMemsetNode(
134+
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
135+
DepsList.size(), &NodeParams, CommandBuffer->Device->getContext()));
136+
137+
// Get sync point and register the cuNode with it.
138+
*SyncPoint =
139+
CommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
140+
141+
} else {
142+
// CUDA has no memset functions that allow setting values more than 4
143+
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
144+
// fill, which can be more than 4 bytes. We must break up the pattern
145+
// into 4 byte values, and set the buffer using multiple strided calls.
146+
// This means that one cuGraphAddMemsetNode call is made for every 4 bytes
147+
// in the pattern.
148+
149+
size_t NumberOfSteps = PatternSize / sizeof(uint32_t);
150+
151+
// we walk up the pattern in 4-byte steps, and call cuMemset for each
152+
// 4-byte chunk of the pattern.
153+
for (auto Step = 0u; Step < NumberOfSteps; ++Step) {
154+
// take 4 bytes of the pattern
155+
auto Value = *(static_cast<const uint32_t *>(Pattern) + Step);
156+
157+
// offset the pointer to the part of the buffer we want to write to
158+
auto OffsetPtr = DstPtr + (Step * sizeof(uint32_t));
159+
160+
// Create a new node
161+
CUgraphNode GraphNode;
162+
// Update NodeParam
163+
CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
164+
NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
165+
NodeParamsStep.elementSize = 4;
166+
NodeParamsStep.height = N;
167+
NodeParamsStep.pitch = PatternSize;
168+
NodeParamsStep.value = Value;
169+
NodeParamsStep.width = 1;
170+
171+
UR_CHECK_ERROR(cuGraphAddMemsetNode(
172+
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
173+
DepsList.size(), &NodeParamsStep,
174+
CommandBuffer->Device->getContext()));
175+
176+
// Get sync point and register the cuNode with it.
177+
*SyncPoint = CommandBuffer->AddSyncPoint(
178+
std::make_shared<CUgraphNode>(GraphNode));
179+
}
180+
}
181+
} catch (ur_result_t Err) {
182+
Result = Err;
183+
}
184+
return Result;
185+
}
186+
102187
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
103188
ur_context_handle_t hContext, ur_device_handle_t hDevice,
104189
const ur_exp_command_buffer_desc_t *pCommandBufferDesc,
@@ -596,6 +681,48 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
596681
return Result;
597682
}
598683

684+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
685+
ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer,
686+
const void *pPattern, size_t patternSize, size_t offset, size_t size,
687+
uint32_t numSyncPointsInWaitList,
688+
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
689+
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
690+
auto ArgsAreMultiplesOfPatternSize =
691+
(offset % patternSize == 0) || (size % patternSize == 0);
692+
693+
auto PatternIsValid = (pPattern != nullptr);
694+
695+
auto PatternSizeIsValid = ((patternSize & (patternSize - 1)) == 0) &&
696+
(patternSize > 0); // is a positive power of two
697+
UR_ASSERT(ArgsAreMultiplesOfPatternSize && PatternIsValid &&
698+
PatternSizeIsValid,
699+
UR_RESULT_ERROR_INVALID_SIZE);
700+
701+
auto DstDevice = std::get<BufferMem>(hBuffer->Mem).get() + offset;
702+
703+
return enqueueCommandBufferFillHelper(
704+
hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
705+
size, numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
706+
}
707+
708+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
709+
ur_exp_command_buffer_handle_t hCommandBuffer, void *pPtr,
710+
const void *pPattern, size_t patternSize, size_t size,
711+
uint32_t numSyncPointsInWaitList,
712+
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
713+
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
714+
715+
auto PatternIsValid = (pPattern != nullptr);
716+
717+
auto PatternSizeIsValid = ((patternSize & (patternSize - 1)) == 0) &&
718+
(patternSize > 0); // is a positive power of two
719+
720+
UR_ASSERT(PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE);
721+
return enqueueCommandBufferFillHelper(
722+
hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size,
723+
numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
724+
}
725+
599726
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
600727
ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
601728
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,

source/adapters/cuda/ur_interface_loader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
279279
pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp;
280280
pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp;
281281
pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp;
282+
pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp;
282283
pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp;
283284
pDdiTable->pfnAppendMemBufferCopyRectExp =
284285
urCommandBufferAppendMemBufferCopyRectExp;
@@ -291,6 +292,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
291292
urCommandBufferAppendMemBufferWriteRectExp;
292293
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
293294
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
295+
pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp;
294296
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;
295297

296298
return retVal;

source/adapters/hip/command_buffer.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
137137
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
138138
}
139139

140+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
141+
ur_exp_command_buffer_handle_t, ur_mem_handle_t, const void *, size_t,
142+
size_t, size_t, uint32_t, const ur_exp_command_buffer_sync_point_t *,
143+
ur_exp_command_buffer_sync_point_t *) {
144+
detail::ur::die("Experimental Command-buffer feature is not "
145+
"implemented for HIP adapter.");
146+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
147+
}
148+
149+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
150+
ur_exp_command_buffer_handle_t, void *, const void *, size_t, size_t,
151+
uint32_t, const ur_exp_command_buffer_sync_point_t *,
152+
ur_exp_command_buffer_sync_point_t *) {
153+
detail::ur::die("Experimental Command-buffer feature is not "
154+
"implemented for HIP adapter.");
155+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
156+
}
157+
140158
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
141159
ur_exp_command_buffer_handle_t, ur_queue_handle_t, uint32_t,
142160
const ur_event_handle_t *, ur_event_handle_t *) {

source/adapters/hip/ur_interface_loader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
276276
pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp;
277277
pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp;
278278
pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp;
279+
pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp;
279280
pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp;
280281
pDdiTable->pfnAppendMemBufferCopyRectExp =
281282
urCommandBufferAppendMemBufferCopyRectExp;
@@ -289,6 +290,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
289290
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
290291
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
291292
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;
293+
pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp;
292294

293295
return retVal;
294296
}

source/adapters/level_zero/command_buffer.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,48 @@ static ur_result_t enqueueCommandBufferMemCopyRectHelper(
379379
return UR_RESULT_SUCCESS;
380380
}
381381

382+
// Helper function for enqueuing memory fills
383+
static ur_result_t enqueueCommandBufferFillHelper(
384+
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
385+
void *Ptr, const void *Pattern, size_t PatternSize, size_t Size,
386+
uint32_t NumSyncPointsInWaitList,
387+
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
388+
ur_exp_command_buffer_sync_point_t *SyncPoint) {
389+
// Pattern size must be a power of two.
390+
UR_ASSERT((PatternSize > 0) && ((PatternSize & (PatternSize - 1)) == 0),
391+
UR_RESULT_ERROR_INVALID_VALUE);
392+
393+
// Pattern size must fit the compute queue capabilities.
394+
UR_ASSERT(
395+
PatternSize <=
396+
CommandBuffer->Device
397+
->QueueGroup[ur_device_handle_t_::queue_group_info_t::Compute]
398+
.ZeProperties.maxMemoryFillPatternSize,
399+
UR_RESULT_ERROR_INVALID_VALUE);
400+
401+
std::vector<ze_event_handle_t> ZeEventList;
402+
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
403+
SyncPointWaitList, ZeEventList));
404+
405+
ur_event_handle_t LaunchEvent;
406+
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
407+
LaunchEvent->CommandType = CommandType;
408+
409+
// Get sync point and register the event with it.
410+
*SyncPoint = CommandBuffer->GetNextSyncPoint();
411+
CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent);
412+
413+
ZE2UR_CALL(zeCommandListAppendMemoryFill,
414+
(CommandBuffer->ZeCommandList, Ptr, Pattern, PatternSize, Size,
415+
LaunchEvent->ZeEvent, ZeEventList.size(), ZeEventList.data()));
416+
417+
urPrint("calling zeCommandListAppendMemoryFill() with"
418+
" ZeEvent %#lx\n",
419+
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));
420+
421+
return UR_RESULT_SUCCESS;
422+
}
423+
382424
UR_APIEXPORT ur_result_t UR_APICALL
383425
urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
384426
const ur_exp_command_buffer_desc_t *CommandBufferDesc,
@@ -783,6 +825,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
783825
return UR_RESULT_SUCCESS;
784826
}
785827

828+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
829+
ur_exp_command_buffer_handle_t CommandBuffer, ur_mem_handle_t Buffer,
830+
const void *Pattern, size_t PatternSize, size_t Offset, size_t Size,
831+
uint32_t NumSyncPointsInWaitList,
832+
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
833+
ur_exp_command_buffer_sync_point_t *SyncPoint) {
834+
835+
std::scoped_lock<ur_shared_mutex> Lock(Buffer->Mutex);
836+
837+
char *ZeHandleDst = nullptr;
838+
_ur_buffer *UrBuffer = reinterpret_cast<_ur_buffer *>(Buffer);
839+
UR_CALL(UrBuffer->getZeHandle(ZeHandleDst, ur_mem_handle_t_::write_only,
840+
CommandBuffer->Device));
841+
842+
return enqueueCommandBufferFillHelper(
843+
UR_COMMAND_MEM_BUFFER_FILL, CommandBuffer, ZeHandleDst + Offset,
844+
Pattern, // It will be interpreted as an 8-bit value,
845+
PatternSize, // which is indicated with this pattern_size==1
846+
Size, NumSyncPointsInWaitList, SyncPointWaitList, SyncPoint);
847+
}
848+
849+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
850+
ur_exp_command_buffer_handle_t CommandBuffer, void *Ptr,
851+
const void *Pattern, size_t PatternSize, size_t Size,
852+
uint32_t NumSyncPointsInWaitList,
853+
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
854+
ur_exp_command_buffer_sync_point_t *SyncPoint) {
855+
856+
return enqueueCommandBufferFillHelper(
857+
UR_COMMAND_MEM_BUFFER_FILL, CommandBuffer, Ptr,
858+
Pattern, // It will be interpreted as an 8-bit value,
859+
PatternSize, // which is indicated with this pattern_size==1
860+
Size, NumSyncPointsInWaitList, SyncPointWaitList, SyncPoint);
861+
}
862+
786863
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
787864
ur_exp_command_buffer_handle_t CommandBuffer, ur_queue_handle_t Queue,
788865
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,

source/adapters/level_zero/ur_interface_loader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
326326
pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp;
327327
pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp;
328328
pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp;
329+
pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp;
329330
pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp;
330331
pDdiTable->pfnAppendMemBufferCopyRectExp =
331332
urCommandBufferAppendMemBufferCopyRectExp;
@@ -338,6 +339,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
338339
urCommandBufferAppendMemBufferWriteRectExp;
339340
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
340341
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
342+
pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp;
341343
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;
342344

343345
return retVal;

source/adapters/opencl/command_buffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
273273
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
274274
}
275275

276-
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMembufferFillExp(
276+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
277277
ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer,
278278
const void *pPattern, size_t patternSize, size_t offset, size_t size,
279279
uint32_t numSyncPointsInWaitList,

source/adapters/opencl/ur_interface_loader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
286286
pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp;
287287
pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp;
288288
pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp;
289+
pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp;
289290
pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp;
290291
pDdiTable->pfnAppendMemBufferCopyRectExp =
291292
urCommandBufferAppendMemBufferCopyRectExp;
@@ -298,6 +299,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
298299
urCommandBufferAppendMemBufferWriteRectExp;
299300
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
300301
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
302+
pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp;
301303
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;
302304

303305
return retVal;

0 commit comments

Comments
 (0)