Skip to content

Commit 6d97d98

Browse files
authored
[SYCL] RTC support for AMD and Nvidia GPU targets (#18918)
This patch extends RTC support to AMD and Nvidia GPU targets. Additionally: * reinstate __SYCL_PROGRAM_METADATA_TAG_NEED_FINALIZATION tag, * split sycl.cpp RTC file to exclude IMF from the body of the main test.
1 parent 205aa71 commit 6d97d98

33 files changed

+410
-148
lines changed

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ target_include_directories(sycl-jit
6060
${LLVM_MAIN_INCLUDE_DIR}
6161
${LLVM_SPIRV_INCLUDE_DIRS}
6262
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/include
63+
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/lib
6364
${CMAKE_BINARY_DIR}/tools/clang/include
6465
)
6566
target_include_directories(sycl-jit

sycl-jit/jit-compiler/include/RTC.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,11 @@ class RTCResult {
176176

177177
/// Calculates a BLAKE3 hash of the pre-processed source string described by
178178
/// \p SourceFile (considering any additional \p IncludeFiles) and the
179-
/// concatenation of the \p UserArgs.
179+
/// concatenation of the \p UserArgs, for a given \p Format.
180180
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
181181
View<InMemoryFile> IncludeFiles,
182-
View<const char *> UserArgs);
182+
View<const char *> UserArgs,
183+
BinaryFormat Format);
183184

184185
/// Compiles, links against device libraries, and finalizes the device code in
185186
/// the source string described by \p SourceFile, considering any additional \p
@@ -191,10 +192,14 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
191192
///
192193
/// If \p SaveIR is true and \p CachedIR is empty, the LLVM module obtained from
193194
/// the frontend invocation is wrapped in bitcode format in the result object.
195+
///
196+
/// \p BinaryFormat describes the desired format of the compilation - which
197+
/// corresponds to the backend that is being targeted.
194198
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
195199
View<InMemoryFile> IncludeFiles,
196200
View<const char *> UserArgs,
197-
View<char> CachedIR, bool SaveIR);
201+
View<char> CachedIR, bool SaveIR,
202+
BinaryFormat Format);
198203

199204
/// Requests that the JIT binary referenced by \p Address is deleted from the
200205
/// `JITContext`.

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 122 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@
88

99
#include "DeviceCompilation.h"
1010
#include "ESIMD.h"
11+
#include "JITBinaryInfo.h"
12+
#include "translation/Translation.h"
1113

14+
#include <Driver/ToolChains/AMDGPU.h>
15+
#include <Driver/ToolChains/Cuda.h>
16+
#include <Driver/ToolChains/LazyDetector.h>
1217
#include <clang/Basic/DiagnosticDriver.h>
1318
#include <clang/Basic/Version.h>
1419
#include <clang/CodeGen/CodeGenAction.h>
1520
#include <clang/Driver/Compilation.h>
21+
#include <clang/Driver/Driver.h>
1622
#include <clang/Driver/Options.h>
1723
#include <clang/Frontend/ChainedDiagnosticConsumer.h>
1824
#include <clang/Frontend/CompilerInstance.h>
@@ -178,7 +184,8 @@ class RTCToolActionBase : public ToolAction {
178184
assert(!hasExecuted() && "Action should only be invoked on a single file");
179185

180186
// Create a compiler instance to handle the actual work.
181-
CompilerInstance Compiler(std::move(Invocation), std::move(PCHContainerOps));
187+
CompilerInstance Compiler(std::move(Invocation),
188+
std::move(PCHContainerOps));
182189
Compiler.setFileManager(Files);
183190
// Suppress summary with number of warnings and errors being printed to
184191
// stdout.
@@ -312,7 +319,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
312319
} // anonymous namespace
313320

314321
static void adjustArgs(const InputArgList &UserArgList,
315-
const std::string &DPCPPRoot,
322+
const std::string &DPCPPRoot, BinaryFormat Format,
316323
SmallVectorImpl<std::string> &CommandLine) {
317324
DerivedArgList DAL{UserArgList};
318325
const auto &OptTable = getDriverOptTable();
@@ -325,6 +332,23 @@ static void adjustArgs(const InputArgList &UserArgList,
325332
// unused argument warning.
326333
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Qunused_arguments));
327334

