Skip to content

Commit 92ae661

Browse files
joppermsommerlukas
andauthored
[SYCL][RTC] Add minimum amount of post-link functionality to extract symbol table and properties (#16109)
This PR implements symbol table and property set extraction from the runtime-compiled module, based on the skeleton used by `sycl-post-link`, assuming that no splitting of the module is required and/or requested. We add the necessary plumbing to pass this information to the runtime, which then reuses existing `sycl-jit` infrastructure to create the necessary data structures for device binary images. We don't use the properties *yet* during the creating of the kernel bundles, but the necessary information to feed the images into the program manager is there. --------- Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com> Co-authored-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent e127a2e commit 92ae661

File tree

16 files changed

+304
-41
lines changed

16 files changed

+304
-41
lines changed

sycl-jit/common/include/Kernel.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <cstdint>
1818
#include <cstring>
1919
#include <functional>
20+
#include <string_view>
2021
#include <type_traits>
2122

2223
namespace jit_compiler {
@@ -350,11 +351,60 @@ struct SYCLKernelInfo {
350351
: Name{KernelName}, Args{NumArgs}, Attributes{}, NDR{}, BinaryInfo{} {}
351352
};
352353

354+
// RTC-related datastructures
355+
// TODO: Consider moving into separate header.
356+
353357
struct InMemoryFile {
354358
const char *Path;
355359
const char *Contents;
356360
};
357361

362+
using RTCBundleBinaryInfo = SYCLKernelBinaryInfo;
363+
using FrozenSymbolTable = DynArray<sycl::detail::string>;
364+
365+
// Note: `FrozenPropertyValue` and `FrozenPropertySet` constructors take
366+
// `std::string_view` arguments instead of `const char *` because they will be
367+
// created from `llvm::SmallString`s, which don't contain the trailing '\0'
368+
// byte. Hence obtaining a C-string would cause an additional copy.
369+
370+
struct FrozenPropertyValue {
371+
sycl::detail::string Name;
372+
bool IsUIntValue;
373+
uint32_t UIntValue;
374+
DynArray<uint8_t> Bytes;
375+
376+
FrozenPropertyValue() = default;
377+
FrozenPropertyValue(FrozenPropertyValue &&) = default;
378+
FrozenPropertyValue &operator=(FrozenPropertyValue &&) = default;
379+
380+
FrozenPropertyValue(std::string_view Name, uint32_t Value)
381+
: Name{Name}, IsUIntValue{true}, UIntValue{Value}, Bytes{0} {}
382+
FrozenPropertyValue(std::string_view Name, const uint8_t *Ptr, size_t Size)
383+
: Name{Name}, IsUIntValue{false}, Bytes{Size} {
384+
std::memcpy(Bytes.begin(), Ptr, Size);
385+
}
386+
};
387+
388+
struct FrozenPropertySet {
389+
sycl::detail::string Name;
390+
DynArray<FrozenPropertyValue> Values;
391+
392+
FrozenPropertySet() = default;
393+
FrozenPropertySet(FrozenPropertySet &&) = default;
394+
FrozenPropertySet &operator=(FrozenPropertySet &&) = default;
395+
396+
FrozenPropertySet(std::string_view Name, size_t Size)
397+
: Name{Name}, Values{Size} {}
398+
};
399+
400+
using FrozenPropertyRegistry = DynArray<FrozenPropertySet>;
401+
402+
struct RTCBundleInfo {
403+
RTCBundleBinaryInfo BinaryInfo;
404+
FrozenSymbolTable SymbolTable;
405+
FrozenPropertyRegistry Properties;
406+
};
407+
358408
} // namespace jit_compiler
359409

360410
#endif // SYCL_FUSION_COMMON_KERNEL_H

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ add_llvm_library(sycl-jit
3131
Target
3232
TargetParser
3333
MC
34+
SYCLLowerIR
3435
${LLVM_TARGETS_TO_BUILD}
3536

3637
LINK_LIBS

sycl-jit/jit-compiler/include/KernelFusion.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,32 @@ class JITResult {
5656
sycl::detail::string ErrorMessage;
5757
};
5858

59+
class RTCResult {
60+
public:
61+
explicit RTCResult(const char *ErrorMessage)
62+
: Failed{true}, BundleInfo{}, ErrorMessage{ErrorMessage} {}
63+
64+
explicit RTCResult(RTCBundleInfo &&BundleInfo)
65+
: Failed{false}, BundleInfo{std::move(BundleInfo)}, ErrorMessage{} {}
66+
67+
bool failed() const { return Failed; }
68+
69+
const char *getErrorMessage() const {
70+
assert(failed() && "No error message present");
71+
return ErrorMessage.c_str();
72+
}
73+
74+
const RTCBundleInfo &getBundleInfo() const {
75+
assert(!failed() && "No bundle info");
76+
return BundleInfo;
77+
}
78+
79+
private:
80+
bool Failed;
81+
RTCBundleInfo BundleInfo;
82+
sycl::detail::string ErrorMessage;
83+
};
84+
5985
extern "C" {
6086

6187
#ifdef __clang__
@@ -77,7 +103,7 @@ KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
77103
const char *KernelName, jit_compiler::SYCLKernelBinaryInfo &BinInfo,
78104
View<unsigned char> SpecConstBlob);
79105

80-
KF_EXPORT_SYMBOL JITResult compileSYCL(InMemoryFile SourceFile,
106+
KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
81107
View<InMemoryFile> IncludeFiles,
82108
View<const char *> UserArgs);
83109

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ using namespace jit_compiler;
2525
using FusedFunction = helper::FusionHelper::FusedFunction;
2626
using FusedFunctionList = std::vector<FusedFunction>;
2727

28-
static JITResult errorToFusionResult(llvm::Error &&Err,
29-
const std::string &Msg) {
28+
template <typename ResultType>
29+
static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) {
3030
std::stringstream ErrMsg;
3131
ErrMsg << Msg << "\nDetailed information:\n";
3232
llvm::handleAllErrors(std::move(Err),
@@ -35,7 +35,7 @@ static JITResult errorToFusionResult(llvm::Error &&Err,
3535
// compiled without exception support.
3636
ErrMsg << "\t" << StrErr.getMessage() << "\n";
3737
});
38-
return JITResult{ErrMsg.str().c_str()};
38+
return ResultType{ErrMsg.str().c_str()};
3939
}
4040

4141
static std::vector<jit_compiler::NDRange>
@@ -95,7 +95,7 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
9595
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
9696
ModuleInfo.kernels());
9797
if (auto Error = ModOrError.takeError()) {
98-
return errorToFusionResult(std::move(Error), "Failed to load kernels");
98+
return errorTo<JITResult>(std::move(Error), "Failed to load kernels");
9999
}
100100
std::unique_ptr<llvm::Module> NewMod = std::move(*ModOrError);
101101
if (!fusion::FusionPipeline::runMaterializerPasses(
@@ -107,8 +107,8 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
107107
SYCLKernelInfo &MaterializerKernelInfo = *ModuleInfo.getKernelFor(KernelName);
108108
if (auto Error = translation::KernelTranslator::translateKernel(
109109
MaterializerKernelInfo, *NewMod, JITCtx, TargetFormat)) {
110-
return errorToFusionResult(std::move(Error),
111-
"Translation to output format failed");
110+
return errorTo<JITResult>(std::move(Error),
111+
"Translation to output format failed");
112112
}
113113

114114
return JITResult{MaterializerKernelInfo};
@@ -133,7 +133,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
133133
llvm::Expected<jit_compiler::FusedNDRange> FusedNDR =
134134
jit_compiler::FusedNDRange::get(NDRanges);
135135
if (llvm::Error Err = FusedNDR.takeError()) {
136-
return errorToFusionResult(std::move(Err), "Illegal ND-range combination");
136+
return errorTo<JITResult>(std::move(Err), "Illegal ND-range combination");
137137
}
138138

139139
if (!isTargetFormatSupported(TargetFormat)) {
@@ -180,7 +180,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
180180
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
181181
ModuleInfo.kernels());
182182
if (auto Error = ModOrError.takeError()) {
183-
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
183+
return errorTo<JITResult>(std::move(Error), "SPIR-V translation failed");
184184
}
185185
std::unique_ptr<llvm::Module> LLVMMod = std::move(*ModOrError);
186186

@@ -197,8 +197,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
197197
llvm::Expected<std::unique_ptr<llvm::Module>> NewModOrError =
198198
helper::FusionHelper::addFusedKernel(LLVMMod.get(), FusedKernelList);
199199
if (auto Error = NewModOrError.takeError()) {
200-
return errorToFusionResult(std::move(Error),
201-
"Insertion of fused kernel stub failed");
200+
return errorTo<JITResult>(std::move(Error),
201+
"Insertion of fused kernel stub failed");
202202
}
203203
std::unique_ptr<llvm::Module> NewMod = std::move(*NewModOrError);
204204

@@ -221,8 +221,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
221221

222222
if (auto Error = translation::KernelTranslator::translateKernel(
223223
FusedKernelInfo, *NewMod, JITCtx, TargetFormat)) {
224-
return errorToFusionResult(std::move(Error),
225-
"Translation to output format failed");
224+
return errorTo<JITResult>(std::move(Error),
225+
"Translation to output format failed");
226226
}
227227

228228
FusedKernelInfo.NDR = FusedNDR->getNDR();
@@ -234,37 +234,47 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
234234
return JITResult{FusedKernelInfo};
235235
}
236236

237-
extern "C" KF_EXPORT_SYMBOL JITResult
237+
extern "C" KF_EXPORT_SYMBOL RTCResult
238238
compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
239239
View<const char *> UserArgs) {
240240
auto UserArgListOrErr = parseUserArgs(UserArgs);
241241
if (!UserArgListOrErr) {
242-
return errorToFusionResult(UserArgListOrErr.takeError(),
243-
"Parsing of user arguments failed");
242+
return errorTo<RTCResult>(UserArgListOrErr.takeError(),
243+
"Parsing of user arguments failed");
244244
}
245245
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);
246246

247247
auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList);
248248
if (!ModuleOrErr) {
249-
return errorToFusionResult(ModuleOrErr.takeError(),
250-
"Device compilation failed");
249+
return errorTo<RTCResult>(ModuleOrErr.takeError(),
250+
"Device compilation failed");
251251
}
252252

