Skip to content

Commit eee723f

Browse files
authored
[Offload] Replace GetKernel with GetSymbol with global support (#148221)
`olGetKernel` has been replaced by `olGetSymbol` which accepts a `Kind` parameter. As well as loading information about kernels, it can now also load information about global variables.
1 parent 38b9c66 commit eee723f

File tree

9 files changed

+169
-77
lines changed

9 files changed

+169
-77
lines changed

offload/liboffload/API/Kernel.td

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,10 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file contains Offload API definitions related to loading and launching
10-
// kernels
9+
// This file contains Offload API definitions related to launching kernels
1110
//
1211
//===----------------------------------------------------------------------===//
1312

14-
def : Function {
15-
let name = "olGetKernel";
16-
let desc = "Get a kernel from the function identified by `KernelName` in the given program.";
17-
let details = [
18-
"Symbol handles are owned by the program and do not need to be manually destroyed."
19-
];
20-
let params = [
21-
Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>,
22-
Param<"const char*", "KernelName", "name of the kernel entry point in the program", PARAM_IN>,
23-
Param<"ol_symbol_handle_t*", "Kernel", "output pointer for the fetched kernel", PARAM_OUT>
24-
];
25-
let returns = [];
26-
}
27-
2813
def : Struct {
2914
let name = "ol_kernel_launch_size_args_t";
3015
let desc = "Size-related arguments for a kernel launch.";

offload/liboffload/API/Symbol.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,21 @@ def : Enum {
1515
let desc = "The kind of a symbol";
1616
let etors =[
1717
Etor<"KERNEL", "a kernel object">,
18+
Etor<"GLOBAL_VARIABLE", "a global variable">,
1819
];
1920
}
21+
22+
def : Function {
23+
let name = "olGetSymbol";
24+
let desc = "Get a symbol (kernel or global variable) identified by `Name` in the given program.";
25+
let details = [
26+
"Symbol handles are owned by the program and do not need to be manually destroyed."
27+
];
28+
let params = [
29+
Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>,
30+
Param<"const char*", "Name", "name of the symbol to look up", PARAM_IN>,
31+
Param<"ol_symbol_kind_t", "Kind", "symbol kind to look up", PARAM_IN>,
32+
Param<"ol_symbol_handle_t*", "Symbol", "output pointer for the symbol", PARAM_OUT>,
33+
];
34+
let returns = [];
35+
}

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ struct ol_program_impl_t {
9191
struct ol_symbol_impl_t {
9292
ol_symbol_impl_t(GenericKernelTy *Kernel)
9393
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {}
94-
std::variant<GenericKernelTy *> PluginImpl;
94+
ol_symbol_impl_t(GlobalTy &&Global)
95+
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {}
96+
std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
9597
ol_symbol_kind_t Kind;
9698
};
9799

@@ -660,24 +662,6 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) {
660662
return olDestroy(Program);
661663
}
662664

663-
Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName,
664-
ol_symbol_handle_t *Kernel) {
665-
666-
auto &Device = Program->Image->getDevice();
667-
auto KernelImpl = Device.constructKernel(KernelName);
668-
if (!KernelImpl)
669-
return KernelImpl.takeError();
670-
671-
if (auto Err = KernelImpl->init(Device, *Program->Image))
672-
return Err;
673-
674-
*Kernel = Program->Symbols
675-
.emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
676-
.get();
677-
678-
return Error::success();
679-
}
680-
681665
Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
682666
ol_symbol_handle_t Kernel, const void *ArgumentsData,
683667
size_t ArgumentsSize,
@@ -726,5 +710,43 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
726710
return Error::success();
727711
}
728712

