@@ -286,78 +286,63 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
286
286
return Plugin::error (ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str ());
287
287
};
288
288
289
- // Find the info if it exists under any of the given names
290
- auto getInfoString =
291
- [&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
292
- for (auto &Name : Names) {
293
- if (auto Entry = Device->Info .get (Name)) {
294
- if (!std::holds_alternative<std::string>((*Entry)->Value ))
295
- return makeError (ErrorCode::BACKEND_FAILURE,
296
- " plugin returned incorrect type" );
297
- return std::get<std::string>((*Entry)->Value ).c_str ();
298
- }
299
- }
300
-
301
- return makeError (ErrorCode::UNIMPLEMENTED,
302
- " plugin did not provide a response for this information" );
303
- };
304
-
305
- auto getInfoXyz =
306
- [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t > {
307
- for (auto &Name : Names) {
308
- if (auto Entry = Device->Info .get (Name)) {
309
- auto Node = *Entry;
310
- ol_dimensions_t Out{0 , 0 , 0 };
311
-
312
- auto getField = [&](StringRef Name, uint32_t &Dest) {
313
- if (auto F = Node->get (Name)) {
314
- if (!std::holds_alternative<size_t >((*F)->Value ))
315
- return makeError (
316
- ErrorCode::BACKEND_FAILURE,
317
- " plugin returned incorrect type for dimensions element" );
318
- Dest = std::get<size_t >((*F)->Value );
319
- } else
320
- return makeError (ErrorCode::BACKEND_FAILURE,
321
- " plugin didn't provide all values for dimensions" );
322
- return Plugin::success ();
323
- };
324
-
325
- if (auto Res = getField (" x" , Out.x ))
326
- return Res;
327
- if (auto Res = getField (" y" , Out.y ))
328
- return Res;
329
- if (auto Res = getField (" z" , Out.z ))
330
- return Res;
331
-
332
- return Out;
333
- }
334
- }
289
+ // These are not implemented by the plugin interface
290
+ if (PropName == OL_DEVICE_INFO_PLATFORM)
291
+ return Info.write <void *>(Device->Platform );
292
+ if (PropName == OL_DEVICE_INFO_TYPE)
293
+ return Info.write <ol_device_type_t >(OL_DEVICE_TYPE_GPU);
294
+ // TODO: Update when https://github.com/llvm/llvm-project/pull/147314 is merged
295
+ if (PropName > OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE)
296
+ return createOffloadError (ErrorCode::INVALID_ENUMERATION,
297
+ " getDeviceInfo enum '%i' is invalid" , PropName);
335
298
299
+ auto EntryOpt = Device->Info .get (static_cast <DeviceInfo>(PropName));
300
+ if (!EntryOpt)
336
301
return makeError (ErrorCode::UNIMPLEMENTED,
337
302
" plugin did not provide a response for this information" );
338
- } ;
303
+ auto Entry = *EntryOpt ;
339
304
340
305
switch (PropName) {
341
- case OL_DEVICE_INFO_PLATFORM:
342
- return Info.write <void *>(Device->Platform );
343
- case OL_DEVICE_INFO_TYPE:
344
- return Info.write <ol_device_type_t >(OL_DEVICE_TYPE_GPU);
345
306
case OL_DEVICE_INFO_NAME:
346
- return Info.writeString (getInfoString ({" Device Name" }));
347
307
case OL_DEVICE_INFO_VENDOR:
348
- return Info.writeString (getInfoString ({" Vendor Name" }));
349
- case OL_DEVICE_INFO_DRIVER_VERSION:
350
- return Info.writeString (
351
- getInfoString ({" CUDA Driver Version" , " HSA Runtime Version" }));
352
- case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
353
- return Info.write (getInfoXyz ({" Workgroup Max Size per Dimension" /* AMD*/ ,
354
- " Maximum Block Dimensions" /* CUDA*/ }));
355
- default :
356
- return createOffloadError (ErrorCode::INVALID_ENUMERATION,
357
- " getDeviceInfo enum '%i' is invalid" , PropName);
308
+ case OL_DEVICE_INFO_DRIVER_VERSION: {
309
+ // String values
310
+ if (!std::holds_alternative<std::string>(Entry->Value ))
311
+ return makeError (ErrorCode::BACKEND_FAILURE,
312
+ " plugin returned incorrect type" );
313
+ return Info.writeString (std::get<std::string>(Entry->Value ).c_str ());
358
314
}
359
315
360
- return Error::success ();
316
+ case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
317
+ // {x, y, z} triples
318
+ ol_dimensions_t Out{0 , 0 , 0 };
319
+
320
+ auto getField = [&](StringRef Name, uint32_t &Dest) {
321
+ if (auto F = Entry->get (Name)) {
322
+ if (!std::holds_alternative<size_t >((*F)->Value ))
323
+ return makeError (
324
+ ErrorCode::BACKEND_FAILURE,
325
+ " plugin returned incorrect type for dimensions element" );
326
+ Dest = std::get<size_t >((*F)->Value );
327
+ } else
328
+ return makeError (ErrorCode::BACKEND_FAILURE,
329
+ " plugin didn't provide all values for dimensions" );
330
+ return Plugin::success ();
331
+ };
332
+
333
+ if (auto Res = getField (" x" , Out.x ))
334
+ return Res;
335
+ if (auto Res = getField (" y" , Out.y ))
336
+ return Res;
337
+ if (auto Res = getField (" z" , Out.z ))
338
+ return Res;
339
+
340
+ return Info.write (Out);
341
+ }
342
+
343
+ default :
344
+ llvm_unreachable (" Unimplemented device info" );
345
+ }
361
346
}
362
347
363
348
Error olGetDeviceInfoImplDetailHost (ol_device_handle_t Device,
0 commit comments