diff --git a/offload/liboffload/src/Helpers.hpp b/offload/liboffload/src/Helpers.hpp index 8b85945508b98..62e55e500fac7 100644 --- a/offload/liboffload/src/Helpers.hpp +++ b/offload/liboffload/src/Helpers.hpp @@ -75,23 +75,16 @@ class InfoWriter { InfoWriter(InfoWriter &) = delete; ~InfoWriter() = default; - template llvm::Error write(llvm::Expected &&Val) { - if (Val) - return getInfo(Size, Target, SizeRet, *Val); - return Val.takeError(); + template llvm::Error write(T Val) { + return getInfo(Size, Target, SizeRet, Val); } - template - llvm::Error writeArray(llvm::Expected &&Val, size_t Elems) { - if (Val) - return getInfoArray(Elems, Size, Target, SizeRet, *Val); - return Val.takeError(); + template llvm::Error writeArray(T Val, size_t Elems) { + return getInfoArray(Elems, Size, Target, SizeRet, Val); } - llvm::Error writeString(llvm::Expected &&Val) { - if (Val) - return getInfoString(Size, Target, SizeRet, *Val); - return Val.takeError(); + llvm::Error writeString(llvm::StringRef Val) { + return getInfoString(Size, Target, SizeRet, Val); } private: diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index f9da638436705..4ca32d2e0f8a5 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -286,78 +286,64 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str()); }; - // Find the info if it exists under any of the given names - auto getInfoString = - [&](std::vector Names) -> llvm::Expected { - for (auto &Name : Names) { - if (auto Entry = Device->Info.get(Name)) { - if (!std::holds_alternative((*Entry)->Value)) - return makeError(ErrorCode::BACKEND_FAILURE, - "plugin returned incorrect type"); - return std::get((*Entry)->Value).c_str(); - } - } - - return makeError(ErrorCode::UNIMPLEMENTED, - "plugin did not provide a response for this information"); - }; - - auto getInfoXyz = - [&](std::vector Names) -> llvm::Expected { - for (auto &Name : Names) { - if (auto Entry = Device->Info.get(Name)) { - auto Node = *Entry; - ol_dimensions_t Out{0, 0, 0}; - - auto getField = [&](StringRef Name, uint32_t &Dest) { - if (auto F = Node->get(Name)) { - if (!std::holds_alternative((*F)->Value)) - return makeError( - ErrorCode::BACKEND_FAILURE, - "plugin returned incorrect type for dimensions element"); - Dest = std::get((*F)->Value); - } else - return makeError(ErrorCode::BACKEND_FAILURE, - "plugin didn't provide all values for dimensions"); - return Plugin::success(); - }; - - if (auto Res = getField("x", Out.x)) - return Res; - if (auto Res = getField("y", Out.y)) - return Res; - if (auto Res = getField("z", Out.z)) - return Res; - - return Out; - } - } + // These are not implemented by the plugin interface + if (PropName == OL_DEVICE_INFO_PLATFORM) + return Info.write(Device->Platform); + if (PropName == OL_DEVICE_INFO_TYPE) + return Info.write(OL_DEVICE_TYPE_GPU); + // TODO: Update when https://github.com/llvm/llvm-project/pull/147314 is + // merged + if (PropName > OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE) + return createOffloadError(ErrorCode::INVALID_ENUMERATION, + "getDeviceInfo enum '%i' is invalid", PropName); + auto EntryOpt = Device->Info.get(static_cast(PropName)); + if (!EntryOpt) return makeError(ErrorCode::UNIMPLEMENTED, "plugin did not provide a response for this information"); - }; + auto Entry = *EntryOpt; switch (PropName) { - case OL_DEVICE_INFO_PLATFORM: - return Info.write(Device->Platform); - case OL_DEVICE_INFO_TYPE: - return Info.write(OL_DEVICE_TYPE_GPU); case OL_DEVICE_INFO_NAME: - return Info.writeString(getInfoString({"Device Name"})); case OL_DEVICE_INFO_VENDOR: - return Info.writeString(getInfoString({"Vendor Name"})); - case OL_DEVICE_INFO_DRIVER_VERSION: - return Info.writeString( - getInfoString({"CUDA Driver Version", "HSA Runtime Version"})); - case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: - return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/, - "Maximum Block Dimensions" /*CUDA*/})); - default: - return createOffloadError(ErrorCode::INVALID_ENUMERATION, - "getDeviceInfo enum '%i' is invalid", PropName); + case OL_DEVICE_INFO_DRIVER_VERSION: { + // String values + if (!std::holds_alternative(Entry->Value)) + return makeError(ErrorCode::BACKEND_FAILURE, + "plugin returned incorrect type"); + return Info.writeString(std::get(Entry->Value).c_str()); } - return Error::success(); + case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: { + // {x, y, z} triples + ol_dimensions_t Out{0, 0, 0}; + + auto getField = [&](StringRef Name, uint32_t &Dest) { + if (auto F = Entry->get(Name)) { + if (!std::holds_alternative((*F)->Value)) + return makeError( + ErrorCode::BACKEND_FAILURE, + "plugin returned incorrect type for dimensions element"); + Dest = std::get((*F)->Value); + } else + return makeError(ErrorCode::BACKEND_FAILURE, + "plugin didn't provide all values for dimensions"); + return Plugin::success(); + }; + + if (auto Res = getField("x", Out.x)) + return Res; + if (auto Res = getField("y", Out.y)) + return Res; + if (auto Res = getField("z", Out.z)) + return Res; + + return Info.write(Out); + } + + default: + llvm_unreachable("Unimplemented device info"); + } } Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,