713+
Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
714+
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
715+
auto &Device = Program->Image->getDevice();
716+
717+
switch (Kind) {
718+
case OL_SYMBOL_KIND_KERNEL: {
719+
auto KernelImpl = Device.constructKernel(Name);
720+
if (!KernelImpl)
721+
return KernelImpl.takeError();
722+
723+
if (auto Err = KernelImpl->init(Device, *Program->Image))
724+
return Err;
725+
726+
*Symbol =
727+
Program->Symbols
728+
.emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
729+
.get();
730+
return Error::success();
731+
}
732+
case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
733+
GlobalTy GlobalObj{Name};
734+
if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
735+
Device, *Program->Image, GlobalObj))
736+
return Res;
737+
738+
*Symbol = Program->Symbols
739+
.emplace_back(
740+
std::make_unique<ol_symbol_impl_t>(std::move(GlobalObj)))
741+
.get();
742+
743+
return Error::success();
744+
}
745+
default:
746+
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
747+
"getSymbol kind enum '%i' is invalid", Kind);
748+
}
749+
}
750+
729751
} // namespace offload
730752
} // namespace llvm

offload/unittests/OffloadAPI/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ add_offload_unittest("init"
1919
target_compile_definitions("init.unittests" PRIVATE DISABLE_WRAPPER)
2020

2121
add_offload_unittest("kernel"
22-
kernel/olGetKernel.cpp
2322
kernel/olLaunchKernel.cpp)
2423

2524
add_offload_unittest("memory"
@@ -41,3 +40,6 @@ add_offload_unittest("queue"
4140
queue/olDestroyQueue.cpp
4241
queue/olGetQueueInfo.cpp
4342
queue/olGetQueueInfoSize.cpp)
43+
44+
add_offload_unittest("symbol"
45+
symbol/olGetSymbol.cpp)

offload/unittests/OffloadAPI/common/Fixtures.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ struct OffloadProgramTest : OffloadDeviceTest {
113113
struct OffloadKernelTest : OffloadProgramTest {
114114
void SetUp() override {
115115
RETURN_ON_FATAL_FAILURE(OffloadProgramTest::SetUp());
116-
ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel));
116+
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel));
117117
}
118118

