@@ -599,11 +599,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
599
599
void **GlobalVariablePointerRet // /< [out] Returns the pointer to the global
600
600
// /< variable if it is found in the program.
601
601
) {
602
- std::ignore = Device;
603
602
std::scoped_lock<ur_shared_mutex> lock (Program->Mutex );
604
603
604
+ ze_module_handle_t ZeModuleEntry{};
605
+ ZeModuleEntry = Program->ZeModule ;
606
+ if (!Program->ZeModuleMap .empty ()) {
607
+ auto It = Program->ZeModuleMap .find (Device->ZeDevice );
608
+ if (It != Program->ZeModuleMap .end ()) {
609
+ ZeModuleEntry = It->second ;
610
+ }
611
+ }
612
+
605
613
ze_result_t ZeResult =
606
- zeModuleGetGlobalPointer (Program-> ZeModule , GlobalVariableName,
614
+ zeModuleGetGlobalPointer (ZeModuleEntry , GlobalVariableName,
607
615
GlobalVariableSizeRet, GlobalVariablePointerRet);
608
616
609
617
if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) {
@@ -634,11 +642,28 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
634
642
case UR_PROGRAM_INFO_CONTEXT:
635
643
return ReturnValue (Program->Context );
636
644
case UR_PROGRAM_INFO_NUM_DEVICES:
637
- // TODO: return true number of devices this program exists for.
638
- return ReturnValue (uint32_t {1 });
645
+ if (!Program->ZeModuleMap .empty ())
646
+ return ReturnValue (
647
+ uint32_t {ur_cast<uint32_t >(Program->ZeModuleMap .size ())});
648
+ else
649
+ return ReturnValue (uint32_t {1 });
639
650
case UR_PROGRAM_INFO_DEVICES:
640
- // TODO: return all devices this program exists for.
641
- return ReturnValue (Program->Context ->Devices [0 ]);
651
+ if (!Program->ZeModuleMap .empty ()) {
652
+ std::vector<ur_device_handle_t > devices;
653
+ for (auto &ZeModulePair : Program->ZeModuleMap ) {
654
+ auto It = Program->ZeModuleMap .find (ZeModulePair.first );
655
+ if (It != Program->ZeModuleMap .end ()) {
656
+ for (auto &Device : Program->Context ->Devices ) {
657
+ if (Device->ZeDevice == ZeModulePair.first ) {
658
+ devices.push_back (Device);
659
+ }
660
+ }
661
+ }
662
+ }
663
+ return ReturnValue (devices);
664
+ } else {
665
+ return ReturnValue (Program->Context ->Devices [0 ]);
666
+ }
642
667
case UR_PROGRAM_INFO_BINARY_SIZES: {
643
668
std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
644
669
size_t SzBinary;
@@ -647,8 +672,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
647
672
Program->State == ur_program_handle_t_::Object) {
648
673
SzBinary = Program->CodeLength ;
649
674
} else if (Program->State == ur_program_handle_t_::Exe) {
650
- ZE2UR_CALL (zeModuleGetNativeBinary,
651
- (Program->ZeModule , &SzBinary, nullptr ));
675
+ if (!Program->ZeModuleMap .empty ()) {
676
+ std::vector<size_t > binarySizes;
677
+ for (auto &ZeModulePair : Program->ZeModuleMap ) {
678
+ size_t binarySize = 0 ;
679
+ ZE2UR_CALL (zeModuleGetNativeBinary,
680
+ (ZeModulePair.second , &binarySize, nullptr ));
681
+ binarySizes.push_back (binarySize);
682
+ }
683
+ return ReturnValue (binarySizes);
684
+ } else {
685
+ ZE2UR_CALL (zeModuleGetNativeBinary,
686
+ (Program->ZeModule , &SzBinary, nullptr ));
687
+ return ReturnValue (SzBinary);
688
+ }
652
689
} else {
653
690
return UR_RESULT_ERROR_INVALID_PROGRAM;
654
691
}
@@ -657,9 +694,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
657
694
}
658
695
case UR_PROGRAM_INFO_BINARIES: {
659
696
// The caller sets "ParamValue" to an array of pointers, one for each
660
- // device. Since Level Zero supports only one device, there is only one
661
- // pointer. If the pointer is NULL, we don't do anything. Otherwise, we
662
- // copy the program's binary image to the buffer at that pointer.
697
+ // device.
663
698
uint8_t **PBinary = nullptr ;
664
699
if (ProgramInfo) {
665
700
PBinary = ur_cast<uint8_t **>(ProgramInfo);
@@ -668,6 +703,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
668
703
}
669
704
}
670
705
std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
706
+ // If the caller is using a Program which is IL, Native or an object, then
707
+ // the program has not been built for multiple devices so a single IL is
708
+ // returned.
671
709
if (Program->State == ur_program_handle_t_::IL ||
672
710
Program->State == ur_program_handle_t_::Native ||
673
711
Program->State == ur_program_handle_t_::Object) {
@@ -677,13 +715,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
677
715
std::memcpy (PBinary[0 ], Program->Code .get (), Program->CodeLength );
678
716
}
679
717
} else if (Program->State == ur_program_handle_t_::Exe) {
718
+ // If the caller is using a Program which is a built binary, then
719
+ // the program returned will either be a single module if this is a native
720
+ // binary or the native binary for each device will be returned.
680
721
size_t SzBinary = 0 ;
681
722
uint8_t *NativeBinaryPtr = nullptr ;
682
723
if (PBinary) {
683
724
NativeBinaryPtr = PBinary[0 ];
684
725
}
685
- ZE2UR_CALL (zeModuleGetNativeBinary,
686
- (Program->ZeModule , &SzBinary, NativeBinaryPtr));
726
+ if (!Program->ZeModuleMap .empty ()) {
727
+ uint32_t deviceIndex = 0 ;
728
+ for (auto &ZeDeviceModule : Program->ZeModuleMap ) {
729
+ size_t binarySize = 0 ;
730
+ ZE2UR_CALL (
731
+ zeModuleGetNativeBinary,
732
+ (ZeDeviceModule.second , &binarySize, PBinary[deviceIndex++]));
733
+ SzBinary += binarySize;
734
+ }
735
+ } else {
736
+ ZE2UR_CALL (zeModuleGetNativeBinary,
737
+ (Program->ZeModule , &SzBinary, NativeBinaryPtr));
738
+ }
687
739
if (PropSizeRet)
688
740
*PropSizeRet = SzBinary;
689
741
} else {
@@ -693,15 +745,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
693
745
}
694
746
case UR_PROGRAM_INFO_NUM_KERNELS: {
695
747
std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
696
- uint32_t NumKernels;
748
+ uint32_t NumKernels = 0 ;
697
749
if (Program->State == ur_program_handle_t_::IL ||
698
750
Program->State == ur_program_handle_t_::Native ||
699
751
Program->State == ur_program_handle_t_::Object) {
700
752
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
701
753
} else if (Program->State == ur_program_handle_t_::Exe) {
702
- NumKernels = 0 ;
703
- ZE2UR_CALL (zeModuleGetKernelNames,
704
- (Program->ZeModule , &NumKernels, nullptr ));
754
+ if (!Program->ZeModuleMap .empty ()) {
755
+ ZE2UR_CALL (
756
+ zeModuleGetKernelNames,
757
+ (Program->ZeModuleMap .begin ()->second , &NumKernels, nullptr ));
758
+ } else {
759
+ ZE2UR_CALL (zeModuleGetKernelNames,
760
+ (Program->ZeModule , &NumKernels, nullptr ));
761
+ }
705
762
} else {
706
763
return UR_RESULT_ERROR_INVALID_PROGRAM;
707
764
}
@@ -717,11 +774,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
717
774
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
718
775
} else if (Program->State == ur_program_handle_t_::Exe) {
719
776
uint32_t Count = 0 ;
720
- ZE2UR_CALL (zeModuleGetKernelNames,
721
- (Program->ZeModule , &Count, nullptr ));
722
- std::unique_ptr<const char *[]> PNames (new const char *[Count]);
723
- ZE2UR_CALL (zeModuleGetKernelNames,
724
- (Program->ZeModule , &Count, PNames.get ()));
777
+ std::unique_ptr<const char *[]> PNames;
778
+ if (!Program->ZeModuleMap .empty ()) {
779
+ ZE2UR_CALL (zeModuleGetKernelNames,
780
+ (Program->ZeModuleMap .begin ()->second , &Count, nullptr ));
781
+ PNames = std::make_unique<const char *[]>(Count);
782
+ ZE2UR_CALL (
783
+ zeModuleGetKernelNames,
784
+ (Program->ZeModuleMap .begin ()->second , &Count, PNames.get ()));
785
+ } else {
786
+ ZE2UR_CALL (zeModuleGetKernelNames,
787
+ (Program->ZeModule , &Count, nullptr ));
788
+ PNames = std::make_unique<const char *[]>(Count);
789
+ ZE2UR_CALL (zeModuleGetKernelNames,
790
+ (Program->ZeModule , &Count, PNames.get ()));
791
+ }
725
792
for (uint32_t I = 0 ; I < Count; ++I) {
726
793
PINames += (I > 0 ? " ;" : " " );
727
794
PINames += PNames[I];
0 commit comments