Skip to content

Commit bfb3dac

Browse files
committed
[OpenCL] Implement urEnqueueUSMMemcpy2D and allow large fill patterns.
Normally OpenCL limits fill type operations to a max pattern size of 128, this patch includes a workaround to extend that.
1 parent 55d432c commit bfb3dac

File tree

2 files changed

+165
-28
lines changed

2 files changed

+165
-28
lines changed

source/adapters/opencl/enqueue.cpp

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,47 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
178178
size_t patternSize, size_t offset, size_t size,
179179
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
180180
ur_event_handle_t *phEvent) {
181+
// CL FillBuffer only allows pattern sizes up to the largest CL type:
182+
// long16/double16
183+
if (patternSize <= 128) {
184+
CL_RETURN_ON_FAILURE(
185+
clEnqueueFillBuffer(cl_adapter::cast<cl_command_queue>(hQueue),
186+
cl_adapter::cast<cl_mem>(hBuffer), pPattern,
187+
patternSize, offset, size, numEventsInWaitList,
188+
cl_adapter::cast<const cl_event *>(phEventWaitList),
189+
cl_adapter::cast<cl_event *>(phEvent)));
190+
return UR_RESULT_SUCCESS;
191+
}
192+
193+
auto NumValues = size / sizeof(uint64_t);
194+
auto HostBuffer = new uint64_t[NumValues];
195+
auto NumChunks = patternSize / sizeof(uint64_t);
196+
for (size_t i = 0; i < NumValues; i++) {
197+
HostBuffer[i] = static_cast<const uint64_t *>(pPattern)[i % NumChunks];
198+
}
181199

182-
CL_RETURN_ON_FAILURE(clEnqueueFillBuffer(
200+
cl_event WriteEvent = nullptr;
201+
auto ClErr = clEnqueueWriteBuffer(
183202
cl_adapter::cast<cl_command_queue>(hQueue),
184-
cl_adapter::cast<cl_mem>(hBuffer), pPattern, patternSize, offset, size,
203+
cl_adapter::cast<cl_mem>(hBuffer), false, offset, size, HostBuffer,
185204
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
186-
cl_adapter::cast<cl_event *>(phEvent)));
205+
&WriteEvent);
206+
if (ClErr != CL_SUCCESS) {
207+
delete[] HostBuffer;
208+
CL_RETURN_ON_FAILURE(ClErr);
209+
}
210+
211+
auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
212+
delete[] static_cast<uint64_t *>(pUserData);
213+
};
214+
CL_RETURN_ON_FAILURE(
215+
clSetEventCallback(WriteEvent, CL_COMPLETE, DeleteCallback, HostBuffer));
216+
217+
if (phEvent) {
218+
*phEvent = cl_adapter::cast<ur_event_handle_t>(WriteEvent);
219+
} else {
220+
CL_RETURN_ON_FAILURE(clReleaseEvent(WriteEvent));
221+
}
187222

188223
return UR_RESULT_SUCCESS;
189224
}
@@ -350,9 +385,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe(
350385
return mapCLErrorToUR(CLErr);
351386
}
352387