119119
void TearDown() override {

offload/unittests/OffloadAPI/device_code/global.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <gpuintrin.h>
22
#include <stdint.h>
33

4+
[[gnu::visibility("default")]]
45
uint32_t global[64];
56

67
__gpu_kernel void write() {

offload/unittests/OffloadAPI/kernel/olGetKernel.cpp

Lines changed: 0 additions & 38 deletions
This file was deleted.

offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ struct LaunchKernelTestBase : OffloadQueueTest {
4040
struct LaunchSingleKernelTestBase : LaunchKernelTestBase {
4141
void SetUpKernel(const char *kernel) {
4242
RETURN_ON_FATAL_FAILURE(SetUpProgram(kernel));
43-
ASSERT_SUCCESS(olGetKernel(Program, kernel, &Kernel));
43+
ASSERT_SUCCESS(
44+
olGetSymbol(Program, kernel, OL_SYMBOL_KIND_KERNEL, &Kernel));
4445
}
4546

4647
ol_symbol_handle_t Kernel = nullptr;
@@ -67,7 +68,8 @@ struct LaunchMultipleKernelTestBase : LaunchKernelTestBase {
6768
Kernels.resize(kernels.size());
6869
size_t I = 0;
6970
for (auto K : kernels)
70-
ASSERT_SUCCESS(olGetKernel(Program, K, &Kernels[I++]));
71+
ASSERT_SUCCESS(
72+
olGetSymbol(Program, K, OL_SYMBOL_KIND_KERNEL, &Kernels[I++]));
7173
}
7274

7375
std::vector<ol_symbol_handle_t> Kernels;
@@ -223,6 +225,15 @@ TEST_P(olLaunchKernelGlobalTest, Success) {
223225
ASSERT_SUCCESS(olMemFree(Mem));
224226
}
225227

228+
TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) {
229+
ol_symbol_handle_t Global = nullptr;
230+
ASSERT_SUCCESS(
231+
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global));
232+
ASSERT_ERROR(
233+
OL_ERRC_SYMBOL_KIND,
234+
olLaunchKernel(Queue, Device, Global, nullptr, 0, &LaunchArgs, nullptr));
235+
}
236+
226237
TEST_P(olLaunchKernelGlobalCtorTest, Success) {
227238
void *Mem;
228239
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//===------- Offload API tests - olGetSymbol ---------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "../common/Fixtures.hpp"
10+
#include <OffloadAPI.h>
11+
#include <gtest/gtest.h>
12+
13+
using olGetSymbolKernelTest = OffloadProgramTest;
14+
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetSymbolKernelTest);
15+
16+
struct olGetSymbolGlobalTest : OffloadQueueTest {
17+
void SetUp() override {
18+
RETURN_ON_FATAL_FAILURE(OffloadQueueTest::SetUp());
19+
ASSERT_TRUE(TestEnvironment::loadDeviceBinary("global", Device, DeviceBin));
20+
ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
21+
ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(),
22+
DeviceBin->getBufferSize(), &Program));
23+
}
24+
25+
void TearDown() override {
26+
if (Program) {
27+
olDestroyProgram(Program);
28+
}
29+
RETURN_ON_FATAL_FAILURE(OffloadQueueTest::TearDown());
30+
}
31+
32+
std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
33+
ol_program_handle_t Program = nullptr;
34+
ol_kernel_launch_size_args_t LaunchArgs{};
35+
};
36+
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetSymbolGlobalTest);
37+
38+
TEST_P(olGetSymbolKernelTest, Success) {
39+
ol_symbol_handle_t Kernel = nullptr;
40+
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel));
41+
ASSERT_NE(Kernel, nullptr);
42+
}
43+
44+
TEST_P(olGetSymbolKernelTest, InvalidNullProgram) {
45+
ol_symbol_handle_t Kernel = nullptr;
46+
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
47+
olGetSymbol(nullptr, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel));
48+
}
49+
50+
TEST_P(olGetSymbolKernelTest, InvalidNullKernelPointer) {
51+
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
52+
olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, nullptr));
53+
}
54+
55+
TEST_P(olGetSymbolKernelTest, InvalidKernelName) {
56+
ol_symbol_handle_t Kernel = nullptr;
57+
ASSERT_ERROR(OL_ERRC_NOT_FOUND, olGetSymbol(Program, "invalid_kernel_name",
58+
OL_SYMBOL_KIND_KERNEL, &Kernel));
59+
}
60+
61+
TEST_P(olGetSymbolKernelTest, InvalidKind) {
62+
ol_symbol_handle_t Kernel = nullptr;
63+
ASSERT_ERROR(
64+
OL_ERRC_INVALID_ENUMERATION,
65+
olGetSymbol(Program, "foo", OL_SYMBOL_KIND_FORCE_UINT32, &Kernel));
66+
}
67+
68+
TEST_P(olGetSymbolGlobalTest, Success) {
69+
ol_symbol_handle_t Global = nullptr;
70+
ASSERT_SUCCESS(
71+
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global));
72+
ASSERT_NE(Global, nullptr);
73+
}
74+
75+
TEST_P(olGetSymbolGlobalTest, InvalidNullProgram) {
76+
ol_symbol_handle_t Global = nullptr;
77+
ASSERT_ERROR(
78+
OL_ERRC_INVALID_NULL_HANDLE,
79+
olGetSymbol(nullptr, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global));
80+
}
81+
82+
TEST_P(olGetSymbolGlobalTest, InvalidNullGlobalPointer) {
83+
ASSERT_ERROR(
84+
OL_ERRC_INVALID_NULL_POINTER,
85+
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, nullptr));
86+
}
87+
88+
TEST_P(olGetSymbolGlobalTest, InvalidGlobalName) {
89+
ol_symbol_handle_t Global = nullptr;
90+
ASSERT_ERROR(OL_ERRC_NOT_FOUND,
91+
olGetSymbol(Program, "invalid_global",
92+
OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global));
93+
}

0 commit comments

Comments
 (0)