@@ -209,7 +209,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
209
209
device_access_mode_t accessMode)
210
210
: ur_mem_handle_t_(hContext, size, accessMode),
211
211
deviceAllocations (hContext->getPlatform ()->getNumDevices()),
212
- activeAllocationDevice(nullptr ), hostAllocations() {
212
+ activeAllocationDevice(nullptr ), mapToPtr(hostPtr), hostAllocations() {
213
213
if (hostPtr) {
214
214
auto initialDevice = hContext->getDevices ()[0 ];
215
215
UR_CALL_THROWS (migrateBufferTo (initialDevice, hostPtr, size));
@@ -246,12 +246,18 @@ ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
246
246
if (!activeAllocationDevice || !writeBackPtr)
247
247
return ;
248
248
249
- auto srcPtr = ur_cast<char *>(
250
- deviceAllocations[activeAllocationDevice->Id .value ()].get ());
249
+ auto srcPtr = getActiveDeviceAlloc ();
251
250
synchronousZeCopy (hContext, activeAllocationDevice, writeBackPtr, srcPtr,
252
251
getSize ());
253
252
}
254
253
254
+ void *ur_discrete_mem_handle_t ::getActiveDeviceAlloc(size_t offset) {
255
+ assert (activeAllocationDevice);
256
+ return ur_cast<char *>(
257
+ deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
258
+ offset;
259
+ }
260
+
255
261
void *ur_discrete_mem_handle_t ::getDevicePtr(
256
262
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
257
263
size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
@@ -272,10 +278,8 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
272
278
hDevice = activeAllocationDevice;
273
279
}
274
280
275
- char *ptr;
276
281
if (activeAllocationDevice == hDevice) {
277
- ptr = ur_cast<char *>(deviceAllocations[hDevice->Id .value ()].get ());
278
- return ptr + offset;
282
+ return getActiveDeviceAlloc (offset);
279
283
}
280
284
281
285
auto &p2pDevices = hContext->getP2PDevices (hDevice);
@@ -288,9 +292,7 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
288
292
}
289
293
290
294
// TODO: see if it's better to migrate the memory to the specified device
291
- return ur_cast<char *>(
292
- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
293
- offset;
295
+ return getActiveDeviceAlloc (offset);
294
296
}
295
297
296
298
void *ur_discrete_mem_handle_t ::mapHostPtr(
@@ -299,55 +301,60 @@ void *ur_discrete_mem_handle_t::mapHostPtr(
299
301
TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::mapHostPtr" );
300
302
// TODO: use async alloc?
301
303
302
- void *ptr;
303
- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
304
- hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
304
+ void *ptr = mapToPtr;
305
+ if (!ptr) {
306
+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
307
+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
308
+ }
305
309
306
- hostAllocations.emplace_back (ptr, size, offset, flags);
310
+ usm_unique_ptr_t mappedPtr =
311
+ usm_unique_ptr_t (ptr, [ownsAlloc = bool (mapToPtr), this ](void *p) {
312
+ if (ownsAlloc) {
313
+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->free (p));
314
+ }
315
+ });
316
+
317
+ hostAllocations.emplace_back (std::move (mappedPtr), size, offset, flags);
307
318
308
319
if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
309
- auto srcPtr =
310
- ur_cast<char *>(
311
- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
312
- offset;
313
- migrate (srcPtr, hostAllocations.back ().ptr , size);
320
+ auto srcPtr = getActiveDeviceAlloc (offset);
321
+ migrate (srcPtr, hostAllocations.back ().ptr .get (), size);
314
322
}
315
323
316
- return hostAllocations.back ().ptr ;
324
+ return hostAllocations.back ().ptr . get () ;
317
325
}
318
326
319
327
void ur_discrete_mem_handle_t::unmapHostPtr (
320
328
void *pMappedPtr,
321
329
std::function<void (void *src, void *dst, size_t )> migrate) {
322
330
TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::unmapHostPtr" );
323
331
324
- for (auto &hostAllocation : hostAllocations) {
325
- if (hostAllocation.ptr == pMappedPtr) {
326
- void *devicePtr = nullptr ;
327
- if (activeAllocationDevice) {
328
- devicePtr =
329
- ur_cast<char *>(
330
- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
331
- hostAllocation.offset ;
332
- } else if (!(hostAllocation.flags &
333
- UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {
334
- devicePtr = ur_cast<char *>(getDevicePtr (
335
- hContext->getDevices ()[0 ], device_access_mode_t ::read_only,
336
- hostAllocation.offset , hostAllocation.size , migrate));
337
- }
332
+ auto hostAlloc =
333
+ std::find_if (hostAllocations.begin (), hostAllocations.end (),
334
+ [pMappedPtr](const host_allocation_desc_t &desc) {
335
+ return desc.ptr .get () == pMappedPtr;
336
+ });
338
337
339
- if (devicePtr ) {
340
- migrate (hostAllocation. ptr , devicePtr, hostAllocation. size ) ;
341
- }
338
+ if (hostAlloc == hostAllocations. end () ) {
339
+ throw UR_RESULT_ERROR_INVALID_ARGUMENT ;
340
+ }
342
341
343
- // TODO: use async free here?
344
- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->free (hostAllocation.ptr ));
345
- return ;
346
- }
342
+ bool shouldMigrateToDevice =
343
+ !(hostAlloc->flags & UR_MAP_FLAG_WRITE_INVALIDATE_REGION);
344
+
345
+ if (!activeAllocationDevice && shouldMigrateToDevice) {
346
+ allocateOnDevice (hContext->getDevices ()[0 ], getSize ());
347
+ }
348
+
349
+ // TODO: tests require that memory is migrated even for
350
+ // UR_MAP_FLAG_WRITE_INVALIDATE_REGION when there is an active device
351
+ // allocation. is this correct?
352
+ if (activeAllocationDevice) {
353
+ migrate (hostAlloc->ptr .get (), getActiveDeviceAlloc (hostAlloc->offset ),
354
+ hostAlloc->size );
347
355
}
348
356
349
- // No mapping found
350
- throw UR_RESULT_ERROR_INVALID_ARGUMENT;
357
+ hostAllocations.erase (hostAlloc);
351
358
}
352
359
353
360
static bool useHostBuffer (ur_context_handle_t hContext) {
0 commit comments