@@ -710,7 +710,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
710
710
ZeKernelDesc.pKernelName = KernelName;
711
711
712
712
ze_kernel_handle_t ZeKernel;
713
- ZE2UR_CALL (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
713
+ auto ZeResult =
714
+ ZE_CALL_NOCHECK (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
715
+ // Gracefully handle the case that kernel create fails.
716
+ if (ZeResult != ZE_RESULT_SUCCESS) {
717
+ delete *RetKernel;
718
+ *RetKernel = nullptr ;
719
+ return ze2urResult (ZeResult);
720
+ }
714
721
715
722
auto ZeDevice = It.first ;
716
723
@@ -764,20 +771,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
764
771
PArgValue = nullptr ;
765
772
}
766
773
774
+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
775
+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
776
+ }
777
+
767
778
std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
779
+ ze_result_t ZeResult = ZE_RESULT_SUCCESS;
768
780
if (Kernel->ZeKernelMap .empty ()) {
769
781
auto ZeKernel = Kernel->ZeKernel ;
770
- ZE2UR_CALL (zeKernelSetArgumentValue,
771
- (ZeKernel, ArgIndex, ArgSize, PArgValue));
782
+ ZeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
783
+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
772
784
} else {
773
785
for (auto It : Kernel->ZeKernelMap ) {
774
786
auto ZeKernel = It.second ;
775
- ZE2UR_CALL (zeKernelSetArgumentValue,
776
- (ZeKernel, ArgIndex, ArgSize, PArgValue));
787
+ ZeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
788
+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
777
789
}
778
790
}
779
791
780
- return UR_RESULT_SUCCESS;
792
+ if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
793
+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
794
+ }
795
+
796
+ return ze2urResult (ZeResult);
781
797
}
782
798
783
799
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgLocal (
@@ -826,6 +842,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(
826
842
} catch (...) {
827
843
return UR_RESULT_ERROR_UNKNOWN;
828
844
}
845
+ case UR_KERNEL_INFO_NUM_REGS:
829
846
case UR_KERNEL_INFO_NUM_ARGS:
830
847
return ReturnValue (uint32_t {Kernel->ZeKernelProperties ->numKernelArgs });
831
848
case UR_KERNEL_INFO_REFERENCE_COUNT:
@@ -1076,6 +1093,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgSampler(
1076
1093
) {
1077
1094
std::ignore = Properties;
1078
1095
std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
1096
+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
1097
+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
1098
+ }
1079
1099
ZE2UR_CALL (zeKernelSetArgumentValue, (Kernel->ZeKernel , ArgIndex,
1080
1100
sizeof (void *), &ArgValue->ZeSampler ));
1081
1101
@@ -1095,6 +1115,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgMemObj(
1095
1115
// The ArgValue may be a NULL pointer in which case a NULL value is used for
1096
1116
// the kernel argument declared as a pointer to global or constant memory.
1097
1117
1118
+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
1119
+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
1120
+ }
1121
+
1098
1122
ur_mem_handle_t_ *UrMem = ur_cast<ur_mem_handle_t_ *>(ArgValue);
1099
1123
1100
1124
ur_mem_handle_t_::access_mode_t UrAccessMode = ur_mem_handle_t_::read_write;
0 commit comments