@@ -58,6 +58,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
58
58
*Program // /< [out] pointer to handle of program object created.
59
59
) {
60
60
std::ignore = Properties;
61
+ UR_ASSERT (Context, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
62
+ UR_ASSERT (IL && Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
61
63
try {
62
64
ur_program_handle_t_ *UrProgram =
63
65
new ur_program_handle_t_ (ur_program_handle_t_::IL, Context, IL, Length);
@@ -82,8 +84,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
82
84
ur_program_handle_t
83
85
*Program // /< [out] pointer to handle of Program object created.
84
86
) {
85
- std::ignore = Device;
86
- std::ignore = Properties;
87
87
// In OpenCL, clCreateProgramWithBinary() can be used to load any of the
88
88
// following: "program executable", "compiled program", or "library of
89
89
// compiled programs". In addition, the loaded program can be either
@@ -96,8 +96,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
96
96
// information to distinguish the cases.
97
97
98
98
try {
99
- ur_program_handle_t_ *UrProgram = new ur_program_handle_t_ (
100
- ur_program_handle_t_::Native, Context, Binary, Size);
99
+ ur_program_handle_t_ *UrProgram =
100
+ new ur_program_handle_t_ (ur_program_handle_t_::Native, Context, Device,
101
+ Properties, Binary, Size);
101
102
*Program = reinterpret_cast <ur_program_handle_t >(UrProgram);
102
103
} catch (const std::bad_alloc &) {
103
104
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
@@ -597,11 +598,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
597
598
void **GlobalVariablePointerRet // /< [out] Returns the pointer to the global
598
599
// /< variable if it is found in the program.
599
600
) {
600
- std::ignore = Device;
601
601
std::scoped_lock<ur_shared_mutex> lock (Program->Mutex );
602
602
603
+ ze_module_handle_t ZeModuleEntry{};
604
+ ZeModuleEntry = Program->ZeModule ;
605
+ if (!Program->ZeModuleMap .empty ()) {
606
+ auto It = Program->ZeModuleMap .find (Device->ZeDevice );
607
+ if (It != Program->ZeModuleMap .end ()) {
608
+ ZeModuleEntry = It->second ;
609
+ }
610
+ }
611
+
603
612
ze_result_t ZeResult =
604
- zeModuleGetGlobalPointer (Program-> ZeModule , GlobalVariableName,
613
+ zeModuleGetGlobalPointer (ZeModuleEntry , GlobalVariableName,
605
614
GlobalVariableSizeRet, GlobalVariablePointerRet);
606
615
607
616
if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) {
@@ -632,11 +641,28 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
632
641
case UR_PROGRAM_INFO_CONTEXT:
633
642
return ReturnValue (Program->Context );
634
643
case UR_PROGRAM_INFO_NUM_DEVICES:
635
- // TODO: return true number of devices this program exists for.
636
- return ReturnValue (uint32_t {1 });
644
+ if (!Program->ZeModuleMap .empty ())
645
+ return ReturnValue (
646
+ uint32_t {ur_cast<uint32_t >(Program->ZeModuleMap .size ())});
647
+ else
648
+ return ReturnValue (uint32_t {1 });
637
649
case UR_PROGRAM_INFO_DEVICES:
638
- // TODO: return all devices this program exists for.
639
- return ReturnValue (Program->Context ->Devices [0 ]);
650
+ if (!Program->ZeModuleMap .empty ()) {
651
+ std::vector<ur_device_handle_t > devices;
652
+ for (auto &ZeModulePair : Program->ZeModuleMap ) {
653
+ auto It = Program->ZeModuleMap .find (ZeModulePair.first );
654
+ if (It != Program->ZeModuleMap .end ()) {
655
+ for (auto &Device : Program->Context ->Devices ) {
656
+ if (Device->ZeDevice == ZeModulePair.first ) {
657
+ devices.push_back (Device);
658
+ }
659
+ }
660
+ }
661
+ }
662
+ return ReturnValue (devices.data (), devices.size ());
663
+ } else {
664
+ return ReturnValue (Program->Context ->Devices [0 ]);
665
+ }
640
666
case UR_PROGRAM_INFO_BINARY_SIZES: {
641
667
std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
642
668
size_t SzBinary;
@@ -645,8 +671,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
645
671
Program->State == ur_program_handle_t_::Object) {
646
672
SzBinary = Program->CodeLength ;
647
673
} else if (Program->State == ur_program_handle_t_::Exe) {
648
- ZE2UR_CALL (zeModuleGetNativeBinary,
649
- (Program->ZeModule , &SzBinary, nullptr ));
674
+ if (!Program->ZeModuleMap .empty ()) {
675
+ std::vector<size_t > binarySizes;
676
+ for (auto &ZeModulePair : Program->ZeModuleMap ) {
677
+ size_t binarySize = 0 ;
678
+ ZE2UR_CALL (zeModuleGetNativeBinary,
679
+ (ZeModulePair.second , &binarySize, nullptr ));
680
+ binarySizes.push_back (binarySize);
681
+ }
682
+ return ReturnValue (binarySizes.data (), binarySizes.size ());
683
+ } else {
684
+ ZE2UR_CALL (zeModuleGetNativeBinary,
685
+ (Program->ZeModule , &SzBinary, nullptr ));
686
+ return ReturnValue (SzBinary);
687
+ }
650
688
} else {
651
689
return UR_RESULT_ERROR_INVALID_PROGRAM;
652
690
}
@@ -655,38 +693,71 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
655
693
}
656
694
case UR_PROGRAM_INFO_BINARIES: {
657
695
// The caller sets "ParamValue" to an array of pointers, one for each
658
- // device. Since Level Zero supports only one device, there is only one
659
- // pointer. If the pointer is NULL, we don't do anything. Otherwise, we
660
- // copy the program's binary image to the buffer at that pointer.
661
- uint8_t **PBinary = ur_cast<uint8_t **>(ProgramInfo);
662
- if (!PBinary[0 ])
663
- break ;
664
-
696
+ // device.
697
+ uint8_t **PBinary = nullptr ;
698
+ if (ProgramInfo) {
699
+ PBinary = ur_cast<uint8_t **>(ProgramInfo);
700
+ if (!PBinary[0 ]) {
701
+ break ;
702
+ }
703
+ }
665
704
std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
705
+ // If the caller is using a Program which is IL, Native or an object, then
706
+ // the program has not been built for multiple devices so a single IL is
707
+ // returned.
666
708
if (Program->State == ur_program_handle_t_::IL ||
667
709
Program->State == ur_program_handle_t_::Native ||
668
710
Program->State == ur_program_handle_t_::Object) {
669
- std::memcpy (PBinary[0 ], Program->Code .get (), Program->CodeLength );
711
+ if (PropSizeRet)
712
+ *PropSizeRet = Program->CodeLength ;
713
+ if (PBinary) {
714
+ std::memcpy (PBinary[0 ], Program->Code .get (), Program->CodeLength );
715
+ }
670
716
} else if (Program->State == ur_program_handle_t_::Exe) {
717
+ // If the caller is using a Program which is a built binary, then
718
+ // the program returned will either be a single module if this is a native
719
+ // binary or the native binary for each device will be returned.
671
720
size_t SzBinary = 0 ;
672
- ZE2UR_CALL (zeModuleGetNativeBinary,
673
- (Program->ZeModule , &SzBinary, PBinary[0 ]));
721
+ uint8_t *NativeBinaryPtr = nullptr ;
722
+ if (PBinary) {
723
+ NativeBinaryPtr = PBinary[0 ];
724
+ }
725
+ if (!Program->ZeModuleMap .empty ()) {
726
+ uint32_t deviceIndex = 0 ;
727
+ for (auto &ZeDeviceModule : Program->ZeModuleMap ) {
728
+ size_t binarySize = 0 ;
729
+ ZE2UR_CALL (
730
+ zeModuleGetNativeBinary,
731
+ (ZeDeviceModule.second , &binarySize, PBinary[deviceIndex++]));
732
+ SzBinary += binarySize;
733
+ }
734
+ } else {
735
+ ZE2UR_CALL (zeModuleGetNativeBinary,
736
+ (Program->ZeModule , &SzBinary, NativeBinaryPtr));
737
+ }
738
+ if (PropSizeRet)
739
+ *PropSizeRet = SzBinary;
674
740
} else {
675
741
return UR_RESULT_ERROR_INVALID_PROGRAM;
676
742
}
677
743
break ;
678
744
}
679
745
case UR_PROGRAM_INFO_NUM_KERNELS: {
680
746
std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
681
- uint32_t NumKernels;
747
+ uint32_t NumKernels = 0 ;
682
748
if (Program->State == ur_program_handle_t_::IL ||
683
749
Program->State == ur_program_handle_t_::Native ||
684
750
Program->State == ur_program_handle_t_::Object) {
685
751
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
686
752
} else if (Program->State == ur_program_handle_t_::Exe) {
687
- NumKernels = 0 ;
688
- ZE2UR_CALL (zeModuleGetKernelNames,
689
- (Program->ZeModule , &NumKernels, nullptr ));
753
+ if (!Program->ZeModuleMap .empty ()) {
754
+ ZE2UR_CALL (
755
+ zeModuleGetKernelNames,
756
+ (Program->ZeModuleMap .begin ()->second , &NumKernels, nullptr ));
757
+ } else {
758
+ ZE2UR_CALL (zeModuleGetKernelNames,
759
+ (Program->ZeModule , &NumKernels, nullptr ));
760
+ }
690
761
} else {
691
762
return UR_RESULT_ERROR_INVALID_PROGRAM;
692
763
}
@@ -702,11 +773,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
702
773
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
703
774
} else if (Program->State == ur_program_handle_t_::Exe) {
704
775
uint32_t Count = 0 ;
705
- ZE2UR_CALL (zeModuleGetKernelNames,
706
- (Program->ZeModule , &Count, nullptr ));
707
- std::unique_ptr<const char *[]> PNames (new const char *[Count]);
708
- ZE2UR_CALL (zeModuleGetKernelNames,
709
- (Program->ZeModule , &Count, PNames.get ()));
776
+ std::unique_ptr<const char *[]> PNames;
777
+ if (!Program->ZeModuleMap .empty ()) {
778
+ ZE2UR_CALL (zeModuleGetKernelNames,
779
+ (Program->ZeModuleMap .begin ()->second , &Count, nullptr ));
780
+ PNames = std::make_unique<const char *[]>(Count);
781
+ ZE2UR_CALL (
782
+ zeModuleGetKernelNames,
783
+ (Program->ZeModuleMap .begin ()->second , &Count, PNames.get ()));
784
+ } else {
785
+ ZE2UR_CALL (zeModuleGetKernelNames,
786
+ (Program->ZeModule , &Count, nullptr ));
787
+ PNames = std::make_unique<const char *[]>(Count);
788
+ ZE2UR_CALL (zeModuleGetKernelNames,
789
+ (Program->ZeModule , &Count, PNames.get ()));
790
+ }
710
791
for (uint32_t I = 0 ; I < Count; ++I) {
711
792
PINames += (I > 0 ? " ;" : " " );
712
793
PINames += PNames[I];
@@ -720,8 +801,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
720
801
} catch (...) {
721
802
return UR_RESULT_ERROR_UNKNOWN;
722
803
}
804
+ case UR_PROGRAM_INFO_SOURCE:
805
+ return ReturnValue (Program->Code .get ());
723
806
default :
724
- die ( " urProgramGetInfo: not implemented " ) ;
807
+ return UR_RESULT_ERROR_INVALID_ENUMERATION ;
725
808
}
726
809
727
810
return UR_RESULT_SUCCESS;
@@ -761,6 +844,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetBuildInfo(
761
844
// return for programs that were built outside and registered
762
845
// with urProgramRegister?
763
846
return ReturnValue (" " );
847
+ } else if (PropName == UR_PROGRAM_BUILD_INFO_STATUS) {
848
+ return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
764
849
} else if (PropName == UR_PROGRAM_BUILD_INFO_LOG) {
765
850
// Check first to see if the plugin code recorded an error message.
766
851
if (!Program->ErrorMessage .empty ()) {
@@ -852,6 +937,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithNativeHandle(
852
937
// /< program object created.
853
938
) {
854
939
std::ignore = Properties;
940
+ UR_ASSERT (Context && NativeProgram, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
941
+ UR_ASSERT (Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
855
942
auto ZeModule = ur_cast<ze_module_handle_t >(NativeProgram);
856
943
857
944
// We assume here that programs created from a native handle always
0 commit comments