253253
std::unique_ptr<llvm::LLVMContext> Context;
254254
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
255255
Context.reset(&Module->getContext());
256256

257257
if (auto Error = linkDeviceLibraries(*Module, UserArgList)) {
258-
return errorToFusionResult(std::move(Error), "Device linking failed");
258+
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
259259
}
260260

261-
SYCLKernelInfo Kernel;
262-
if (auto Error = translation::KernelTranslator::translateKernel(
263-
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV)) {
264-
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
261+
auto BundleInfoOrError = performPostLink(*Module, UserArgList);
262+
if (!BundleInfoOrError) {
263+
return errorTo<RTCResult>(BundleInfoOrError.takeError(),
264+
"Post-link phase failed");
265+
}
266+
auto BundleInfo = std::move(*BundleInfoOrError);
267+
268+
auto BinaryInfoOrError =
269+
translation::KernelTranslator::translateBundleToSPIRV(
270+
*Module, JITContext::getInstance());
271+
if (!BinaryInfoOrError) {
272+
return errorTo<RTCResult>(BinaryInfoOrError.takeError(),
273+
"SPIR-V translation failed");
265274
}
275+
BundleInfo.BinaryInfo = std::move(*BinaryInfoOrError);
266276

267-
return JITResult{Kernel};
277+
return RTCResult{std::move(BundleInfo)};
268278
}
269279

270280
extern "C" KF_EXPORT_SYMBOL void resetJITConfiguration() {

0 commit comments

Comments
 (0)