Skip to content

Commit 83f7ad9

Browse files
Merge pull request #1860 from PietroGhg/pietro/fill
[NATIVECPU] Fix pointer arithmetic in USMfill
2 parents ab9baf5 + 8fb6824 commit 83f7ad9

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

source/adapters/native_cpu/enqueue.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
511511

512512
UR_ASSERT(ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER);
513513
UR_ASSERT(pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER);
514-
UR_ASSERT(size % patternSize == 0 || patternSize > size,
515-
UR_RESULT_ERROR_INVALID_SIZE);
514+
UR_ASSERT(patternSize != 0, UR_RESULT_ERROR_INVALID_SIZE)
515+
UR_ASSERT(size != 0, UR_RESULT_ERROR_INVALID_SIZE)
516+
UR_ASSERT(patternSize < size, UR_RESULT_ERROR_INVALID_SIZE)
517+
UR_ASSERT(size % patternSize == 0, UR_RESULT_ERROR_INVALID_SIZE)
518+
// TODO: add check for allocation size once the query is supported
516519

517520
switch (patternSize) {
518521
case 1:
@@ -522,33 +525,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
522525
const auto pattern = *static_cast<const uint16_t *>(pPattern);
523526
auto *start = reinterpret_cast<uint16_t *>(ptr);
524527
auto *end =
525-
reinterpret_cast<uint16_t *>(reinterpret_cast<uint16_t *>(ptr) + size);
528+
reinterpret_cast<uint16_t *>(reinterpret_cast<uint8_t *>(ptr) + size);
526529
std::fill(start, end, pattern);
527530
break;
528531
}
529532
case 4: {
530533
const auto pattern = *static_cast<const uint32_t *>(pPattern);
531534
auto *start = reinterpret_cast<uint32_t *>(ptr);
532535
auto *end =
533-
reinterpret_cast<uint32_t *>(reinterpret_cast<uint32_t *>(ptr) + size);
536+
reinterpret_cast<uint32_t *>(reinterpret_cast<uint8_t *>(ptr) + size);
534537
std::fill(start, end, pattern);
535538
break;
536539
}
537540
case 8: {
538541
const auto pattern = *static_cast<const uint64_t *>(pPattern);
539542
auto *start = reinterpret_cast<uint64_t *>(ptr);
540543
auto *end =
541-
reinterpret_cast<uint64_t *>(reinterpret_cast<uint64_t *>(ptr) + size);
544+
reinterpret_cast<uint64_t *>(reinterpret_cast<uint8_t *>(ptr) + size);
542545
std::fill(start, end, pattern);
543546
break;
544547
}
545-
default:
546-
for (unsigned int step{0}; step < size; ++step) {
547-
auto *dest = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(ptr) +
548-
step * patternSize);
548+
default: {
549+
for (unsigned int step{0}; step < size; step += patternSize) {
550+
auto *dest =
551+
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(ptr) + step);
549552
memcpy(dest, pPattern, patternSize);
550553
}
551554
}
555+
}
552556
return UR_RESULT_SUCCESS;
553557
}
554558

@@ -583,7 +587,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
583587
std::ignore = phEventWaitList;
584588
std::ignore = phEvent;
585589

586-
DIE_NO_IMPLEMENTATION;
590+
// TODO: properly implement USM prefetch
591+
return UR_RESULT_SUCCESS;
587592
}
588593

589594
UR_APIEXPORT ur_result_t UR_APICALL
@@ -595,7 +600,8 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
595600
std::ignore = advice;
596601
std::ignore = phEvent;
597602

598-
DIE_NO_IMPLEMENTATION;
603+
// TODO: properly implement USM advise
604+
return UR_RESULT_SUCCESS;
599605
}
600606

601607
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(

0 commit comments

Comments
 (0)