Skip to content

Commit 2fdeeef

Browse files
authored
[Offload] Add global variable address/size queries (#147972)
Add two new symbol info types for getting the bounds of a global variable. As well as a number of tests for reading/writing to it.
1 parent 2c0d563 commit 2fdeeef

File tree

6 files changed

+175
-3
lines changed

6 files changed

+175
-3
lines changed

offload/liboffload/API/Symbol.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def : Enum {
3939
let desc = "Supported symbol info.";
4040
let is_typed = 1;
4141
let etors = [
42-
TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">
42+
TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">,
43+
TaggedEtor<"GLOBAL_VARIABLE_ADDRESS", "void *", "The address in memory for this global variable.">,
44+
TaggedEtor<"GLOBAL_VARIABLE_SIZE", "size_t", "The size in bytes for this global variable.">,
4345
];
4446
}
4547

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,9 +753,28 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
753753
void *PropValue, size_t *PropSizeRet) {
754754
InfoWriter Info(PropSize, PropValue, PropSizeRet);
755755

756+
auto CheckKind = [&](ol_symbol_kind_t Required) {
757+
if (Symbol->Kind != Required) {
758+
std::string ErrBuffer;
759+
llvm::raw_string_ostream(ErrBuffer)
760+
<< PropName << ": Expected a symbol of Kind " << Required
761+
<< " but given a symbol of Kind " << Symbol->Kind;
762+
return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
763+
}
764+
return Plugin::success();
765+
};
766+
756767
switch (PropName) {
757768
case OL_SYMBOL_INFO_KIND:
758769
return Info.write<ol_symbol_kind_t>(Symbol->Kind);
770+
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
771+
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
772+
return Err;
773+
return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
774+
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
775+
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
776+
return Err;
777+
return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
759778
default:
760779
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
761780
"olGetSymbolInfo enum '%i' is invalid", PropName);

offload/tools/offload-tblgen/PrintGen.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,12 @@ inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_
7474
if (Type == "char[]") {
7575
OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
7676
} else {
77-
OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
78-
Type);
77+
if (Type == "void *")
78+
OS << formatv(TAB_2 "void * const * const tptr = (void * "
79+
"const * const)ptr;\n");
80+
else
81+
OS << formatv(
82+
TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", Type);
7983
// TODO: Handle other cases here
8084
OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
8185
if (Type.ends_with("*")) {

offload/unittests/OffloadAPI/memory/olMemcpy.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,32 @@
1313
using olMemcpyTest = OffloadQueueTest;
1414
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyTest);
1515

16+
struct olMemcpyGlobalTest : OffloadGlobalTest {
17+
void SetUp() override {
18+
RETURN_ON_FATAL_FAILURE(OffloadGlobalTest::SetUp());
19+
ASSERT_SUCCESS(
20+
olGetSymbol(Program, "read", OL_SYMBOL_KIND_KERNEL, &ReadKernel));
21+
ASSERT_SUCCESS(
22+
olGetSymbol(Program, "write", OL_SYMBOL_KIND_KERNEL, &WriteKernel));
23+
ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
24+
ASSERT_SUCCESS(olGetSymbolInfo(
25+
Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, sizeof(Addr), &Addr));
26+
27+
LaunchArgs.Dimensions = 1;
28+
LaunchArgs.GroupSize = {64, 1, 1};
29+
LaunchArgs.NumGroups = {1, 1, 1};
30+
31+
LaunchArgs.DynSharedMemory = 0;
32+
}
33+
34+
ol_kernel_launch_size_args_t LaunchArgs{};
35+
void *Addr;
36+
ol_symbol_handle_t ReadKernel;
37+
ol_symbol_handle_t WriteKernel;
38+
ol_queue_handle_t Queue;
39+
};
40+
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyGlobalTest);
41+
1642
TEST_P(olMemcpyTest, SuccessHtoD) {
1743
constexpr size_t Size = 1024;
1844
void *Alloc;
@@ -105,3 +131,82 @@ TEST_P(olMemcpyTest, SuccessSizeZero) {
105131
ASSERT_SUCCESS(
106132
olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr));
107133
}
134+
135+
TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
136+
void *SourceMem;
137+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
138+
64 * sizeof(uint32_t), &SourceMem));
139+
uint32_t *SourceData = (uint32_t *)SourceMem;
140+
for (auto I = 0; I < 64; I++)
141+
SourceData[I] = I;
142+
143+
void *DestMem;
144+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
145+
64 * sizeof(uint32_t), &DestMem));
146+
147+
ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
148+
64 * sizeof(uint32_t), nullptr));
149+
ASSERT_SUCCESS(olWaitQueue(Queue));
150+
ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
151+
64 * sizeof(uint32_t), nullptr));
152+
ASSERT_SUCCESS(olWaitQueue(Queue));
153+
154+
uint32_t *DestData = (uint32_t *)DestMem;
155+
for (uint32_t I = 0; I < 64; I++)
156+
ASSERT_EQ(DestData[I], I);
157+
158+
ASSERT_SUCCESS(olMemFree(DestMem));
159+
ASSERT_SUCCESS(olMemFree(SourceMem));
160+
}
161+
162+
TEST_P(olMemcpyGlobalTest, SuccessWrite) {
163+
void *SourceMem;
164+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
165+
LaunchArgs.GroupSize.x * sizeof(uint32_t),
166+
&SourceMem));
167+
uint32_t *SourceData = (uint32_t *)SourceMem;
168+
for (auto I = 0; I < 64; I++)
169+
SourceData[I] = I;
170+
171+
void *DestMem;
172+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
173+
LaunchArgs.GroupSize.x * sizeof(uint32_t),
174+
&DestMem));
175+
struct {
176+
void *Mem;
177+
} Args{DestMem};
178+
179+
ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
180+
64 * sizeof(uint32_t), nullptr));
181+
ASSERT_SUCCESS(olWaitQueue(Queue));
182+
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args),
183+
&LaunchArgs, nullptr));
184+
ASSERT_SUCCESS(olWaitQueue(Queue));
185+
186+
uint32_t *DestData = (uint32_t *)DestMem;
187+
for (uint32_t I = 0; I < 64; I++)
188+
ASSERT_EQ(DestData[I], I);
189+
190+
ASSERT_SUCCESS(olMemFree(DestMem));
191+
ASSERT_SUCCESS(olMemFree(SourceMem));
192+
}
193+
194+
TEST_P(olMemcpyGlobalTest, SuccessRead) {
195+
void *DestMem;
196+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
197+
LaunchArgs.GroupSize.x * sizeof(uint32_t),
198+
&DestMem));
199+
200+
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0,
201+
&LaunchArgs, nullptr));
202+
ASSERT_SUCCESS(olWaitQueue(Queue));
203+
ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
204+
64 * sizeof(uint32_t), nullptr));
205+
ASSERT_SUCCESS(olWaitQueue(Queue));
206+
207+
uint32_t *DestData = (uint32_t *)DestMem;
208+
for (uint32_t I = 0; I < 64; I++)
209+
ASSERT_EQ(DestData[I], I * 2);
210+
211+
ASSERT_SUCCESS(olMemFree(DestMem));
212+
}

offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,34 @@ TEST_P(olGetSymbolInfoGlobalTest, SuccessKind) {
3030
ASSERT_EQ(RetrievedKind, OL_SYMBOL_KIND_GLOBAL_VARIABLE);
3131
}
3232

33+
TEST_P(olGetSymbolInfoKernelTest, InvalidAddress) {
34+
void *RetrievedAddr;
35+
ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
36+
olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
37+
sizeof(RetrievedAddr), &RetrievedAddr));
38+
}
39+
40+
TEST_P(olGetSymbolInfoGlobalTest, SuccessAddress) {
41+
void *RetrievedAddr = nullptr;
42+
ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
43+
sizeof(RetrievedAddr), &RetrievedAddr));
44+
ASSERT_NE(RetrievedAddr, nullptr);
45+
}
46+
47+
TEST_P(olGetSymbolInfoKernelTest, InvalidSize) {
48+
size_t RetrievedSize;
49+
ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
50+
olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
51+
sizeof(RetrievedSize), &RetrievedSize));
52+
}
53+
54+
TEST_P(olGetSymbolInfoGlobalTest, SuccessSize) {
55+
size_t RetrievedSize = 0;
56+
ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
57+
sizeof(RetrievedSize), &RetrievedSize));
58+
ASSERT_EQ(RetrievedSize, 64 * sizeof(uint32_t));
59+
}
60+
3361
TEST_P(olGetSymbolInfoKernelTest, InvalidNullHandle) {
3462
ol_symbol_kind_t RetrievedKind;
3563
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,

offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessKind) {
2828
ASSERT_EQ(Size, sizeof(ol_symbol_kind_t));
2929
}
3030

31+
TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessAddress) {
32+
size_t Size = 0;
33+
ASSERT_SUCCESS(olGetSymbolInfoSize(
34+
Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, &Size));
35+
ASSERT_EQ(Size, sizeof(void *));
36+
}
37+
38+
TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessSize) {
39+
size_t Size = 0;
40+
ASSERT_SUCCESS(
41+
olGetSymbolInfoSize(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, &Size));
42+
ASSERT_EQ(Size, sizeof(size_t));
43+
}
44+
3145
TEST_P(olGetSymbolInfoSizeKernelTest, InvalidNullHandle) {
3246
size_t Size = 0;
3347
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,

0 commit comments

Comments
 (0)