335+
if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
336+
auto [CPU, Features] =
337+
Translator::getTargetCPUAndFeatureAttrs(nullptr, "", Format);
338+
(void)Features;
339+
if (Format == BinaryFormat::AMDGCN) {
340+
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ),
341+
"amdgcn-amd-amdhsa");
342+
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_Xsycl_backend_EQ),
343+
"amdgcn-amd-amdhsa");
344+
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_offload_arch_EQ), CPU);
345+
} else {
346+
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ),
347+
"nvptx64-nvidia-cuda");
348+
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Xsycl_backend));
349+
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_cuda_gpu_arch_EQ), CPU);
350+
}
351+
}
328352
ArgStringList ASL;
329353
for_each(DAL, [&DAL, &ASL](Arg *A) { A->render(DAL, ASL); });
330354
for_each(UserArgList,
@@ -361,10 +385,9 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
361385
});
362386
}
363387

364-
Expected<std::string>
365-
jit_compiler::calculateHash(InMemoryFile SourceFile,
366-
View<InMemoryFile> IncludeFiles,
367-
const InputArgList &UserArgList) {
388+
Expected<std::string> jit_compiler::calculateHash(
389+
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
390+
const InputArgList &UserArgList, BinaryFormat Format) {
368391
TimeTraceScope TTS{"calculateHash"};
369392

370393
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -373,7 +396,7 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
373396
}
374397

375398
SmallVector<std::string> CommandLine;
376-
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
399+
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
377400

378401
FixedCompilationDatabase DB{".", CommandLine};
379402
ClangTool Tool{DB, {SourceFile.Path}};
@@ -399,11 +422,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
399422
return createStringError("Calculating source hash failed");
400423
}
401424

402-
Expected<ModuleUPtr>
403-
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
404-
View<InMemoryFile> IncludeFiles,
405-
const InputArgList &UserArgList,
406-
std::string &BuildLog, LLVMContext &Context) {
425+
Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
426+
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
427+
const InputArgList &UserArgList, std::string &BuildLog,
428+
LLVMContext &Context, BinaryFormat Format) {
407429
TimeTraceScope TTS{"compileDeviceCode"};
408430

409431
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -412,7 +434,7 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
412434
}
413435

414436
SmallVector<std::string> CommandLine;
415-
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
437+
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
416438

417439
FixedCompilationDatabase DB{".", CommandLine};
418440
ClangTool Tool{DB, {SourceFile.Path}};
@@ -430,12 +452,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
430452
return createStringError(BuildLog);
431453
}
432454

433-
// This function is a simplified copy of the device library selection process in
434-
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
435-
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
455+
// This function is a simplified copy of the device library selection process
456+
// in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
457+
// GPU targets (no AoT, no native CPU). Keep in sync!
436458
static bool getDeviceLibraries(const ArgList &Args,
437459
SmallVectorImpl<std::string> &LibraryList,
438-
DiagnosticsEngine &Diags) {
460+
DiagnosticsEngine &Diags, BinaryFormat Format) {
461+
// For CUDA/HIP we only need devicelib, early exit here.
462+
if (Format == BinaryFormat::PTX) {
463+
LibraryList.push_back(
464+
Args.MakeArgString("devicelib-nvptx64-nvidia-cuda.bc"));
465+
return false;
466+
} else if (Format == BinaryFormat::AMDGCN) {
467+
LibraryList.push_back(Args.MakeArgString("devicelib-amdgcn-amd-amdhsa.bc"));
468+
return false;
469+
}
470+
439471
struct DeviceLibOptInfo {
440472
StringRef DeviceLibName;
441473
StringRef DeviceLibOption;
@@ -540,7 +572,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
540572

541573
Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
542574
const InputArgList &UserArgList,
543-
std::string &BuildLog) {
575+
std::string &BuildLog,
576+
BinaryFormat Format) {
544577
TimeTraceScope TTS{"linkDeviceLibraries"};
545578

546579
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -555,11 +588,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
555588
/* ShouldOwnClient=*/false);
556589

557590
SmallVector<std::string> LibNames;
558-
bool FoundUnknownLib = getDeviceLibraries(UserArgList, LibNames, Diags);
591+
const bool FoundUnknownLib =
592+
getDeviceLibraries(UserArgList, LibNames, Diags, Format);
559593
if (FoundUnknownLib) {
560594
return createStringError("Could not determine list of device libraries: %s",
561595
BuildLog.c_str());
562596
}
597+
const bool IsCudaHIP =
598+
Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
599+
if (IsCudaHIP) {
600+
// Based on the OS and the format decide on the version of libspirv.
601+
// NOTE: this will be problematic if cross-compiling between OSes.
602+
std::string Libclc{"clc/"};
603+
Libclc.append(
604+
#ifdef _WIN32
605+
"remangled-l32-signed_char.libspirv-"
606+
#else
607+
"remangled-l64-signed_char.libspirv-"
608+
#endif
609+
);
610+
Libclc.append(Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda.bc"
611+
: "amdgcn-amd-amdhsa.bc");
612+
LibNames.push_back(Libclc);
613+
}
563614

564615
LLVMContext &Context = Module.getContext();
565616
for (const std::string &LibName : LibNames) {
@@ -577,6 +628,58 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
577628
}
578629
}
579630

