@@ -25,8 +25,8 @@ using namespace jit_compiler;
25
25
using FusedFunction = helper::FusionHelper::FusedFunction;
26
26
using FusedFunctionList = std::vector<FusedFunction>;
27
27
28
- static JITResult errorToFusionResult (llvm::Error &&Err,
29
- const std::string &Msg) {
28
+ template < typename ResultType>
29
+ static ResultType errorTo (llvm::Error &&Err, const std::string &Msg) {
30
30
std::stringstream ErrMsg;
31
31
ErrMsg << Msg << " \n Detailed information:\n " ;
32
32
llvm::handleAllErrors (std::move (Err),
@@ -35,7 +35,7 @@ static JITResult errorToFusionResult(llvm::Error &&Err,
35
35
// compiled without exception support.
36
36
ErrMsg << " \t " << StrErr.getMessage () << " \n " ;
37
37
});
38
- return JITResult {ErrMsg.str ().c_str ()};
38
+ return ResultType {ErrMsg.str ().c_str ()};
39
39
}
40
40
41
41
static std::vector<jit_compiler::NDRange>
@@ -95,7 +95,7 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
95
95
translation::KernelTranslator::loadKernels (*JITCtx.getLLVMContext (),
96
96
ModuleInfo.kernels ());
97
97
if (auto Error = ModOrError.takeError ()) {
98
- return errorToFusionResult (std::move (Error), " Failed to load kernels" );
98
+ return errorTo<JITResult> (std::move (Error), " Failed to load kernels" );
99
99
}
100
100
std::unique_ptr<llvm::Module> NewMod = std::move (*ModOrError);
101
101
if (!fusion::FusionPipeline::runMaterializerPasses (
@@ -107,8 +107,8 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
107
107
SYCLKernelInfo &MaterializerKernelInfo = *ModuleInfo.getKernelFor (KernelName);
108
108
if (auto Error = translation::KernelTranslator::translateKernel (
109
109
MaterializerKernelInfo, *NewMod, JITCtx, TargetFormat)) {
110
- return errorToFusionResult (std::move (Error),
111
- " Translation to output format failed" );
110
+ return errorTo<JITResult> (std::move (Error),
111
+ " Translation to output format failed" );
112
112
}
113
113
114
114
return JITResult{MaterializerKernelInfo};
@@ -133,7 +133,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
133
133
llvm::Expected<jit_compiler::FusedNDRange> FusedNDR =
134
134
jit_compiler::FusedNDRange::get (NDRanges);
135
135
if (llvm::Error Err = FusedNDR.takeError ()) {
136
- return errorToFusionResult (std::move (Err), " Illegal ND-range combination" );
136
+ return errorTo<JITResult> (std::move (Err), " Illegal ND-range combination" );
137
137
}
138
138
139
139
if (!isTargetFormatSupported (TargetFormat)) {
@@ -180,7 +180,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
180
180
translation::KernelTranslator::loadKernels (*JITCtx.getLLVMContext (),
181
181
ModuleInfo.kernels ());
182
182
if (auto Error = ModOrError.takeError ()) {
183
- return errorToFusionResult (std::move (Error), " SPIR-V translation failed" );
183
+ return errorTo<JITResult> (std::move (Error), " SPIR-V translation failed" );
184
184
}
185
185
std::unique_ptr<llvm::Module> LLVMMod = std::move (*ModOrError);
186
186
@@ -197,8 +197,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
197
197
llvm::Expected<std::unique_ptr<llvm::Module>> NewModOrError =
198
198
helper::FusionHelper::addFusedKernel (LLVMMod.get (), FusedKernelList);
199
199
if (auto Error = NewModOrError.takeError ()) {
200
- return errorToFusionResult (std::move (Error),
201
- " Insertion of fused kernel stub failed" );
200
+ return errorTo<JITResult> (std::move (Error),
201
+ " Insertion of fused kernel stub failed" );
202
202
}
203
203
std::unique_ptr<llvm::Module> NewMod = std::move (*NewModOrError);
204
204
@@ -221,8 +221,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
221
221
222
222
if (auto Error = translation::KernelTranslator::translateKernel (
223
223
FusedKernelInfo, *NewMod, JITCtx, TargetFormat)) {
224
- return errorToFusionResult (std::move (Error),
225
- " Translation to output format failed" );
224
+ return errorTo<JITResult> (std::move (Error),
225
+ " Translation to output format failed" );
226
226
}
227
227
228
228
FusedKernelInfo.NDR = FusedNDR->getNDR ();
@@ -234,37 +234,47 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
234
234
return JITResult{FusedKernelInfo};
235
235
}
236
236
237
- extern " C" KF_EXPORT_SYMBOL JITResult
237
+ extern " C" KF_EXPORT_SYMBOL RTCResult
238
238
compileSYCL (InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
239
239
View<const char *> UserArgs) {
240
240
auto UserArgListOrErr = parseUserArgs (UserArgs);
241
241
if (!UserArgListOrErr) {
242
- return errorToFusionResult (UserArgListOrErr.takeError (),
243
- " Parsing of user arguments failed" );
242
+ return errorTo<RTCResult> (UserArgListOrErr.takeError (),
243
+ " Parsing of user arguments failed" );
244
244
}
245
245
llvm::opt::InputArgList UserArgList = std::move (*UserArgListOrErr);
246
246
247
247
auto ModuleOrErr = compileDeviceCode (SourceFile, IncludeFiles, UserArgList);
248
248
if (!ModuleOrErr) {
249
- return errorToFusionResult (ModuleOrErr.takeError (),
250
- " Device compilation failed" );
249
+ return errorTo<RTCResult> (ModuleOrErr.takeError (),
250
+ " Device compilation failed" );
251
251
}
252
252
253
253
std::unique_ptr<llvm::LLVMContext> Context;
254
254
std::unique_ptr<llvm::Module> Module = std::move (*ModuleOrErr);
255
255
Context.reset (&Module->getContext ());
256
256
257
257
if (auto Error = linkDeviceLibraries (*Module, UserArgList)) {
258
- return errorToFusionResult (std::move (Error), " Device linking failed" );
258
+ return errorTo<RTCResult> (std::move (Error), " Device linking failed" );
259
259
}
260
260
261
- SYCLKernelInfo Kernel;
262
- if (auto Error = translation::KernelTranslator::translateKernel (
263
- Kernel, *Module, JITContext::getInstance (), BinaryFormat::SPIRV)) {
264
- return errorToFusionResult (std::move (Error), " SPIR-V translation failed" );
261
+ auto BundleInfoOrError = performPostLink (*Module, UserArgList);
262
+ if (!BundleInfoOrError) {
263
+ return errorTo<RTCResult>(BundleInfoOrError.takeError (),
264
+ " Post-link phase failed" );
265
+ }
266
+ auto BundleInfo = std::move (*BundleInfoOrError);
267
+
268
+ auto BinaryInfoOrError =
269
+ translation::KernelTranslator::translateBundleToSPIRV (
270
+ *Module, JITContext::getInstance ());
271
+ if (!BinaryInfoOrError) {
272
+ return errorTo<RTCResult>(BinaryInfoOrError.takeError (),
273
+ " SPIR-V translation failed" );
265
274
}
275
+ BundleInfo.BinaryInfo = std::move (*BinaryInfoOrError);
266
276
267
- return JITResult{Kernel };
277
+ return RTCResult{ std::move (BundleInfo) };
268
278
}
269
279
270
280
extern " C" KF_EXPORT_SYMBOL void resetJITConfiguration () {
0 commit comments