diff --git a/include/ur_api.h b/include/ur_api.h index d7621bda32..87f060816c 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -8279,18 +8279,20 @@ typedef enum ur_map_flag_t { #define UR_MAP_FLAGS_MASK 0xfffffff8 /////////////////////////////////////////////////////////////////////////////// -/// @brief Map flags +/// @brief USM migration flags, indicating the direction data is migrated in typedef uint32_t ur_usm_migration_flags_t; typedef enum ur_usm_migration_flag_t { - /// Default migration TODO: Add more enums! - UR_USM_MIGRATION_FLAG_DEFAULT = UR_BIT(0), + /// Migrate data from host to device + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE = UR_BIT(0), + /// Migrate data from device to host + UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST = UR_BIT(1), /// @cond UR_USM_MIGRATION_FLAG_FORCE_UINT32 = 0x7fffffff /// @endcond } ur_usm_migration_flag_t; /// @brief Bit Mask for validating ur_usm_migration_flags_t -#define UR_USM_MIGRATION_FLAGS_MASK 0xfffffffe +#define UR_USM_MIGRATION_FLAGS_MASK 0xfffffffc /////////////////////////////////////////////////////////////////////////////// /// @brief Enqueue a command to map a region of the buffer object into the host @@ -8549,7 +8551,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( const void *pMem, /// [in] size in bytes to be fetched size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] size of the event wait list uint32_t numEventsInWaitList, @@ -10909,7 +10911,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/include/ur_print.hpp b/include/ur_print.hpp index 5c5f573477..100c243a52 100644 --- a/include/ur_print.hpp +++ b/include/ur_print.hpp @@ -10300,8 +10300,11 @@ inline ur_result_t printFlag(std::ostream &os, uint32_t flag) { inline std::ostream &operator<<(std::ostream &os, enum ur_usm_migration_flag_t value) { switch (value) { - case UR_USM_MIGRATION_FLAG_DEFAULT: - os << "UR_USM_MIGRATION_FLAG_DEFAULT"; + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + os << "UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE"; + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + os << "UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST"; break; default: os << "unknown enumerator"; @@ -10319,15 +10322,26 @@ inline ur_result_t printFlag(std::ostream &os, uint32_t val = flag; bool first = true; - if ((val & UR_USM_MIGRATION_FLAG_DEFAULT) == - (uint32_t)UR_USM_MIGRATION_FLAG_DEFAULT) { - val ^= (uint32_t)UR_USM_MIGRATION_FLAG_DEFAULT; + if ((val & UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) == + (uint32_t)UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) { + val ^= (uint32_t)UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE; + if (!first) { + os << " | "; + } else { + first = false; + } + os << UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE; + } + + if ((val & UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST) == + (uint32_t)UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST) { + val ^= (uint32_t)UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST; if (!first) { os << " | "; } else { first = false; } - os << UR_USM_MIGRATION_FLAG_DEFAULT; + os << UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST; } if (val != 0) { std::bitset<32> bits(val); diff --git a/scripts/core/enqueue.yml b/scripts/core/enqueue.yml index cc2597962d..a5b85adfb0 100644 --- a/scripts/core/enqueue.yml +++ b/scripts/core/enqueue.yml @@ -849,13 +849,16 @@ etors: value: "$X_BIT(2)" --- #-------------------------------------------------------------------------- type: enum -desc: "Map flags" -class: $xDevice +desc: "USM migration flags, indicating the direction data is migrated in" +class: $xEnqueue name: $x_usm_migration_flags_t etors: - - name: DEFAULT - desc: "Default migration TODO: Add more enums! " + - name: HOST_TO_DEVICE + desc: "Migrate data from host to device" value: "$X_BIT(0)" + - name: DEVICE_TO_HOST + desc: "Migrate data from device to host" + value: "$X_BIT(1)" --- #-------------------------------------------------------------------------- type: function desc: "Enqueue a command to map a region of the buffer object into the host address space and return a pointer to the mapped region" @@ -1144,7 +1147,7 @@ params: desc: "[in] size in bytes to be fetched" - type: $x_usm_migration_flags_t name: flags - desc: "[in] USM prefetch flags" + desc: "[in] USM migration flags" - type: uint32_t name: numEventsInWaitList desc: "[in] size of the event wait list" diff --git a/scripts/core/exp-command-buffer.yml b/scripts/core/exp-command-buffer.yml index 218c626423..4e43437f82 100644 --- a/scripts/core/exp-command-buffer.yml +++ b/scripts/core/exp-command-buffer.yml @@ -1024,7 +1024,7 @@ params: desc: "[in] size in bytes to be fetched." - type: $x_usm_migration_flags_t name: flags - desc: "[in] USM prefetch flags" + desc: "[in] USM migration flags" - type: uint32_t name: numSyncPointsInWaitList desc: "[in] The number of sync points in the provided dependency list." diff --git a/source/adapters/cuda/enqueue.cpp b/source/adapters/cuda/enqueue.cpp index f0334e6312..3c4580981d 100644 --- a/source/adapters/cuda/enqueue.cpp +++ b/source/adapters/cuda/enqueue.cpp @@ -1598,13 +1598,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( ur_queue_handle_t hQueue, const void *pMem, size_t size, ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - std::ignore = flags; + ur_device_handle_t Device = hQueue->getDevice(); + CUdevice TargetDevice; + switch (flags) { + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + TargetDevice = Device->get(); + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + TargetDevice = CU_DEVICE_CPU; + break; + default: + setErrorMessage("Invalid USM migration flag", + UR_RESULT_ERROR_INVALID_ENUMERATION); + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } size_t PointerRangeSize = 0; UR_CHECK_ERROR(cuPointerGetAttribute( &PointerRangeSize, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)pMem)); UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE); - ur_device_handle_t Device = hQueue->getDevice(); // Certain cuda devices and Windows do not have support for some Unified // Memory features. cuMemPrefetchAsync requires concurrent memory access @@ -1641,7 +1653,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( UR_CHECK_ERROR(EventPtr->start()); } UR_CHECK_ERROR( - cuMemPrefetchAsync((CUdeviceptr)pMem, size, Device->get(), CuStream)); + cuMemPrefetchAsync((CUdeviceptr)pMem, size, TargetDevice, CuStream)); if (phEvent) { UR_CHECK_ERROR(EventPtr->record()); *phEvent = EventPtr.release(); diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index 849369de4b..6924bf523b 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -1365,10 +1365,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( ur_queue_handle_t hQueue, const void *pMem, size_t size, ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - std::ignore = flags; void *HIPDevicePtr = const_cast(pMem); ur_device_handle_t Device = hQueue->getDevice(); + hipDevice_t TargetDevice; + switch (flags) { + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + TargetDevice = Device->get(); + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + // HIP doesn't have a constant for host like CUDA does; -1 is used instead + // https://github.com/ROCm/HIP/blob/3d60bd3a6415c2/docs/how-to/unified_memory.rst#L376 + TargetDevice = -1; + break; + default: + setErrorMessage("Invalid USM migration flag", + UR_RESULT_ERROR_INVALID_ENUMERATION); + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } // HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5, // so we can't perform this check for such cases. @@ -1428,8 +1442,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( return UR_RESULT_ERROR_ADAPTER_SPECIFIC; } - UR_CHECK_ERROR( - hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream)); + UR_CHECK_ERROR(hipMemPrefetchAsync(pMem, size, TargetDevice, HIPStream)); releaseEvent(); } catch (ur_result_t Err) { return Err; diff --git a/source/adapters/level_zero/command_buffer.cpp b/source/adapters/level_zero/command_buffer.cpp index 879ee0f1cc..1b7980ac7c 100644 --- a/source/adapters/level_zero/command_buffer.cpp +++ b/source/adapters/level_zero/command_buffer.cpp @@ -1304,13 +1304,18 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp( std::ignore = EventWaitList; std::ignore = Event; std::ignore = Command; - std::ignore = Flags; if (CommandBuffer->IsInOrderCmdList) { - // Add the prefetch command to the command-buffer. - // Note that L0 does not handle migration flags. - ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, - (CommandBuffer->ZeComputeCommandList, Mem, Size)); + if (Flags == UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) { + // Add the prefetch command to the command-buffer. + ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, + (CommandBuffer->ZeComputeCommandList, Mem, Size)); + } else { + // L0 currently does not handle migration flags -- All other migration + // behavior is ignored: + logger::warning("USM migration from device to host is not currently " + "supported by level zero."); + } } else { std::vector ZeEventList; ze_event_handle_t ZeLaunchEvent = nullptr; @@ -1324,10 +1329,16 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp( ZeEventList.data())); } - // Add the prefetch command to the command-buffer. - // Note that L0 does not handle migration flags. - ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, - (CommandBuffer->ZeComputeCommandList, Mem, Size)); + if (Flags == UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) { + // Add the prefetch command to the command buffer. + ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, + (CommandBuffer->ZeComputeCommandList, Mem, Size)); + } else { + // L0 currently does not handle migration flags -- All other migration + // behavior is ignored: + logger::warning("USM migration from device to host is not currently " + "supported by level zero."); + } // Level Zero does not have a completion "event" with the prefetch API, // so manually add command to signal our event. diff --git a/source/adapters/level_zero/memory.cpp b/source/adapters/level_zero/memory.cpp index 4a5cb787dc..c5eac60c79 100644 --- a/source/adapters/level_zero/memory.cpp +++ b/source/adapters/level_zero/memory.cpp @@ -1290,7 +1290,6 @@ ur_result_t urEnqueueUSMPrefetch( /// [in,out][optional] return an event object that identifies this /// particular command instance. ur_event_handle_t *OutEvent) { - std::ignore = Flags; // Lock automatically releases when this goes out of scope. std::scoped_lock lock(Queue->Mutex); @@ -1330,8 +1329,14 @@ ur_result_t urEnqueueUSMPrefetch( ZE2UR_CALL(zeCommandListAppendWaitOnEvents, (ZeCommandList, WaitList.Length, WaitList.ZeEventList)); } - // TODO: figure out how to translate "flags" - ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, (ZeCommandList, Mem, Size)); + + if (Flags == UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) { + ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, (ZeCommandList, Mem, Size)); + } else { + // L0 does not suppot migrating from device to host yet: skip procedure + logger::warning("urEnqueueUSMPrefetch: Prefetch from device to host not yet" + " supported by level zero"); + } // TODO: Level Zero does not have a completion "event" with the prefetch API, // so manually add command to signal our event. diff --git a/source/adapters/level_zero/v2/queue_immediate_in_order.cpp b/source/adapters/level_zero/v2/queue_immediate_in_order.cpp index d33ac12f7e..956964737b 100644 --- a/source/adapters/level_zero/v2/queue_immediate_in_order.cpp +++ b/source/adapters/level_zero/v2/queue_immediate_in_order.cpp @@ -629,8 +629,6 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueUSMPrefetch( ur_event_handle_t *phEvent) { TRACK_SCOPE_LATENCY("ur_queue_immediate_in_order_t::enqueueUSMPrefetch"); - std::ignore = flags; - std::scoped_lock lock(this->Mutex); auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_USM_PREFETCH); @@ -643,12 +641,19 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueUSMPrefetch( zeCommandListAppendWaitOnEvents, (commandListManager.getZeCommandList(), numWaitEvents, pWaitEvents)); } - // TODO: figure out how to translate "flags" - ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, - (commandListManager.getZeCommandList(), pMem, size)); - if (zeSignalEvent) { - ZE2UR_CALL(zeCommandListAppendSignalEvent, - (commandListManager.getZeCommandList(), zeSignalEvent)); + + if (flags == UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) { + ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, + (commandListManager.getZeCommandList(), pMem, size)); + if (zeSignalEvent) { + ZE2UR_CALL(zeCommandListAppendSignalEvent, + (commandListManager.getZeCommandList(), zeSignalEvent)); + } + } else { + // L0 does not suppot migrating from device to host yet: skip procedure + setErrorMessage("Prefetch from device to host not yet supported by level " + "zero.", + UR_RESULT_SUCCESS); } return UR_RESULT_SUCCESS; diff --git a/source/adapters/mock/ur_mockddi.cpp b/source/adapters/mock/ur_mockddi.cpp index 6d5034d07b..296874d342 100644 --- a/source/adapters/mock/ur_mockddi.cpp +++ b/source/adapters/mock/ur_mockddi.cpp @@ -6720,7 +6720,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch( const void *pMem, /// [in] size in bytes to be fetched size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] size of the event wait list uint32_t numEventsInWaitList, @@ -9584,7 +9584,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/source/adapters/opencl/common.cpp b/source/adapters/opencl/common.cpp index 33da43a182..140e6f4aac 100644 --- a/source/adapters/opencl/common.cpp +++ b/source/adapters/opencl/common.cpp @@ -13,15 +13,15 @@ namespace cl_adapter { /* Global variables for urAdapterGetLastError() */ -thread_local int32_t ErrorMessageCode = 0; +thread_local ur_result_t ErrorMessageCode = UR_RESULT_SUCCESS; thread_local char ErrorMessage[MaxMessageSize]{}; -[[maybe_unused]] void setErrorMessage(const char *Message, int32_t ErrorCode) { +[[maybe_unused]] void setErrorMessage(const char *Message, + ur_result_t ErrorCode) { assert(strlen(Message) < cl_adapter::MaxMessageSize); // Copy at most MaxMessageSize - 1 bytes to ensure the resultant string is // always null terminated. strncpy(cl_adapter::ErrorMessage, Message, MaxMessageSize - 1); - ErrorMessageCode = ErrorCode; } } // namespace cl_adapter diff --git a/source/adapters/opencl/common.hpp b/source/adapters/opencl/common.hpp index 6857220dc2..165f2d2878 100644 --- a/source/adapters/opencl/common.hpp +++ b/source/adapters/opencl/common.hpp @@ -150,7 +150,7 @@ inline const OpenCLVersion V3_0(3, 0); namespace cl_adapter { constexpr size_t MaxMessageSize = 256; -extern thread_local int32_t ErrorMessageCode; +extern thread_local ur_result_t ErrorMessageCode; extern thread_local char ErrorMessage[MaxMessageSize]; // Utility function for setting a message and warning @@ -203,6 +203,7 @@ CONSTFIX char EnqueueWriteGlobalVariableName[] = "clEnqueueWriteGlobalVariableINTEL"; CONSTFIX char EnqueueReadGlobalVariableName[] = "clEnqueueReadGlobalVariableINTEL"; +CONSTFIX char EnqueueMigrateMemName[] = "clEnqueueMigrateMemINTEL"; // Names of host pipe functions queried from OpenCL CONSTFIX char EnqueueReadHostPipeName[] = "clEnqueueReadHostPipeINTEL"; CONSTFIX char EnqueueWriteHostPipeName[] = "clEnqueueWriteHostPipeINTEL"; diff --git a/source/adapters/opencl/extension_functions.def b/source/adapters/opencl/extension_functions.def index 3f5e3ea917..e947407cbc 100644 --- a/source/adapters/opencl/extension_functions.def +++ b/source/adapters/opencl/extension_functions.def @@ -8,6 +8,7 @@ CL_EXTENSION_FUNC(clMemBlockingFreeINTEL) CL_EXTENSION_FUNC(clSetKernelArgMemPointerINTEL) CL_EXTENSION_FUNC(clEnqueueMemFillINTEL) CL_EXTENSION_FUNC(clEnqueueMemcpyINTEL) +CL_EXTENSION_FUNC(clEnqueueMigrateMemINTEL) CL_EXTENSION_FUNC(clGetMemAllocInfoINTEL) CL_EXTENSION_FUNC(clEnqueueWriteGlobalVariable) CL_EXTENSION_FUNC(clEnqueueReadGlobalVariable) diff --git a/source/adapters/opencl/usm.cpp b/source/adapters/opencl/usm.cpp index 7961cb76ff..ef59eb1dcf 100644 --- a/source/adapters/opencl/usm.cpp +++ b/source/adapters/opencl/usm.cpp @@ -488,43 +488,60 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( ur_queue_handle_t hQueue, [[maybe_unused]] const void *pMem, - [[maybe_unused]] size_t size, - [[maybe_unused]] ur_usm_migration_flags_t flags, + [[maybe_unused]] size_t size, ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - return mapCLErrorToUR(clEnqueueMarkerWithWaitList( - cl_adapter::cast(hQueue), numEventsInWaitList, - cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); - - /* - // Use this once impls support it. // Have to look up the context from the kernel cl_context CLContext; - cl_int CLErr = - clGetCommandQueueInfo(cl_adapter::cast(hQueue), - CL_QUEUE_CONTEXT, sizeof(cl_context), - &CLContext, nullptr); + cl_int CLErr = clGetCommandQueueInfo( + cl_adapter::cast(hQueue), CL_QUEUE_CONTEXT, + sizeof(cl_context), &CLContext, nullptr); if (CLErr != CL_SUCCESS) { - return map_cl_error_to_ur(CLErr); + return mapCLErrorToUR(CLErr); } clEnqueueMigrateMemINTEL_fn FuncPtr; - ur_result_t Err = cl_ext::getExtFuncFromContext( - CLContext, "clEnqueueMigrateMemINTEL", &FuncPtr); + if (cl_ext::getExtFuncFromContext( + CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMigrateMemINTELCache, + cl_ext::EnqueueMigrateMemName, &FuncPtr)) { + // Exit gracefully if unable to find USM function + cl_adapter::setErrorMessage("Prefetch hint ignored as current OpenCL versio" + "n does not support clEnqueueMigrateMemINTEL", + UR_RESULT_SUCCESS); + return UR_RESULT_SUCCESS; + } - ur_result_t RetVal; - if (Err != UR_RESULT_SUCCESS) { - RetVal = Err; - } else { - RetVal = map_cl_error_to_ur( - FuncPtr(cl_adapter::cast(hQueue), pMem, size, flags, - numEventsInWaitList, - reinterpret_cast(phEventWaitList), - reinterpret_cast(phEvent))); + cl_mem_migration_flags MigrationFlag; + switch (flags) { + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + MigrationFlag = CL_MIGRATE_MEM_OBJECT_CONTENT_UNDEFINED; + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + MigrationFlag = CL_MIGRATE_MEM_OBJECT_HOST; + break; + default: + cl_adapter::setErrorMessage("Invalid USM migration flag", + UR_RESULT_ERROR_INVALID_ENUMERATION); + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + + cl_int Result = FuncPtr(cl_adapter::cast(hQueue), pMem, + size, MigrationFlag, numEventsInWaitList, + reinterpret_cast(phEventWaitList), + reinterpret_cast(phEvent)); + + switch (Result) { + case CL_INVALID_VALUE: + cl_adapter::setErrorMessage("Prefetch hint ignored as current OpenCL " + "version does not support prefetching " + "(clEnqueueMigrateMemINTEL) from current " + "device to host (CL_MIGRATE_MEM_OBJECT_HOST)", + UR_RESULT_SUCCESS); + return UR_RESULT_SUCCESS; + default: + return mapCLErrorToUR(Result); } - */ } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMAdvise( diff --git a/source/loader/layers/tracing/ur_trcddi.cpp b/source/loader/layers/tracing/ur_trcddi.cpp index f4a7b7e60a..e6bdbc33af 100644 --- a/source/loader/layers/tracing/ur_trcddi.cpp +++ b/source/loader/layers/tracing/ur_trcddi.cpp @@ -5531,7 +5531,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch( const void *pMem, /// [in] size in bytes to be fetched size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] size of the event wait list uint32_t numEventsInWaitList, @@ -7993,7 +7993,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index eb2bd4c353..4dc9546f48 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -5922,7 +5922,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch( const void *pMem, /// [in] size in bytes to be fetched size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] size of the event wait list uint32_t numEventsInWaitList, @@ -8631,7 +8631,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index ec6081509f..5f91219b19 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -5627,7 +5627,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch( const void *pMem, /// [in] size in bytes to be fetched size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] size of the event wait list uint32_t numEventsInWaitList, @@ -8152,7 +8152,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 5761dab3a4..7d40ef7231 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -6243,7 +6243,7 @@ ur_result_t UR_APICALL urEnqueueUSMPrefetch( const void *pMem, /// [in] size in bytes to be fetched size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] size of the event wait list uint32_t numEventsInWaitList, @@ -8585,7 +8585,7 @@ ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 7023161cb1..33c576247c 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -5484,7 +5484,7 @@ ur_result_t UR_APICALL urEnqueueUSMPrefetch( const void *pMem, /// [in] size in bytes to be fetched size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] size of the event wait list uint32_t numEventsInWaitList, @@ -7504,7 +7504,7 @@ ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/test/conformance/enqueue/urEnqueueUSMPrefetch.cpp b/test/conformance/enqueue/urEnqueueUSMPrefetch.cpp index e5168057d4..3b7012785a 100644 --- a/test/conformance/enqueue/urEnqueueUSMPrefetch.cpp +++ b/test/conformance/enqueue/urEnqueueUSMPrefetch.cpp @@ -19,7 +19,8 @@ struct urEnqueueUSMPrefetchWithParamTest UUR_DEVICE_TEST_SUITE_WITH_PARAM( urEnqueueUSMPrefetchWithParamTest, - ::testing::Values(UR_USM_MIGRATION_FLAG_DEFAULT), + ::testing::Values(UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, + UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST), uur::deviceTestWithParamPrinter); TEST_P(urEnqueueUSMPrefetchWithParamTest, Success) { @@ -106,14 +107,14 @@ UUR_INSTANTIATE_DEVICE_TEST_SUITE(urEnqueueUSMPrefetchTest); TEST_P(urEnqueueUSMPrefetchTest, InvalidNullHandleQueue) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE, urEnqueueUSMPrefetch(nullptr, ptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, nullptr)); } TEST_P(urEnqueueUSMPrefetchTest, InvalidNullPointerMem) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER, urEnqueueUSMPrefetch(queue, nullptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, nullptr)); } @@ -127,23 +128,22 @@ TEST_P(urEnqueueUSMPrefetchTest, InvalidEnumeration) { TEST_P(urEnqueueUSMPrefetchTest, InvalidSizeZero) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_SIZE, urEnqueueUSMPrefetch(queue, ptr, 0, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, nullptr)); } TEST_P(urEnqueueUSMPrefetchTest, InvalidSizeTooLarge) { UUR_KNOWN_FAILURE_ON(uur::LevelZero{}, uur::LevelZeroV2{}, uur::NativeCPU{}); - ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_SIZE, urEnqueueUSMPrefetch(queue, ptr, allocation_size * 2, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, nullptr)); } TEST_P(urEnqueueUSMPrefetchTest, InvalidEventWaitList) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST, urEnqueueUSMPrefetch(queue, ptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 1, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 1, nullptr, nullptr)); ur_event_handle_t validEvent; @@ -151,14 +151,20 @@ TEST_P(urEnqueueUSMPrefetchTest, InvalidEventWaitList) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST, urEnqueueUSMPrefetch(queue, ptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, &validEvent, nullptr)); ur_event_handle_t inv_evt = nullptr; ASSERT_EQ_RESULT(urEnqueueUSMPrefetch(queue, ptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 1, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 1, &inv_evt, nullptr), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); ASSERT_SUCCESS(urEventRelease(validEvent)); } + +TEST_P(urEnqueueUSMPrefetchTest, InvalidMigrationFlag) { + ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_ENUMERATION, + urEnqueueUSMPrefetch(queue, ptr, allocation_size, 23, 0, + nullptr, nullptr)); +} \ No newline at end of file diff --git a/test/conformance/exp_command_buffer/commands.cpp b/test/conformance/exp_command_buffer/commands.cpp index e35ebb9d02..326b158984 100644 --- a/test/conformance/exp_command_buffer/commands.cpp +++ b/test/conformance/exp_command_buffer/commands.cpp @@ -124,8 +124,13 @@ TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendMemBufferFillExp) { TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendUSMPrefetchExp) { ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( - cmd_buf_handle, device_ptrs[0], allocation_size, 0, 0, nullptr, 0, - nullptr, nullptr, nullptr, nullptr)); + cmd_buf_handle, device_ptrs[0], allocation_size, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, 0, nullptr, nullptr, + nullptr, nullptr)); + ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( + cmd_buf_handle, device_ptrs[0], allocation_size, + UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST, 0, nullptr, 0, nullptr, nullptr, + nullptr, nullptr)); } TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendUSMAdviseExp) { diff --git a/test/conformance/exp_command_buffer/update/event_sync.cpp b/test/conformance/exp_command_buffer/update/event_sync.cpp index 0303da1f42..fafd6aee75 100644 --- a/test/conformance/exp_command_buffer/update/event_sync.cpp +++ b/test/conformance/exp_command_buffer/update/event_sync.cpp @@ -723,8 +723,8 @@ TEST_P(CommandEventSyncUpdateTest, USMPrefetchExp) { // Test prefetch command waiting on queue event ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( updatable_cmd_buf_handle, device_ptrs[1], allocation_size, - 0 /* migration flags*/, 0, nullptr, 1, &external_events[0], nullptr, - &external_events[1], &command_handles[0])); + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, 1, &external_events[0], + nullptr, &external_events[1], &command_handles[0])); ASSERT_NE(nullptr, command_handles[0]); ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,