353-
clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr;
388+
cl_ext::clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr;
354389
ur_result_t RetVal =
355-
cl_ext::getExtFuncFromContext<clEnqueueReadHostPipeINTEL_fn>(
390+
cl_ext::getExtFuncFromContext<cl_ext::clEnqueueReadHostPipeINTEL_fn>(
356391
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueReadHostPipeINTELCache,
357392
cl_ext::EnqueueReadHostPipeName, &FuncPtr);
358393

@@ -382,9 +417,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe(
382417
return mapCLErrorToUR(CLErr);
383418
}
384419

385-
clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr;
420+
cl_ext::clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr;
386421
ur_result_t RetVal =
387-
cl_ext::getExtFuncFromContext<clEnqueueWriteHostPipeINTEL_fn>(
422+
cl_ext::getExtFuncFromContext<cl_ext::clEnqueueWriteHostPipeINTEL_fn>(
388423
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueWriteHostPipeINTELCache,
389424
cl_ext::EnqueueWriteHostPipeName, &FuncPtr);
390425

source/adapters/opencl/usm.cpp

Lines changed: 123 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
197197
ur_queue_handle_t hQueue, void *ptr, size_t patternSize,
198198
const void *pPattern, size_t size, uint32_t numEventsInWaitList,
199199
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
200-
201200
// Have to look up the context from the kernel
202201
cl_context CLContext;
203202
cl_int CLErr = clGetCommandQueueInfo(
@@ -207,20 +206,82 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
207206
return mapCLErrorToUR(CLErr);
208207
}
209208

210-
clEnqueueMemFillINTEL_fn FuncPtr = nullptr;
211-
ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemFillINTEL_fn>(
212-
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemFillINTELCache,
213-
cl_ext::EnqueueMemFillName, &FuncPtr);
209+
if (patternSize <= 128) {
210+
clEnqueueMemFillINTEL_fn EnqueueMemFill = nullptr;
211+
UR_RETURN_ON_FAILURE(
212+
cl_ext::getExtFuncFromContext<clEnqueueMemFillINTEL_fn>(
213+
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemFillINTELCache,
214+
cl_ext::EnqueueMemFillName, &EnqueueMemFill));
215+
216+
CL_RETURN_ON_FAILURE(
217+
EnqueueMemFill(cl_adapter::cast<cl_command_queue>(hQueue), ptr,
218+
pPattern, patternSize, size, numEventsInWaitList,
219+
cl_adapter::cast<const cl_event *>(phEventWaitList),
220+
cl_adapter::cast<cl_event *>(phEvent)));
221+
return UR_RESULT_SUCCESS;
222+
}
214223

215-
if (FuncPtr) {
216-
RetVal = mapCLErrorToUR(
217-
FuncPtr(cl_adapter::cast<cl_command_queue>(hQueue), ptr, pPattern,
218-
patternSize, size, numEventsInWaitList,
219-
cl_adapter::cast<const cl_event *>(phEventWaitList),
220-
cl_adapter::cast<cl_event *>(phEvent)));
224+
// OpenCL only supports pattern sizes as large as the largest CL type
225+
// (double16/long16 - 128 bytes), anything larger we need to do on the host
226+
// side and copy it into the target allocation.
227+
clHostMemAllocINTEL_fn HostMemAlloc = nullptr;
228+
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
229+
CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache,
230+
cl_ext::HostMemAllocName, &HostMemAlloc));
231+
232+
clEnqueueMemcpyINTEL_fn USMMemcpy = nullptr;
233+
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
234+
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache,
235+
cl_ext::EnqueueMemcpyName, &USMMemcpy));
236+
237+
clMemBlockingFreeINTEL_fn USMFree = nullptr;
238+
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
239+
CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache,
240+
cl_ext::MemBlockingFreeName, &USMFree));
241+
242+
cl_int ClErr = CL_SUCCESS;
243+
auto HostBuffer = static_cast<uint64_t *>(
244+
HostMemAlloc(CLContext, nullptr, size, 0, &ClErr));
245+
CL_RETURN_ON_FAILURE(ClErr);
246+
247+
auto NumValues = size / sizeof(uint64_t);
248+
auto NumChunks = patternSize / sizeof(uint64_t);
249+
for (size_t i = 0; i < NumValues; i++) {
250+
HostBuffer[i] = static_cast<const uint64_t *>(pPattern)[i % NumChunks];
221251
}
222252

223-
return RetVal;
253+
cl_event CopyEvent = nullptr;
254+
CL_RETURN_ON_FAILURE(USMMemcpy(
255+
cl_adapter::cast<cl_command_queue>(hQueue), false, ptr, HostBuffer, size,
256+
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
257+
&CopyEvent));
258+
259+
struct DeleteCallbackInfo {
260+
clMemBlockingFreeINTEL_fn USMFree;
261+
cl_context CLContext;
262+
void *HostBuffer;
263+
void execute() {
264+
USMFree(CLContext, HostBuffer);
265+
delete this;
266+
}
267+
};
268+
269+
auto Info = new DeleteCallbackInfo{USMFree, CLContext, HostBuffer};
270+
271+
auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
272+
static_cast<DeleteCallbackInfo *>(pUserData)->execute();
273+
};
274+
275+
CL_RETURN_ON_FAILURE(
276+
clSetEventCallback(CopyEvent, CL_COMPLETE, DeleteCallback, Info));
277+
278+
if (phEvent) {
279+
*phEvent = cl_adapter::cast<ur_event_handle_t>(CopyEvent);
280+
} else {
281+
CL_RETURN_ON_FAILURE(clReleaseEvent(CopyEvent));
282+
}
283+
284+
return UR_RESULT_SUCCESS;
224285
}
225286

