@@ -28,10 +28,8 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
28
28
}
29
29
30
30
if (!hostPtrImported) {
31
- // TODO: use UMF
32
- ZeStruct<ze_host_mem_alloc_desc_t > hostDesc;
33
- ZE2UR_CALL_THROWS (zeMemAllocHost, (hContext->getZeHandle (), &hostDesc, size,
34
- 0 , &this ->ptr ));
31
+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
32
+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &this ->ptr ));
35
33
36
34
if (hostPtr) {
37
35
std::memcpy (this ->ptr , hostPtr, size);
@@ -40,9 +38,11 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
40
38
}
41
39
42
40
ur_host_mem_handle_t ::~ur_host_mem_handle_t () {
43
- // TODO: use UMF API here
44
41
if (ptr) {
45
- ZE_CALL_NOCHECK (zeMemFree, (hContext->getZeHandle (), ptr));
42
+ auto ret = hContext->getDefaultUSMPool ()->free (ptr);
43
+ if (ret != UR_RESULT_SUCCESS) {
44
+ logger::error (" Failed to free host memory: {}" , ret);
45
+ }
46
46
}
47
47
}
48
48
@@ -51,55 +51,80 @@ void *ur_host_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
51
51
return ptr;
52
52
}
53
53
54
+ ur_result_t ur_device_mem_handle_t::migrateBufferTo (ur_device_handle_t hDevice,
55
+ void *src, size_t size) {
56
+ auto Id = hDevice->Id .value ();
57
+
58
+ if (!deviceAllocations[Id]) {
59
+ UR_CALL (hContext->getDefaultUSMPool ()->allocate (hContext, hDevice, nullptr ,
60
+ UR_USM_TYPE_DEVICE, size,
61
+ &deviceAllocations[Id]));
62
+ }
63
+
64
+ auto commandList = hContext->commandListCache .getImmediateCommandList (
65
+ hDevice->ZeDevice , true ,
66
+ hDevice
67
+ ->QueueGroup [ur_device_handle_t_::queue_group_info_t ::type::Compute]
68
+ .ZeOrdinal ,
69
+ ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
70
+ std::nullopt);
71
+
72
+ ZE2UR_CALL (zeCommandListAppendMemoryCopy,
73
+ (commandList.get (), deviceAllocations[Id], src, size, nullptr , 0 ,
74
+ nullptr ));
75
+
76
+ activeAllocationDevice = hDevice;
77
+
78
+ return UR_RESULT_SUCCESS;
79
+ }
80
+
54
81
ur_device_mem_handle_t ::ur_device_mem_handle_t (ur_context_handle_t hContext,
55
82
void *hostPtr, size_t size)
56
83
: ur_mem_handle_t_(hContext, size),
57
- deviceAllocations (hContext->getPlatform ()->getNumDevices()) {
58
- // Legacy adapter allocated the memory directly on a device (first on the
59
- // contxt) and if the buffer is used on another device, memory is migrated
60
- // (depending on an env var setting).
61
- //
62
- // TODO: port this behavior or figure out if it makes sense to keep the memory
63
- // in a host buffer (e.g. for smaller sizes).
84
+ deviceAllocations (hContext->getPlatform ()->getNumDevices()),
85
+ activeAllocationDevice(nullptr ) {
64
86
if (hostPtr) {
65
- buffer. assign ( reinterpret_cast < char *>(hostPtr),
66
- reinterpret_cast < char *>( hostPtr) + size);
87
+ auto initialDevice = hContext-> getDevices ()[ 0 ];
88
+ UR_CALL_THROWS ( migrateBufferTo (initialDevice, hostPtr, size) );
67
89
}
68
90
}
69
91
70
92
ur_device_mem_handle_t ::~ur_device_mem_handle_t () {
71
- // TODO: use UMF API here
72
93
for (auto &ptr : deviceAllocations) {
73
94
if (ptr) {
74
- ZE_CALL_NOCHECK (zeMemFree, (hContext->getZeHandle (), ptr));
95
+ auto ret = hContext->getDefaultUSMPool ()->free (ptr);
96
+ if (ret != UR_RESULT_SUCCESS) {
97
+ logger::error (" Failed to free device memory: {}" , ret);
98
+ }
75
99
}
76
100
}
77
101
}
78
102
79
103
void *ur_device_mem_handle_t ::getPtr(ur_device_handle_t hDevice) {
80
104
std::lock_guard lock (this ->Mutex );
81
105
82
- auto &ptr = deviceAllocations[hDevice->Id .value ()];
83
- if (!ptr) {
84
- ZeStruct<ze_device_mem_alloc_desc_t > deviceDesc;
85
- ZE2UR_CALL_THROWS (zeMemAllocDevice, (hContext->getZeHandle (), &deviceDesc,
86
- size, 0 , hDevice->ZeDevice , &ptr));
87
-
88
- if (!buffer.empty ()) {
89
- auto commandList = hContext->commandListCache .getImmediateCommandList (
90
- hDevice->ZeDevice , true ,
91
- hDevice
92
- ->QueueGroup
93
- [ur_device_handle_t_::queue_group_info_t ::type::Compute]
94
- .ZeOrdinal ,
95
- ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
96
- std::nullopt);
97
- ZE2UR_CALL_THROWS (
98
- zeCommandListAppendMemoryCopy,
99
- (commandList.get (), ptr, buffer.data (), size, nullptr , 0 , nullptr ));
100
- }
106
+ if (!activeAllocationDevice) {
107
+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
108
+ hContext, hDevice, nullptr , UR_USM_TYPE_DEVICE, getSize (),
109
+ &deviceAllocations[hDevice->Id .value ()]));
110
+ activeAllocationDevice = hDevice;
101
111
}
102
- return ptr;
112
+
113
+ if (activeAllocationDevice == hDevice) {
114
+ return deviceAllocations[hDevice->Id .value ()];
115
+ }
116
+
117
+ auto &p2pDevices = hContext->getP2PDevices (hDevice);
118
+ auto p2pAccessible = std::find (p2pDevices.begin (), p2pDevices.end (),
119
+ activeAllocationDevice) != p2pDevices.end ();
120
+
121
+ if (!p2pAccessible) {
122
+ // TODO: migrate buffer through the host
123
+ throw UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
124
+ }
125
+
126
+ // TODO: see if it's better to migrate the memory to the specified device
127
+ return deviceAllocations[activeAllocationDevice->Id .value ()];
103
128
}
104
129
105
130
namespace ur ::level_zero {
0 commit comments