diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.cpp b/unified-runtime/source/adapters/level_zero/v2/memory.cpp index b1f3829dd6967..8f70a04dad210 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.cpp @@ -55,14 +55,11 @@ void ur_usm_handle_t::unmapHostPtr(void * /*pMappedPtr*/, ur_integrated_buffer_handle_t::ur_integrated_buffer_handle_t( ur_context_handle_t hContext, void *hostPtr, size_t size, - host_ptr_action_t hostPtrAction, device_access_mode_t accessMode) + device_access_mode_t accessMode) : ur_mem_buffer_t(hContext, size, accessMode) { - bool hostPtrImported = false; - if (hostPtrAction == host_ptr_action_t::import) { - hostPtrImported = - maybeImportUSM(hContext->getPlatform()->ZeDriverHandleExpTranslated, - hContext->getZeHandle(), hostPtr, size); - } + bool hostPtrImported = + maybeImportUSM(hContext->getPlatform()->ZeDriverHandleExpTranslated, + hContext->getZeHandle(), hostPtr, size); if (hostPtrImported) { this->ptr = usm_unique_ptr_t(hostPtr, [hContext](void *ptr) { @@ -201,8 +198,23 @@ ur_discrete_buffer_handle_t::ur_discrete_buffer_handle_t( device_access_mode_t accessMode) : ur_mem_buffer_t(hContext, size, accessMode), deviceAllocations(hContext->getPlatform()->getNumDevices()), - activeAllocationDevice(nullptr), mapToPtr(hostPtr), hostAllocations() { + activeAllocationDevice(nullptr), mapToPtr(nullptr, nullptr), + hostAllocations() { if (hostPtr) { + // Try importing the pointer to speed up memory copies for map/unmap + bool hostPtrImported = + maybeImportUSM(hContext->getPlatform()->ZeDriverHandleExpTranslated, + hContext->getZeHandle(), hostPtr, size); + + if (hostPtrImported) { + mapToPtr = usm_unique_ptr_t(hostPtr, [hContext](void *ptr) { + ZeUSMImport.doZeUSMRelease( + hContext->getPlatform()->ZeDriverHandleExpTranslated, ptr); + }); + } else { + mapToPtr = usm_unique_ptr_t(hostPtr, [](void *) {}); + } + auto initialDevice = hContext->getDevices()[0]; UR_CALL_THROWS(migrateBufferTo(initialDevice, hostPtr, size)); } @@ -305,18 +317,18 @@ void *ur_discrete_buffer_handle_t::mapHostPtr(ur_map_flags_t flags, TRACK_SCOPE_LATENCY("ur_discrete_buffer_handle_t::mapHostPtr"); // TODO: use async alloc? - void *ptr = mapToPtr; + void *ptr = mapToPtr.get(); if (!ptr) { UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate( hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &ptr)); } usm_unique_ptr_t mappedPtr = - usm_unique_ptr_t(ptr, [ownsAlloc = bool(mapToPtr), this](void *p) { + usm_unique_ptr_t(ptr, [ownsAlloc = !bool(mapToPtr), this](void *p) { if (ownsAlloc) { auto ret = hContext->getDefaultUSMPool()->free(p); if (ret != UR_RESULT_SUCCESS) { - UR_LOG(ERR, "Failed to mapped memory: {}", ret); + UR_LOG(ERR, "Failed to free mapped memory: {}", ret); } } }); @@ -541,16 +553,16 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext, // ignore the flag for now. } + if (flags & UR_MEM_FLAG_USE_HOST_POINTER) { + // To speed up copies, we always import the host ptr to USM memory + } + void *hostPtr = pProperties ? pProperties->pHost : nullptr; auto accessMode = ur_mem_buffer_t::getDeviceAccessMode(flags); if (useHostBuffer(hContext)) { - auto hostPtrAction = - flags & UR_MEM_FLAG_USE_HOST_POINTER - ? ur_integrated_buffer_handle_t::host_ptr_action_t::import - : ur_integrated_buffer_handle_t::host_ptr_action_t::copy; *phBuffer = ur_mem_handle_t_::create( - hContext, hostPtr, size, hostPtrAction, accessMode); + hContext, hostPtr, size, accessMode); } else { *phBuffer = ur_mem_handle_t_::create( hContext, hostPtr, size, accessMode); diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.hpp b/unified-runtime/source/adapters/level_zero/v2/memory.hpp index 9c0dc66ef72b4..ea9d75a2acbf7 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.hpp @@ -27,7 +27,7 @@ struct ur_mem_buffer_t : ur_object { enum class device_access_mode_t { read_write, read_only, write_only }; ur_mem_buffer_t(ur_context_handle_t hContext, size_t size, - device_access_mode_t accesMode); + device_access_mode_t accessMode); virtual ~ur_mem_buffer_t() = default; virtual ur_shared_mutex &getMutex(); @@ -89,14 +89,11 @@ struct ur_usm_handle_t : ur_mem_buffer_t { // For integrated devices the buffer has been allocated in host memory // and can be accessed by the device without copying. struct ur_integrated_buffer_handle_t : ur_mem_buffer_t { - enum class host_ptr_action_t { import, copy }; - ur_integrated_buffer_handle_t(ur_context_handle_t hContext, void *hostPtr, - size_t size, host_ptr_action_t useHostPtr, - device_access_mode_t accesMode); + size_t size, device_access_mode_t accessMode); ur_integrated_buffer_handle_t(ur_context_handle_t hContext, void *hostPtr, - size_t size, device_access_mode_t accesMode, + size_t size, device_access_mode_t accessMode, bool ownHostPtr); ~ur_integrated_buffer_handle_t(); @@ -133,13 +130,13 @@ struct ur_discrete_buffer_handle_t : ur_mem_buffer_t { // first device in the context. Otherwise, the buffer is allocated on // firt getDevicePtr call. ur_discrete_buffer_handle_t(ur_context_handle_t hContext, void *hostPtr, - size_t size, device_access_mode_t accesMode); + size_t size, device_access_mode_t accessMode); ~ur_discrete_buffer_handle_t(); // Create buffer on top of existing device memory. ur_discrete_buffer_handle_t(ur_context_handle_t hContext, ur_device_handle_t hDevice, void *devicePtr, - size_t size, device_access_mode_t accesMode, + size_t size, device_access_mode_t accessMode, void *writeBackMemory, bool ownDevicePtr); void *getDevicePtr(ur_device_handle_t, device_access_mode_t, size_t offset, @@ -165,7 +162,7 @@ struct ur_discrete_buffer_handle_t : ur_mem_buffer_t { void *writeBackPtr = nullptr; // If not null, mapHostPtr should map memory to this ptr - void *mapToPtr = nullptr; + usm_unique_ptr_t mapToPtr; std::vector hostAllocations; @@ -177,7 +174,7 @@ struct ur_discrete_buffer_handle_t : ur_mem_buffer_t { struct ur_shared_buffer_handle_t : ur_mem_buffer_t { ur_shared_buffer_handle_t(ur_context_handle_t hContext, void *devicePtr, - size_t size, device_access_mode_t accesMode, + size_t size, device_access_mode_t accessMode, bool ownDevicePtr); void *getDevicePtr(ur_device_handle_t, device_access_mode_t, size_t offset, @@ -195,7 +192,7 @@ struct ur_shared_buffer_handle_t : ur_mem_buffer_t { struct ur_mem_sub_buffer_t : ur_mem_buffer_t { ur_mem_sub_buffer_t(ur_mem_handle_t hParent, size_t offset, size_t size, - device_access_mode_t accesMode); + device_access_mode_t accessMode); ~ur_mem_sub_buffer_t(); void *getDevicePtr(ur_device_handle_t, device_access_mode_t, size_t offset,