@@ -197,7 +197,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
197
197
ur_queue_handle_t hQueue, void *ptr, size_t patternSize,
198
198
const void *pPattern, size_t size, uint32_t numEventsInWaitList,
199
199
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
200
-
201
200
// Have to look up the context from the kernel
202
201
cl_context CLContext;
203
202
cl_int CLErr = clGetCommandQueueInfo (
@@ -207,20 +206,82 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
207
206
return mapCLErrorToUR (CLErr);
208
207
}
209
208
210
- clEnqueueMemFillINTEL_fn FuncPtr = nullptr ;
211
- ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemFillINTEL_fn>(
212
- CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemFillINTELCache ,
213
- cl_ext::EnqueueMemFillName, &FuncPtr);
209
+ if (patternSize <= 128 ) {
210
+ clEnqueueMemFillINTEL_fn EnqueueMemFill = nullptr ;
211
+ UR_RETURN_ON_FAILURE (
212
+ cl_ext::getExtFuncFromContext<clEnqueueMemFillINTEL_fn>(
213
+ CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemFillINTELCache ,
214
+ cl_ext::EnqueueMemFillName, &EnqueueMemFill));
215
+
216
+ CL_RETURN_ON_FAILURE (
217
+ EnqueueMemFill (cl_adapter::cast<cl_command_queue>(hQueue), ptr,
218
+ pPattern, patternSize, size, numEventsInWaitList,
219
+ cl_adapter::cast<const cl_event *>(phEventWaitList),
220
+ cl_adapter::cast<cl_event *>(phEvent)));
221
+ return UR_RESULT_SUCCESS;
222
+ }
214
223
215
- if (FuncPtr) {
216
- RetVal = mapCLErrorToUR (
217
- FuncPtr (cl_adapter::cast<cl_command_queue>(hQueue), ptr, pPattern,
218
- patternSize, size, numEventsInWaitList,
219
- cl_adapter::cast<const cl_event *>(phEventWaitList),
220
- cl_adapter::cast<cl_event *>(phEvent)));
224
+ // OpenCL only supports pattern sizes as large as the largest CL type
225
+ // (double16/long16 - 128 bytes), anything larger we need to do on the host
226
+ // side and copy it into the target allocation.
227
+ clHostMemAllocINTEL_fn HostMemAlloc = nullptr ;
228
+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
229
+ CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache ,
230
+ cl_ext::HostMemAllocName, &HostMemAlloc));
231
+
232
+ clEnqueueMemcpyINTEL_fn USMMemcpy = nullptr ;
233
+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
234
+ CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache ,
235
+ cl_ext::EnqueueMemcpyName, &USMMemcpy));
236
+
237
+ clMemBlockingFreeINTEL_fn USMFree = nullptr ;
238
+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
239
+ CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache ,
240
+ cl_ext::MemBlockingFreeName, &USMFree));
241
+
242
+ cl_int ClErr = CL_SUCCESS;
243
+ auto HostBuffer = static_cast <uint64_t *>(
244
+ HostMemAlloc (CLContext, nullptr , size, 0 , &ClErr));
245
+ CL_RETURN_ON_FAILURE (ClErr);
246
+
247
+ auto NumValues = size / sizeof (uint64_t );
248
+ auto NumChunks = patternSize / sizeof (uint64_t );
249
+ for (size_t i = 0 ; i < NumValues; i++) {
250
+ HostBuffer[i] = static_cast <const uint64_t *>(pPattern)[i % NumChunks];
221
251
}
222
252
223
- return RetVal;
253
+ cl_event CopyEvent = nullptr ;
254
+ CL_RETURN_ON_FAILURE (USMMemcpy (
255
+ cl_adapter::cast<cl_command_queue>(hQueue), false , ptr, HostBuffer, size,
256
+ numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
257
+ &CopyEvent));
258
+
259
+ struct DeleteCallbackInfo {
260
+ clMemBlockingFreeINTEL_fn USMFree;
261
+ cl_context CLContext;
262
+ void *HostBuffer;
263
+ void execute () {
264
+ USMFree (CLContext, HostBuffer);
265
+ delete this ;
266
+ }
267
+ };
268
+
269
+ auto Info = new DeleteCallbackInfo{USMFree, CLContext, HostBuffer};
270
+
271
+ auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
272
+ static_cast <DeleteCallbackInfo *>(pUserData)->execute ();
273
+ };
274
+
275
+ CL_RETURN_ON_FAILURE (
276
+ clSetEventCallback (CopyEvent, CL_COMPLETE, DeleteCallback, Info));
277
+
278
+ if (phEvent) {
279
+ *phEvent = cl_adapter::cast<ur_event_handle_t >(CopyEvent);
280
+ } else {
281
+ CL_RETURN_ON_FAILURE (clReleaseEvent (CopyEvent));
282
+ }
283
+
284
+ return UR_RESULT_SUCCESS;
224
285
}
225
286
226
287
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy (
@@ -343,18 +404,59 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(
343
404
[[maybe_unused]] uint32_t numEventsInWaitList,
344
405
[[maybe_unused]] const ur_event_handle_t *phEventWaitList,
345
406
[[maybe_unused]] ur_event_handle_t *phEvent) {
346
- return UR_RESULT_ERROR_INVALID_OPERATION ;
407
+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE ;
347
408
}
348
409
349
410
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D (
350
- [[maybe_unused]] ur_queue_handle_t hQueue, [[maybe_unused]] bool blocking,
351
- [[maybe_unused]] void *pDst, [[maybe_unused]] size_t dstPitch,
352
- [[maybe_unused]] const void *pSrc, [[maybe_unused]] size_t srcPitch,
353
- [[maybe_unused]] size_t width, [[maybe_unused]] size_t height,
354
- [[maybe_unused]] uint32_t numEventsInWaitList,
355
- [[maybe_unused]] const ur_event_handle_t *phEventWaitList,
356
- [[maybe_unused]] ur_event_handle_t *phEvent) {
357
- return UR_RESULT_ERROR_INVALID_OPERATION;
411
+ ur_queue_handle_t hQueue, bool blocking, void *pDst, size_t dstPitch,
412
+ const void *pSrc, size_t srcPitch, size_t width, size_t height,
413
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
414
+ ur_event_handle_t *phEvent) {
415
+ cl_context CLContext;
416
+ CL_RETURN_ON_FAILURE (clGetCommandQueueInfo (
417
+ cl_adapter::cast<cl_command_queue>(hQueue), CL_QUEUE_CONTEXT,
418
+ sizeof (cl_context), &CLContext, nullptr ));
419
+
420
+ clEnqueueMemcpyINTEL_fn FuncPtr = nullptr ;
421
+ ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
422
+ CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache ,
423
+ cl_ext::EnqueueMemcpyName, &FuncPtr);
424
+
425
+ if (!FuncPtr) {
426
+ return RetVal;
427
+ }
428
+
429
+ std::vector<cl_event> Events;
430
+ for (size_t HeightIndex = 0 ; HeightIndex < height; HeightIndex++) {
431
+ cl_event Event = nullptr ;
432
+ auto ClResult =
433
+ FuncPtr (cl_adapter::cast<cl_command_queue>(hQueue), false ,
434
+ static_cast <uint8_t *>(pDst) + dstPitch * HeightIndex,
435
+ static_cast <const uint8_t *>(pSrc) + srcPitch * HeightIndex,
436
+ width, numEventsInWaitList,
437
+ cl_adapter::cast<const cl_event *>(phEventWaitList), &Event);
438
+ Events.push_back (Event);
439
+ if (ClResult != CL_SUCCESS) {
440
+ for (const auto &E : Events) {
441
+ clReleaseEvent (E);
442
+ }
443
+ CL_RETURN_ON_FAILURE (ClResult);
444
+ }
445
+ }
446
+ cl_int ClResult = CL_SUCCESS;
447
+ if (blocking) {
448
+ ClResult = clWaitForEvents (Events.size (), Events.data ());
449
+ }
450
+ if (phEvent && ClResult == CL_SUCCESS) {
451
+ ClResult = clEnqueueBarrierWithWaitList (
452
+ cl_adapter::cast<cl_command_queue>(hQueue), Events.size (),
453
+ Events.data (), cl_adapter::cast<cl_event *>(phEvent));
454
+ }
455
+ for (const auto &E : Events) {
456
+ clReleaseEvent (E);
457
+ }
458
+ CL_RETURN_ON_FAILURE (ClResult)
459
+ return UR_RESULT_SUCCESS;
358
460
}
359
461
360
462
UR_APIEXPORT ur_result_t UR_APICALL
0 commit comments