Skip to content

[Offload] Refactor device information queries to use new tagging #147318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: users/RossBrunton/keylookup2
Choose a base branch
from

Conversation

RossBrunton
Copy link
Contributor

Instead using strings to look up device information (which is brittle
and slow), use the new tags that the plugins specify when building the
nodes.

@llvmbot llvmbot added the offload label Jul 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-offload

Author: Ross Brunton (RossBrunton)

Changes

Instead using strings to look up device information (which is brittle
and slow), use the new tags that the plugins specify when building the
nodes.


Full diff: https://github.com/llvm/llvm-project/pull/147318.diff

2 Files Affected:

  • (modified) offload/liboffload/src/Helpers.hpp (+6-13)
  • (modified) offload/liboffload/src/OffloadImpl.cpp (+48-63)
diff --git a/offload/liboffload/src/Helpers.hpp b/offload/liboffload/src/Helpers.hpp
index 8b85945508b98..62e55e500fac7 100644
--- a/offload/liboffload/src/Helpers.hpp
+++ b/offload/liboffload/src/Helpers.hpp
@@ -75,23 +75,16 @@ class InfoWriter {
   InfoWriter(InfoWriter &) = delete;
   ~InfoWriter() = default;
 
-  template <typename T> llvm::Error write(llvm::Expected<T> &&Val) {
-    if (Val)
-      return getInfo(Size, Target, SizeRet, *Val);
-    return Val.takeError();
+  template <typename T> llvm::Error write(T Val) {
+    return getInfo(Size, Target, SizeRet, Val);
   }
 
-  template <typename T>
-  llvm::Error writeArray(llvm::Expected<T> &&Val, size_t Elems) {
-    if (Val)
-      return getInfoArray(Elems, Size, Target, SizeRet, *Val);
-    return Val.takeError();
+  template <typename T> llvm::Error writeArray(T Val, size_t Elems) {
+    return getInfoArray(Elems, Size, Target, SizeRet, Val);
   }
 
-  llvm::Error writeString(llvm::Expected<llvm::StringRef> &&Val) {
-    if (Val)
-      return getInfoString(Size, Target, SizeRet, *Val);
-    return Val.takeError();
+  llvm::Error writeString(llvm::StringRef Val) {
+    return getInfoString(Size, Target, SizeRet, Val);
   }
 
 private:
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index f9da638436705..c84bf01460252 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -286,78 +286,63 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
     return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
   };
 
-  // Find the info if it exists under any of the given names
-  auto getInfoString =
-      [&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
-    for (auto &Name : Names) {
-      if (auto Entry = Device->Info.get(Name)) {
-        if (!std::holds_alternative<std::string>((*Entry)->Value))
-          return makeError(ErrorCode::BACKEND_FAILURE,
-                           "plugin returned incorrect type");
-        return std::get<std::string>((*Entry)->Value).c_str();
-      }
-    }
-
-    return makeError(ErrorCode::UNIMPLEMENTED,
-                     "plugin did not provide a response for this information");
-  };
-
-  auto getInfoXyz =
-      [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> {
-    for (auto &Name : Names) {
-      if (auto Entry = Device->Info.get(Name)) {
-        auto Node = *Entry;
-        ol_dimensions_t Out{0, 0, 0};
-
-        auto getField = [&](StringRef Name, uint32_t &Dest) {
-          if (auto F = Node->get(Name)) {
-            if (!std::holds_alternative<size_t>((*F)->Value))
-              return makeError(
-                  ErrorCode::BACKEND_FAILURE,
-                  "plugin returned incorrect type for dimensions element");
-            Dest = std::get<size_t>((*F)->Value);
-          } else
-            return makeError(ErrorCode::BACKEND_FAILURE,
-                             "plugin didn't provide all values for dimensions");
-          return Plugin::success();
-        };
-
-        if (auto Res = getField("x", Out.x))
-          return Res;
-        if (auto Res = getField("y", Out.y))
-          return Res;
-        if (auto Res = getField("z", Out.z))
-          return Res;
-
-        return Out;
-      }
-    }
+  // These are not implemented by the plugin interface
+  if (PropName == OL_DEVICE_INFO_PLATFORM)
+    return Info.write<void *>(Device->Platform);
+  if (PropName == OL_DEVICE_INFO_TYPE)
+    return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
+  // TODO: Update when https://github.com/llvm/llvm-project/pull/147314 is merged
+  if (PropName > OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE)
+    return createOffloadError(ErrorCode::INVALID_ENUMERATION,
+                              "getDeviceInfo enum '%i' is invalid", PropName);
 
+  auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
+  if (!EntryOpt)
     return makeError(ErrorCode::UNIMPLEMENTED,
                      "plugin did not provide a response for this information");
-  };
+  auto Entry = *EntryOpt;
 
   switch (PropName) {
-  case OL_DEVICE_INFO_PLATFORM:
-    return Info.write<void *>(Device->Platform);
-  case OL_DEVICE_INFO_TYPE:
-    return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
   case OL_DEVICE_INFO_NAME:
-    return Info.writeString(getInfoString({"Device Name"}));
   case OL_DEVICE_INFO_VENDOR:
-    return Info.writeString(getInfoString({"Vendor Name"}));
-  case OL_DEVICE_INFO_DRIVER_VERSION:
-    return Info.writeString(
-        getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
-  case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
-    return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
-                                  "Maximum Block Dimensions" /*CUDA*/}));
-  default:
-    return createOffloadError(ErrorCode::INVALID_ENUMERATION,
-                              "getDeviceInfo enum '%i' is invalid", PropName);
+  case OL_DEVICE_INFO_DRIVER_VERSION: {
+    // String values
+    if (!std::holds_alternative<std::string>(Entry->Value))
+      return makeError(ErrorCode::BACKEND_FAILURE,
+                       "plugin returned incorrect type");
+    return Info.writeString(std::get<std::string>(Entry->Value).c_str());
   }
 
-  return Error::success();
+  case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
+    // {x, y, z} triples
+    ol_dimensions_t Out{0, 0, 0};
+
+    auto getField = [&](StringRef Name, uint32_t &Dest) {
+      if (auto F = Entry->get(Name)) {
+        if (!std::holds_alternative<size_t>((*F)->Value))
+          return makeError(
+              ErrorCode::BACKEND_FAILURE,
+              "plugin returned incorrect type for dimensions element");
+        Dest = std::get<size_t>((*F)->Value);
+      } else
+        return makeError(ErrorCode::BACKEND_FAILURE,
+                         "plugin didn't provide all values for dimensions");
+      return Plugin::success();
+    };
+
+    if (auto Res = getField("x", Out.x))
+      return Res;
+    if (auto Res = getField("y", Out.y))
+      return Res;
+    if (auto Res = getField("z", Out.z))
+      return Res;
+
+    return Info.write(Out);
+  }
+
+  default:
+    llvm_unreachable("Unimplemented device info");
+  }
 }
 
 Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,

Copy link

github-actions bot commented Jul 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@RossBrunton RossBrunton force-pushed the users/RossBrunton/keylookup3 branch from 4cce1ee to 0550a48 Compare July 8, 2025 09:12
@RossBrunton RossBrunton force-pushed the users/RossBrunton/keylookup2 branch from 8dc2544 to 00aab9f Compare July 9, 2025 10:37
Instead using strings to look up device information (which is brittle
and slow), use the new tags that the plugins specify when building the
nodes.
@RossBrunton RossBrunton force-pushed the users/RossBrunton/keylookup3 branch from 9dba00d to 7a3defb Compare July 9, 2025 10:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants