Skip to content

[Offload] Refactor device information queries to use new tagging #147318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: users/RossBrunton/keylookup2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions offload/liboffload/src/Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,16 @@ class InfoWriter {
InfoWriter(InfoWriter &) = delete;
~InfoWriter() = default;

template <typename T> llvm::Error write(llvm::Expected<T> &&Val) {
if (Val)
return getInfo(Size, Target, SizeRet, *Val);
return Val.takeError();
template <typename T> llvm::Error write(T Val) {
return getInfo(Size, Target, SizeRet, Val);
}

template <typename T>
llvm::Error writeArray(llvm::Expected<T> &&Val, size_t Elems) {
if (Val)
return getInfoArray(Elems, Size, Target, SizeRet, *Val);
return Val.takeError();
template <typename T> llvm::Error writeArray(T Val, size_t Elems) {
return getInfoArray(Elems, Size, Target, SizeRet, Val);
}

llvm::Error writeString(llvm::Expected<llvm::StringRef> &&Val) {
if (Val)
return getInfoString(Size, Target, SizeRet, *Val);
return Val.takeError();
llvm::Error writeString(llvm::StringRef Val) {
return getInfoString(Size, Target, SizeRet, Val);
}

private:
Expand Down
112 changes: 49 additions & 63 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,78 +286,64 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
};

// Find the info if it exists under any of the given names
auto getInfoString =
[&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
for (auto &Name : Names) {
if (auto Entry = Device->Info.get(Name)) {
if (!std::holds_alternative<std::string>((*Entry)->Value))
return makeError(ErrorCode::BACKEND_FAILURE,
"plugin returned incorrect type");
return std::get<std::string>((*Entry)->Value).c_str();
}
}

return makeError(ErrorCode::UNIMPLEMENTED,
"plugin did not provide a response for this information");
};

auto getInfoXyz =
[&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> {
for (auto &Name : Names) {
if (auto Entry = Device->Info.get(Name)) {
auto Node = *Entry;
ol_dimensions_t Out{0, 0, 0};

auto getField = [&](StringRef Name, uint32_t &Dest) {
if (auto F = Node->get(Name)) {
if (!std::holds_alternative<size_t>((*F)->Value))
return makeError(
ErrorCode::BACKEND_FAILURE,
"plugin returned incorrect type for dimensions element");
Dest = std::get<size_t>((*F)->Value);
} else
return makeError(ErrorCode::BACKEND_FAILURE,
"plugin didn't provide all values for dimensions");
return Plugin::success();
};

if (auto Res = getField("x", Out.x))
return Res;
if (auto Res = getField("y", Out.y))
return Res;
if (auto Res = getField("z", Out.z))
return Res;

return Out;
}
}
// These are not implemented by the plugin interface
if (PropName == OL_DEVICE_INFO_PLATFORM)
return Info.write<void *>(Device->Platform);
if (PropName == OL_DEVICE_INFO_TYPE)
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
// TODO: Update when https://github.com/llvm/llvm-project/pull/147314 is
// merged
if (PropName > OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE)
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getDeviceInfo enum '%i' is invalid", PropName);

auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
if (!EntryOpt)
return makeError(ErrorCode::UNIMPLEMENTED,
"plugin did not provide a response for this information");
};
auto Entry = *EntryOpt;

switch (PropName) {
case OL_DEVICE_INFO_PLATFORM:
return Info.write<void *>(Device->Platform);
case OL_DEVICE_INFO_TYPE:
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
case OL_DEVICE_INFO_NAME:
return Info.writeString(getInfoString({"Device Name"}));
case OL_DEVICE_INFO_VENDOR:
return Info.writeString(getInfoString({"Vendor Name"}));
case OL_DEVICE_INFO_DRIVER_VERSION:
return Info.writeString(
getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
"Maximum Block Dimensions" /*CUDA*/}));
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getDeviceInfo enum '%i' is invalid", PropName);
case OL_DEVICE_INFO_DRIVER_VERSION: {
// String values
if (!std::holds_alternative<std::string>(Entry->Value))
return makeError(ErrorCode::BACKEND_FAILURE,
"plugin returned incorrect type");
return Info.writeString(std::get<std::string>(Entry->Value).c_str());
}

return Error::success();
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
// {x, y, z} triples
ol_dimensions_t Out{0, 0, 0};

auto getField = [&](StringRef Name, uint32_t &Dest) {
if (auto F = Entry->get(Name)) {
if (!std::holds_alternative<size_t>((*F)->Value))
return makeError(
ErrorCode::BACKEND_FAILURE,
"plugin returned incorrect type for dimensions element");
Dest = std::get<size_t>((*F)->Value);
} else
return makeError(ErrorCode::BACKEND_FAILURE,
"plugin didn't provide all values for dimensions");
return Plugin::success();
};

if (auto Res = getField("x", Out.x))
return Res;
if (auto Res = getField("y", Out.y))
return Res;
if (auto Res = getField("z", Out.z))
return Res;

return Info.write(Out);
}

default:
llvm_unreachable("Unimplemented device info");
}
}

Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
Expand Down
Loading