226287
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
@@ -343,18 +404,59 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(
343404
[[maybe_unused]] uint32_t numEventsInWaitList,
344405
[[maybe_unused]] const ur_event_handle_t *phEventWaitList,
345406
[[maybe_unused]] ur_event_handle_t *phEvent) {
346-
return UR_RESULT_ERROR_INVALID_OPERATION;
407+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
347408
}
348409

349410
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
350-
[[maybe_unused]] ur_queue_handle_t hQueue, [[maybe_unused]] bool blocking,
351-
[[maybe_unused]] void *pDst, [[maybe_unused]] size_t dstPitch,
352-
[[maybe_unused]] const void *pSrc, [[maybe_unused]] size_t srcPitch,
353-
[[maybe_unused]] size_t width, [[maybe_unused]] size_t height,
354-
[[maybe_unused]] uint32_t numEventsInWaitList,
355-
[[maybe_unused]] const ur_event_handle_t *phEventWaitList,
356-
[[maybe_unused]] ur_event_handle_t *phEvent) {
357-
return UR_RESULT_ERROR_INVALID_OPERATION;
411+
ur_queue_handle_t hQueue, bool blocking, void *pDst, size_t dstPitch,
412+
const void *pSrc, size_t srcPitch, size_t width, size_t height,
413+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
414+
ur_event_handle_t *phEvent) {
415+
cl_context CLContext;
416+
CL_RETURN_ON_FAILURE(clGetCommandQueueInfo(
417+
cl_adapter::cast<cl_command_queue>(hQueue), CL_QUEUE_CONTEXT,
418+
sizeof(cl_context), &CLContext, nullptr));
419+
420+
clEnqueueMemcpyINTEL_fn FuncPtr = nullptr;
421+
ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
422+
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache,
423+
cl_ext::EnqueueMemcpyName, &FuncPtr);
424+
425+
if (!FuncPtr) {
426+
return RetVal;
427+
}
428+
429+
std::vector<cl_event> Events;
430+
for (size_t HeightIndex = 0; HeightIndex < height; HeightIndex++) {
431+
cl_event Event = nullptr;
432+
auto ClResult =
433+
FuncPtr(cl_adapter::cast<cl_command_queue>(hQueue), false,
434+
static_cast<uint8_t *>(pDst) + dstPitch * HeightIndex,
435+
static_cast<const uint8_t *>(pSrc) + srcPitch * HeightIndex,
436+
width, numEventsInWaitList,
437+
cl_adapter::cast<const cl_event *>(phEventWaitList), &Event);
438+
Events.push_back(Event);
439+
if (ClResult != CL_SUCCESS) {
440+
for (const auto &E : Events) {
441+
clReleaseEvent(E);
442+
}
443+
CL_RETURN_ON_FAILURE(ClResult);
444+
}
445+
}
446+
cl_int ClResult = CL_SUCCESS;
447+
if (blocking) {
448+
ClResult = clWaitForEvents(Events.size(), Events.data());
449+
}
450+
if (phEvent && ClResult == CL_SUCCESS) {
451+
ClResult = clEnqueueBarrierWithWaitList(
452+
cl_adapter::cast<cl_command_queue>(hQueue), Events.size(),
453+
Events.data(), cl_adapter::cast<cl_event *>(phEvent));
454+
}
455+
for (const auto &E : Events) {
456+
clReleaseEvent(E);
457+
}
458+
CL_RETURN_ON_FAILURE(ClResult)
459+
return UR_RESULT_SUCCESS;
358460
}
359461

360462
UR_APIEXPORT ur_result_t UR_APICALL

0 commit comments

Comments
 (0)