Skip to content

Commit cd4b111

Browse files
nrspruitomarahmed1111
authored andcommitted
Fix multi device module/kernel access
Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
1 parent fa3a6a9 commit cd4b111

File tree

1 file changed

+89
-22
lines changed

1 file changed

+89
-22
lines changed

source/adapters/level_zero/program.cpp

Lines changed: 89 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
599599
void **GlobalVariablePointerRet ///< [out] Returns the pointer to the global
600600
///< variable if it is found in the program.
601601
) {
602-
std::ignore = Device;
603602
std::scoped_lock<ur_shared_mutex> lock(Program->Mutex);
604603

604+
ze_module_handle_t ZeModuleEntry{};
605+
ZeModuleEntry = Program->ZeModule;
606+
if (!Program->ZeModuleMap.empty()) {
607+
auto It = Program->ZeModuleMap.find(Device->ZeDevice);
608+
if (It != Program->ZeModuleMap.end()) {
609+
ZeModuleEntry = It->second;
610+
}
611+
}
612+
605613
ze_result_t ZeResult =
606-
zeModuleGetGlobalPointer(Program->ZeModule, GlobalVariableName,
614+
zeModuleGetGlobalPointer(ZeModuleEntry, GlobalVariableName,
607615
GlobalVariableSizeRet, GlobalVariablePointerRet);
608616

609617
if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) {
@@ -634,11 +642,28 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
634642
case UR_PROGRAM_INFO_CONTEXT:
635643
return ReturnValue(Program->Context);
636644
case UR_PROGRAM_INFO_NUM_DEVICES:
637-
// TODO: return true number of devices this program exists for.
638-
return ReturnValue(uint32_t{1});
645+
if (!Program->ZeModuleMap.empty())
646+
return ReturnValue(
647+
uint32_t{ur_cast<uint32_t>(Program->ZeModuleMap.size())});
648+
else
649+
return ReturnValue(uint32_t{1});
639650
case UR_PROGRAM_INFO_DEVICES:
640-
// TODO: return all devices this program exists for.
641-
return ReturnValue(Program->Context->Devices[0]);
651+
if (!Program->ZeModuleMap.empty()) {
652+
std::vector<ur_device_handle_t> devices;
653+
for (auto &ZeModulePair : Program->ZeModuleMap) {
654+
auto It = Program->ZeModuleMap.find(ZeModulePair.first);
655+
if (It != Program->ZeModuleMap.end()) {
656+
for (auto &Device : Program->Context->Devices) {
657+
if (Device->ZeDevice == ZeModulePair.first) {
658+
devices.push_back(Device);
659+
}
660+
}
661+
}
662+
}
663+
return ReturnValue(devices);
664+
} else {
665+
return ReturnValue(Program->Context->Devices[0]);
666+
}
642667
case UR_PROGRAM_INFO_BINARY_SIZES: {
643668
std::shared_lock<ur_shared_mutex> Guard(Program->Mutex);
644669
size_t SzBinary;
@@ -647,8 +672,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
647672
Program->State == ur_program_handle_t_::Object) {
648673
SzBinary = Program->CodeLength;
649674
} else if (Program->State == ur_program_handle_t_::Exe) {
650-
ZE2UR_CALL(zeModuleGetNativeBinary,
651-
(Program->ZeModule, &SzBinary, nullptr));
675+
if (!Program->ZeModuleMap.empty()) {
676+
std::vector<size_t> binarySizes;
677+
for (auto &ZeModulePair : Program->ZeModuleMap) {
678+
size_t binarySize = 0;
679+
ZE2UR_CALL(zeModuleGetNativeBinary,
680+
(ZeModulePair.second, &binarySize, nullptr));
681+
binarySizes.push_back(binarySize);
682+
}
683+
return ReturnValue(binarySizes);
684+
} else {
685+
ZE2UR_CALL(zeModuleGetNativeBinary,
686+
(Program->ZeModule, &SzBinary, nullptr));
687+
return ReturnValue(SzBinary);
688+
}
652689
} else {
653690
return UR_RESULT_ERROR_INVALID_PROGRAM;
654691
}
@@ -657,9 +694,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
657694
}
658695
case UR_PROGRAM_INFO_BINARIES: {
659696
// The caller sets "ParamValue" to an array of pointers, one for each
660-
// device. Since Level Zero supports only one device, there is only one
661-
// pointer. If the pointer is NULL, we don't do anything. Otherwise, we
662-
// copy the program's binary image to the buffer at that pointer.
697+
// device.
663698
uint8_t **PBinary = nullptr;
664699
if (ProgramInfo) {
665700
PBinary = ur_cast<uint8_t **>(ProgramInfo);
@@ -668,6 +703,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
668703
}
669704
}
670705
std::shared_lock<ur_shared_mutex> Guard(Program->Mutex);
706+
// If the caller is using a Program which is IL, Native or an object, then
707+
// the program has not been built for multiple devices so a single IL is
708+
// returned.
671709
if (Program->State == ur_program_handle_t_::IL ||
672710
Program->State == ur_program_handle_t_::Native ||
673711
Program->State == ur_program_handle_t_::Object) {
@@ -677,13 +715,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
677715
std::memcpy(PBinary[0], Program->Code.get(), Program->CodeLength);
678716
}
679717
} else if (Program->State == ur_program_handle_t_::Exe) {
718+
// If the caller is using a Program which is a built binary, then
719+
// the program returned will either be a single module if this is a native
720+
// binary or the native binary for each device will be returned.
680721
size_t SzBinary = 0;
681722
uint8_t *NativeBinaryPtr = nullptr;
682723
if (PBinary) {
683724
NativeBinaryPtr = PBinary[0];
684725
}
685-
ZE2UR_CALL(zeModuleGetNativeBinary,
686-
(Program->ZeModule, &SzBinary, NativeBinaryPtr));
726+
if (!Program->ZeModuleMap.empty()) {
727+
uint32_t deviceIndex = 0;
728+
for (auto &ZeDeviceModule : Program->ZeModuleMap) {
729+
size_t binarySize = 0;
730+
ZE2UR_CALL(
731+
zeModuleGetNativeBinary,
732+
(ZeDeviceModule.second, &binarySize, PBinary[deviceIndex++]));
733+
SzBinary += binarySize;
734+
}
735+
} else {
736+
ZE2UR_CALL(zeModuleGetNativeBinary,
737+
(Program->ZeModule, &SzBinary, NativeBinaryPtr));
738+
}
687739
if (PropSizeRet)
688740
*PropSizeRet = SzBinary;
689741
} else {
@@ -693,15 +745,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
693745
}
694746
case UR_PROGRAM_INFO_NUM_KERNELS: {
695747
std::shared_lock<ur_shared_mutex> Guard(Program->Mutex);
696-
uint32_t NumKernels;
748+
uint32_t NumKernels = 0;
697749
if (Program->State == ur_program_handle_t_::IL ||
698750
Program->State == ur_program_handle_t_::Native ||
699751
Program->State == ur_program_handle_t_::Object) {
700752
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
701753
} else if (Program->State == ur_program_handle_t_::Exe) {
702-
NumKernels = 0;
703-
ZE2UR_CALL(zeModuleGetKernelNames,
704-
(Program->ZeModule, &NumKernels, nullptr));
754+
if (!Program->ZeModuleMap.empty()) {
755+
ZE2UR_CALL(
756+
zeModuleGetKernelNames,
757+
(Program->ZeModuleMap.begin()->second, &NumKernels, nullptr));
758+
} else {
759+
ZE2UR_CALL(zeModuleGetKernelNames,
760+
(Program->ZeModule, &NumKernels, nullptr));
761+
}
705762
} else {
706763
return UR_RESULT_ERROR_INVALID_PROGRAM;
707764
}
@@ -717,11 +774,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
717774
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
718775
} else if (Program->State == ur_program_handle_t_::Exe) {
719776
uint32_t Count = 0;
720-
ZE2UR_CALL(zeModuleGetKernelNames,
721-
(Program->ZeModule, &Count, nullptr));
722-
std::unique_ptr<const char *[]> PNames(new const char *[Count]);
723-
ZE2UR_CALL(zeModuleGetKernelNames,
724-
(Program->ZeModule, &Count, PNames.get()));
777+
std::unique_ptr<const char *[]> PNames;
778+
if (!Program->ZeModuleMap.empty()) {
779+
ZE2UR_CALL(zeModuleGetKernelNames,
780+
(Program->ZeModuleMap.begin()->second, &Count, nullptr));
781+
PNames = std::make_unique<const char *[]>(Count);
782+
ZE2UR_CALL(
783+
zeModuleGetKernelNames,
784+
(Program->ZeModuleMap.begin()->second, &Count, PNames.get()));
785+
} else {
786+
ZE2UR_CALL(zeModuleGetKernelNames,
787+
(Program->ZeModule, &Count, nullptr));
788+
PNames = std::make_unique<const char *[]>(Count);
789+
ZE2UR_CALL(zeModuleGetKernelNames,
790+
(Program->ZeModule, &Count, PNames.get()));
791+
}
725792
for (uint32_t I = 0; I < Count; ++I) {
726793
PINames += (I > 0 ? ";" : "");
727794
PINames += PNames[I];

0 commit comments

Comments
 (0)