@@ -511,8 +511,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
511
511
512
512
UR_ASSERT (ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER);
513
513
UR_ASSERT (pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER);
514
- UR_ASSERT (size % patternSize == 0 || patternSize > size,
515
- UR_RESULT_ERROR_INVALID_SIZE);
514
+ UR_ASSERT (patternSize != 0 , UR_RESULT_ERROR_INVALID_SIZE)
515
+ UR_ASSERT (size != 0 , UR_RESULT_ERROR_INVALID_SIZE)
516
+ UR_ASSERT (patternSize < size, UR_RESULT_ERROR_INVALID_SIZE)
517
+ UR_ASSERT (size % patternSize == 0 , UR_RESULT_ERROR_INVALID_SIZE)
518
+ // TODO: add check for allocation size once the query is supported
516
519
517
520
switch (patternSize) {
518
521
case 1 :
@@ -522,33 +525,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
522
525
const auto pattern = *static_cast <const uint16_t *>(pPattern);
523
526
auto *start = reinterpret_cast <uint16_t *>(ptr);
524
527
auto *end =
525
- reinterpret_cast <uint16_t *>(reinterpret_cast <uint16_t *>(ptr) + size);
528
+ reinterpret_cast <uint16_t *>(reinterpret_cast <uint8_t *>(ptr) + size);
526
529
std::fill (start, end, pattern);
527
530
break ;
528
531
}
529
532
case 4 : {
530
533
const auto pattern = *static_cast <const uint32_t *>(pPattern);
531
534
auto *start = reinterpret_cast <uint32_t *>(ptr);
532
535
auto *end =
533
- reinterpret_cast <uint32_t *>(reinterpret_cast <uint32_t *>(ptr) + size);
536
+ reinterpret_cast <uint32_t *>(reinterpret_cast <uint8_t *>(ptr) + size);
534
537
std::fill (start, end, pattern);
535
538
break ;
536
539
}
537
540
case 8 : {
538
541
const auto pattern = *static_cast <const uint64_t *>(pPattern);
539
542
auto *start = reinterpret_cast <uint64_t *>(ptr);
540
543
auto *end =
541
- reinterpret_cast <uint64_t *>(reinterpret_cast <uint64_t *>(ptr) + size);
544
+ reinterpret_cast <uint64_t *>(reinterpret_cast <uint8_t *>(ptr) + size);
542
545
std::fill (start, end, pattern);
543
546
break ;
544
547
}
545
- default :
546
- for (unsigned int step{0 }; step < size; ++ step) {
547
- auto *dest = reinterpret_cast < void *>( reinterpret_cast < uint8_t *>(ptr) +
548
- step * patternSize );
548
+ default : {
549
+ for (unsigned int step{0 }; step < size; step += patternSize ) {
550
+ auto *dest =
551
+ reinterpret_cast < void *>( reinterpret_cast < uint8_t *>(ptr) + step);
549
552
memcpy (dest, pPattern, patternSize);
550
553
}
551
554
}
555
+ }
552
556
return UR_RESULT_SUCCESS;
553
557
}
554
558
@@ -583,7 +587,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
583
587
std::ignore = phEventWaitList;
584
588
std::ignore = phEvent;
585
589
586
- DIE_NO_IMPLEMENTATION;
590
+ // TODO: properly implement USM prefetch
591
+ return UR_RESULT_SUCCESS;
587
592
}
588
593
589
594
UR_APIEXPORT ur_result_t UR_APICALL
@@ -595,7 +600,8 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
595
600
std::ignore = advice;
596
601
std::ignore = phEvent;
597
602
598
- DIE_NO_IMPLEMENTATION;
603
+ // TODO: properly implement USM advise
604
+ return UR_RESULT_SUCCESS;
599
605
}
600
606
601
607
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D (
0 commit comments