631+
// For GPU targets we need to link against vendor provided libdevice.
632+
if (IsCudaHIP) {
633+
Triple T{Module.getTargetTriple()};
634+
Driver D{(Twine(DPCPPRoot) + "/bin/clang++").str(), T.getTriple(), Diags};
635+
auto [CPU, Features] =
636+
Translator::getTargetCPUAndFeatureAttrs(&Module, "", Format);
637+
(void)Features;
638+
// Helper lambda to link modules.
639+
auto LinkInLib = [&](const StringRef LibDevice) -> Error {
640+
ModuleUPtr LibDeviceModule;
641+
if (auto Error = loadBitcodeLibrary(LibDevice, Context)
642+
.moveInto(LibDeviceModule)) {
643+
return Error;
644+
}
645+
if (Linker::linkModules(Module, std::move(LibDeviceModule),
646+
Linker::LinkOnlyNeeded)) {
647+
return createStringError("Unable to link libdevice: %s",
648+
BuildLog.c_str());
649+
}
650+
return Error::success();
651+
};
652+
SmallVector<std::string, 12> LibDeviceFiles;
653+
if (Format == BinaryFormat::PTX) {
654+
// For NVPTX we can get away with CudaInstallationDetector.
655+
LazyDetector<CudaInstallationDetector> CudaInstallation{D, T,
656+
UserArgList};
657+
auto LibDevice = CudaInstallation->getLibDeviceFile(CPU);
658+
if (LibDevice.empty()) {
659+
return createStringError("Unable to find Cuda libdevice");
660+
}
661+
LibDeviceFiles.push_back(LibDevice);
662+
} else {
663+
// AMDGPU requires entire toolchain in order to provide all common bitcode
664+
// libraries.
665+
clang::driver::toolchains::ROCMToolChain TC(D, T, UserArgList);
666+
auto CommonDeviceLibs = TC.getCommonDeviceLibNames(
667+
UserArgList, CPU, Action::OffloadKind::OFK_SYCL, false);
668+
if (CommonDeviceLibs.empty()) {
669+
return createStringError("Unable to find ROCm common device libraries");
670+
}
671+
for (auto &Lib : CommonDeviceLibs) {
672+
LibDeviceFiles.push_back(Lib.Path);
673+
}
674+
}
675+
for (auto &LibDeviceFile : LibDeviceFiles) {
676+
// llvm::Error converts to false on success.
677+
if (auto Error = LinkInLib(LibDeviceFile)) {
678+
return Error;
679+
}
680+
}
681+
}
682+
580683
return Error::success();
581684
}
582685

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include "JITBinaryInfo.h"
1112
#include "RTC.h"
1213

1314
#include <llvm/ADT/SmallVector.h>
@@ -24,16 +25,17 @@ using ModuleUPtr = std::unique_ptr<llvm::Module>;
2425

