Skip to content

Commit 4cce1ee

Browse files
committed
[Offload] Refactor device information queries to use new tagging
Instead using strings to look up device information (which is brittle and slow), use the new tags that the plugins specify when building the nodes.
1 parent 9b79557 commit 4cce1ee

File tree

2 files changed

+54
-76
lines changed

2 files changed

+54
-76
lines changed

offload/liboffload/src/Helpers.hpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,16 @@ class InfoWriter {
7575
InfoWriter(InfoWriter &) = delete;
7676
~InfoWriter() = default;
7777

78-
template <typename T> llvm::Error write(llvm::Expected<T> &&Val) {
79-
if (Val)
80-
return getInfo(Size, Target, SizeRet, *Val);
81-
return Val.takeError();
78+
template <typename T> llvm::Error write(T Val) {
79+
return getInfo(Size, Target, SizeRet, Val);
8280
}
8381

84-
template <typename T>
85-
llvm::Error writeArray(llvm::Expected<T> &&Val, size_t Elems) {
86-
if (Val)
87-
return getInfoArray(Elems, Size, Target, SizeRet, *Val);
88-
return Val.takeError();
82+
template <typename T> llvm::Error writeArray(T Val, size_t Elems) {
83+
return getInfoArray(Elems, Size, Target, SizeRet, Val);
8984
}
9085

91-
llvm::Error writeString(llvm::Expected<llvm::StringRef> &&Val) {
92-
if (Val)
93-
return getInfoString(Size, Target, SizeRet, *Val);
94-
return Val.takeError();
86+
llvm::Error writeString(llvm::StringRef Val) {
87+
return getInfoString(Size, Target, SizeRet, Val);
9588
}
9689

9790
private:

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 48 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -286,78 +286,63 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
286286
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
287287
};
288288

289-
// Find the info if it exists under any of the given names
290-
auto getInfoString =
291-
[&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
292-
for (auto &Name : Names) {
293-
if (auto Entry = Device->Info.get(Name)) {
294-
if (!std::holds_alternative<std::string>((*Entry)->Value))
295-
return makeError(ErrorCode::BACKEND_FAILURE,
296-
"plugin returned incorrect type");
297-
return std::get<std::string>((*Entry)->Value).c_str();
298-
}
299-
}
300-
301-
return makeError(ErrorCode::UNIMPLEMENTED,
302-
"plugin did not provide a response for this information");
303-
};
304-
305-
auto getInfoXyz =
306-
[&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> {
307-
for (auto &Name : Names) {
308-
if (auto Entry = Device->Info.get(Name)) {
309-
auto Node = *Entry;
310-
ol_dimensions_t Out{0, 0, 0};
311-
312-
auto getField = [&](StringRef Name, uint32_t &Dest) {
313-
if (auto F = Node->get(Name)) {
314-
if (!std::holds_alternative<size_t>((*F)->Value))
315-
return makeError(
316-
ErrorCode::BACKEND_FAILURE,
317-
"plugin returned incorrect type for dimensions element");
318-
Dest = std::get<size_t>((*F)->Value);
319-
} else
320-
return makeError(ErrorCode::BACKEND_FAILURE,
321-
"plugin didn't provide all values for dimensions");
322-
return Plugin::success();
323-
};
324-
325-
if (auto Res = getField("x", Out.x))
326-
return Res;
327-
if (auto Res = getField("y", Out.y))
328-
return Res;
329-
if (auto Res = getField("z", Out.z))
330-
return Res;
331-
332-
return Out;
333-
}
334-
}
289+
// These are not implemented by the plugin interface
290+
if (PropName == OL_DEVICE_INFO_PLATFORM)
291+
return Info.write<void *>(Device->Platform);
292+
if (PropName == OL_DEVICE_INFO_TYPE)
293+
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
294+
// TODO: Update when https://github.com/llvm/llvm-project/pull/147314 is merged
295+
if (PropName > OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE)
296+
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
297+
"getDeviceInfo enum '%i' is invalid", PropName);
335298

299+
auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
300+
if (!EntryOpt)
336301
return makeError(ErrorCode::UNIMPLEMENTED,
337302
"plugin did not provide a response for this information");
338-
};
303+
auto Entry = *EntryOpt;
339304

340305
switch (PropName) {
341-
case OL_DEVICE_INFO_PLATFORM:
342-
return Info.write<void *>(Device->Platform);
343-
case OL_DEVICE_INFO_TYPE:
344-
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
345306
case OL_DEVICE_INFO_NAME:
346-
return Info.writeString(getInfoString({"Device Name"}));
347307
case OL_DEVICE_INFO_VENDOR:
348-
return Info.writeString(getInfoString({"Vendor Name"}));
349-
case OL_DEVICE_INFO_DRIVER_VERSION:
350-
return Info.writeString(
351-
getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
352-
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
353-
return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
354-
"Maximum Block Dimensions" /*CUDA*/}));
355-
default:
356-
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
357-
"getDeviceInfo enum '%i' is invalid", PropName);
308+
case OL_DEVICE_INFO_DRIVER_VERSION: {
309+
// String values
310+
if (!std::holds_alternative<std::string>(Entry->Value))
311+
return makeError(ErrorCode::BACKEND_FAILURE,
312+
"plugin returned incorrect type");
313+
return Info.writeString(std::get<std::string>(Entry->Value).c_str());
358314
}
359315

360-
return Error::success();
316+
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
317+
// {x, y, z} triples
318+
ol_dimensions_t Out{0, 0, 0};
319+
320+
auto getField = [&](StringRef Name, uint32_t &Dest) {
321+
if (auto F = Entry->get(Name)) {
322+
if (!std::holds_alternative<size_t>((*F)->Value))
323+
return makeError(
324+
ErrorCode::BACKEND_FAILURE,
325+
"plugin returned incorrect type for dimensions element");
326+
Dest = std::get<size_t>((*F)->Value);
327+
} else
328+
return makeError(ErrorCode::BACKEND_FAILURE,
329+
"plugin didn't provide all values for dimensions");
330+
return Plugin::success();
331+
};
332+
333+
if (auto Res = getField("x", Out.x))
334+
return Res;
335+
if (auto Res = getField("y", Out.y))
336+
return Res;
337+
if (auto Res = getField("z", Out.z))
338+
return Res;
339+
340+
return Info.write(Out);
341+
}
342+
343+
default:
344+
llvm_unreachable("Unimplemented device info");
345+
}
361346
}
362347

363348
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,

0 commit comments

Comments
 (0)