Skip to content

Commit fb3cbd1

Browse files
authored
Merge pull request #1569 from lbushi25/fix_usm_allocation
[L0] Fix usm allocation functions when alignment is not a power of 2
2 parents 6d676f2 + 2d02c21 commit fb3cbd1

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

source/adapters/level_zero/usm.cpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
306306
uint32_t Align = USMDesc ? USMDesc->align : 0;
307307
// L0 supports alignment up to 64KB and silently ignores higher values.
308308
// We flag alignment > 64KB as an invalid value.
309-
if (Align > 65536)
309+
// L0 spec says that alignment values that are not powers of 2 are invalid.
310+
if (Align > 65536 || (Align & (Align - 1)) != 0)
310311
return UR_RESULT_ERROR_INVALID_VALUE;
311312

312313
ur_platform_handle_t Plt = Context->getPlatform();
@@ -335,11 +336,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
335336
// find the allocator depending on context as we do for Shared and Device
336337
// allocations.
337338
umf_memory_pool_handle_t hPoolInternal = nullptr;
338-
if (!UseUSMAllocator ||
339-
// L0 spec says that allocation fails if Alignment != 2^n, in order to
340-
// keep the same behavior for the allocator, just call L0 API directly and
341-
// return the error code.
342-
((Align & (Align - 1)) != 0)) {
339+
if (!UseUSMAllocator) {
343340
hPoolInternal = Context->HostMemProxyPool.get();
344341
} else if (Pool) {
345342
hPoolInternal = Pool->HostMemPool.get();
@@ -379,7 +376,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
379376

380377
// L0 supports alignment up to 64KB and silently ignores higher values.
381378
// We flag alignment > 64KB as an invalid value.
382-
if (Alignment > 65536)
379+
// L0 spec says that alignment values that are not powers of 2 are invalid.
380+
if (Alignment > 65536 || (Alignment & (Alignment - 1)) != 0)
383381
return UR_RESULT_ERROR_INVALID_VALUE;
384382

385383
ur_platform_handle_t Plt = Device->Platform;
@@ -406,11 +404,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
406404
}
407405

408406
umf_memory_pool_handle_t hPoolInternal = nullptr;
409-
if (!UseUSMAllocator ||
410-
// L0 spec says that allocation fails if Alignment != 2^n, in order to
411-
// keep the same behavior for the allocator, just call L0 API directly and
412-
// return the error code.
413-
((Alignment & (Alignment - 1)) != 0)) {
407+
if (!UseUSMAllocator) {
414408
auto It = Context->DeviceMemProxyPools.find(Device->ZeDevice);
415409
if (It == Context->DeviceMemProxyPools.end())
416410
return UR_RESULT_ERROR_INVALID_VALUE;
@@ -483,7 +477,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
483477

484478
// L0 supports alignment up to 64KB and silently ignores higher values.
485479
// We flag alignment > 64KB as an invalid value.
486-
if (Alignment > 65536)
480+
// L0 spec says that alignment values that are not powers of 2 are invalid.
481+
if (Alignment > 65536 || (Alignment & (Alignment - 1)) != 0)
487482
return UR_RESULT_ERROR_INVALID_VALUE;
488483

489484
ur_platform_handle_t Plt = Device->Platform;
@@ -506,11 +501,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
506501
}
507502

508503
umf_memory_pool_handle_t hPoolInternal = nullptr;
509-
if (!UseUSMAllocator ||
510-
// L0 spec says that allocation fails if Alignment != 2^n, in order to
511-
// keep the same behavior for the allocator, just call L0 API directly and
512-
// return the error code.
513-
((Alignment & (Alignment - 1)) != 0)) {
504+
if (!UseUSMAllocator) {
514505
auto &Allocator = (DeviceReadOnly ? Context->SharedReadOnlyMemProxyPools
515506
: Context->SharedMemProxyPools);
516507
auto It = Allocator.find(Device->ZeDevice);

0 commit comments

Comments
 (0)