2526
llvm::Expected<std::string>
2627
calculateHash(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
27-
const llvm::opt::InputArgList &UserArgList);
28+
const llvm::opt::InputArgList &UserArgList, BinaryFormat Format);
2829

2930
llvm::Expected<ModuleUPtr>
3031
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
3132
const llvm::opt::InputArgList &UserArgList,
32-
std::string &BuildLog, llvm::LLVMContext &Context);
33+
std::string &BuildLog, llvm::LLVMContext &Context,
34+
BinaryFormat Format);
3335

3436
llvm::Error linkDeviceLibraries(llvm::Module &Module,
3537
const llvm::opt::InputArgList &UserArgList,
36-
std::string &BuildLog);
38+
std::string &BuildLog, BinaryFormat Format);
3739

3840
using PostLinkResult = std::pair<RTCBundleInfo, llvm::SmallVector<ModuleUPtr>>;
3941
llvm::Expected<PostLinkResult>

sycl-jit/jit-compiler/lib/rtc/RTC.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "RTC.h"
10+
#include "JITBinaryInfo.h"
1011
#include "helper/ErrorHelper.h"
1112
#include "rtc/DeviceCompilation.h"
1213
#include "translation/SPIRVLLVMTranslation.h"
@@ -26,7 +27,8 @@ using namespace jit_compiler;
2627

2728
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
2829
View<InMemoryFile> IncludeFiles,
29-
View<const char *> UserArgs) {
30+
View<const char *> UserArgs,
31+
BinaryFormat Format) {
3032
llvm::opt::InputArgList UserArgList;
3133
if (auto Error = parseUserArgs(UserArgs).moveInto(UserArgList)) {
3234
return errorTo<RTCHashResult>(std::move(Error),
@@ -36,8 +38,8 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
3638

3739
auto Start = std::chrono::high_resolution_clock::now();
3840
std::string Hash;
39-
if (auto Error =
40-
calculateHash(SourceFile, IncludeFiles, UserArgList).moveInto(Hash)) {
41+
if (auto Error = calculateHash(SourceFile, IncludeFiles, UserArgList, Format)
42+
.moveInto(Hash)) {
4143
return errorTo<RTCHashResult>(std::move(Error), "Hashing failed",
4244
/*IsHash=*/false);
4345
}
@@ -55,7 +57,8 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
5557
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
5658
View<InMemoryFile> IncludeFiles,
5759
View<const char *> UserArgs,
58-
View<char> CachedIR, bool SaveIR) {
60+
View<char> CachedIR, bool SaveIR,
61+
BinaryFormat Format) {
5962
llvm::LLVMContext Context;
6063
std::string BuildLog;
6164
configureDiagnostics(Context, BuildLog);
@@ -104,7 +107,7 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
104107
bool FromSource = !Module;
105108
if (FromSource) {
106109
if (auto Error = compileDeviceCode(SourceFile, IncludeFiles, UserArgList,
107-
BuildLog, Context)
110+
BuildLog, Context, Format)
108111
.moveInto(Module)) {
109112
return errorTo<RTCResult>(std::move(Error), "Device compilation failed");
110113
}
@@ -118,7 +121,8 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
118121
IR = RTCDeviceCodeIR{BCString.data(), BCString.data() + BCString.size()};
119122
}
120123

121-
if (auto Error = linkDeviceLibraries(*Module, UserArgList, BuildLog)) {
124+
if (auto Error =
125+
linkDeviceLibraries(*Module, UserArgList, BuildLog, Format)) {
122126
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
123127
}
124128

@@ -131,9 +135,9 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
131135

132136
for (auto [DevImgInfo, Module] :
133137
llvm::zip_equal(BundleInfo.DevImgInfos, Modules)) {
134-
if (auto Error = Translator::translate(*Module, JITContext::getInstance(),
135-
BinaryFormat::SPIRV)
136-
.moveInto(DevImgInfo.BinaryInfo)) {
138+
if (auto Error =
139+
Translator::translate(*Module, JITContext::getInstance(), Format)
140+
.moveInto(DevImgInfo.BinaryInfo)) {
137141
return errorTo<RTCResult>(std::move(Error), "SPIR-V translation failed");
138142
}
139143
}

0 commit comments

Comments
 (0)