8
8
9
9
#include " DeviceCompilation.h"
10
10
#include " ESIMD.h"
11
+ #include " JITBinaryInfo.h"
12
+ #include " translation/Translation.h"
11
13
14
+ #include < Driver/ToolChains/AMDGPU.h>
15
+ #include < Driver/ToolChains/Cuda.h>
16
+ #include < Driver/ToolChains/LazyDetector.h>
12
17
#include < clang/Basic/DiagnosticDriver.h>
13
18
#include < clang/Basic/Version.h>
14
19
#include < clang/CodeGen/CodeGenAction.h>
15
20
#include < clang/Driver/Compilation.h>
21
+ #include < clang/Driver/Driver.h>
16
22
#include < clang/Driver/Options.h>
17
23
#include < clang/Frontend/ChainedDiagnosticConsumer.h>
18
24
#include < clang/Frontend/CompilerInstance.h>
@@ -178,7 +184,8 @@ class RTCToolActionBase : public ToolAction {
178
184
assert (!hasExecuted () && " Action should only be invoked on a single file" );
179
185
180
186
// Create a compiler instance to handle the actual work.
181
- CompilerInstance Compiler (std::move (Invocation), std::move (PCHContainerOps));
187
+ CompilerInstance Compiler (std::move (Invocation),
188
+ std::move (PCHContainerOps));
182
189
Compiler.setFileManager (Files);
183
190
// Suppress summary with number of warnings and errors being printed to
184
191
// stdout.
@@ -312,7 +319,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
312
319
} // anonymous namespace
313
320
314
321
static void adjustArgs (const InputArgList &UserArgList,
315
- const std::string &DPCPPRoot,
322
+ const std::string &DPCPPRoot, BinaryFormat Format,
316
323
SmallVectorImpl<std::string> &CommandLine) {
317
324
DerivedArgList DAL{UserArgList};
318
325
const auto &OptTable = getDriverOptTable ();
@@ -325,6 +332,23 @@ static void adjustArgs(const InputArgList &UserArgList,
325
332
// unused argument warning.
326
333
DAL.AddFlagArg (nullptr , OptTable.getOption (OPT_Qunused_arguments));
327
334
335
+ if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
336
+ auto [CPU, Features] =
337
+ Translator::getTargetCPUAndFeatureAttrs (nullptr , " " , Format);
338
+ (void )Features;
339
+ if (Format == BinaryFormat::AMDGCN) {
340
+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_fsycl_targets_EQ),
341
+ " amdgcn-amd-amdhsa" );
342
+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_Xsycl_backend_EQ),
343
+ " amdgcn-amd-amdhsa" );
344
+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_offload_arch_EQ), CPU);
345
+ } else {
346
+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_fsycl_targets_EQ),
347
+ " nvptx64-nvidia-cuda" );
348
+ DAL.AddFlagArg (nullptr , OptTable.getOption (OPT_Xsycl_backend));
349
+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_cuda_gpu_arch_EQ), CPU);
350
+ }
351
+ }
328
352
ArgStringList ASL;
329
353
for_each (DAL, [&DAL, &ASL](Arg *A) { A->render (DAL, ASL); });
330
354
for_each (UserArgList,
@@ -361,10 +385,9 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
361
385
});
362
386
}
363
387
364
- Expected<std::string>
365
- jit_compiler::calculateHash (InMemoryFile SourceFile,
366
- View<InMemoryFile> IncludeFiles,
367
- const InputArgList &UserArgList) {
388
+ Expected<std::string> jit_compiler::calculateHash (
389
+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
390
+ const InputArgList &UserArgList, BinaryFormat Format) {
368
391
TimeTraceScope TTS{" calculateHash" };
369
392
370
393
const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -373,7 +396,7 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
373
396
}
374
397
375
398
SmallVector<std::string> CommandLine;
376
- adjustArgs (UserArgList, DPCPPRoot, CommandLine);
399
+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
377
400
378
401
FixedCompilationDatabase DB{" ." , CommandLine};
379
402
ClangTool Tool{DB, {SourceFile.Path }};
@@ -399,11 +422,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
399
422
return createStringError (" Calculating source hash failed" );
400
423
}
401
424
402
- Expected<ModuleUPtr>
403
- jit_compiler::compileDeviceCode (InMemoryFile SourceFile,
404
- View<InMemoryFile> IncludeFiles,
405
- const InputArgList &UserArgList,
406
- std::string &BuildLog, LLVMContext &Context) {
425
+ Expected<ModuleUPtr> jit_compiler::compileDeviceCode (
426
+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
427
+ const InputArgList &UserArgList, std::string &BuildLog,
428
+ LLVMContext &Context, BinaryFormat Format) {
407
429
TimeTraceScope TTS{" compileDeviceCode" };
408
430
409
431
const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -412,7 +434,7 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
412
434
}
413
435
414
436
SmallVector<std::string> CommandLine;
415
- adjustArgs (UserArgList, DPCPPRoot, CommandLine);
437
+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
416
438
417
439
FixedCompilationDatabase DB{" ." , CommandLine};
418
440
ClangTool Tool{DB, {SourceFile.Path }};
@@ -430,12 +452,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
430
452
return createStringError (BuildLog);
431
453
}
432
454
433
- // This function is a simplified copy of the device library selection process in
434
- // `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
435
- // (no AoT, no third-party GPUs , no native CPU). Keep in sync!
455
+ // This function is a simplified copy of the device library selection process
456
+ // in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
457
+ // GPU targets ( no AoT , no native CPU). Keep in sync!
436
458
static bool getDeviceLibraries (const ArgList &Args,
437
459
SmallVectorImpl<std::string> &LibraryList,
438
- DiagnosticsEngine &Diags) {
460
+ DiagnosticsEngine &Diags, BinaryFormat Format) {
461
+ // For CUDA/HIP we only need devicelib, early exit here.
462
+ if (Format == BinaryFormat::PTX) {
463
+ LibraryList.push_back (
464
+ Args.MakeArgString (" devicelib-nvptx64-nvidia-cuda.bc" ));
465
+ return false ;
466
+ } else if (Format == BinaryFormat::AMDGCN) {
467
+ LibraryList.push_back (Args.MakeArgString (" devicelib-amdgcn-amd-amdhsa.bc" ));
468
+ return false ;
469
+ }
470
+
439
471
struct DeviceLibOptInfo {
440
472
StringRef DeviceLibName;
441
473
StringRef DeviceLibOption;
@@ -540,7 +572,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
540
572
541
573
Error jit_compiler::linkDeviceLibraries (llvm::Module &Module,
542
574
const InputArgList &UserArgList,
543
- std::string &BuildLog) {
575
+ std::string &BuildLog,
576
+ BinaryFormat Format) {
544
577
TimeTraceScope TTS{" linkDeviceLibraries" };
545
578
546
579
const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -555,11 +588,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
555
588
/* ShouldOwnClient=*/ false );
556
589
557
590
SmallVector<std::string> LibNames;
558
- bool FoundUnknownLib = getDeviceLibraries (UserArgList, LibNames, Diags);
591
+ const bool FoundUnknownLib =
592
+ getDeviceLibraries (UserArgList, LibNames, Diags, Format);
559
593
if (FoundUnknownLib) {
560
594
return createStringError (" Could not determine list of device libraries: %s" ,
561
595
BuildLog.c_str ());
562
596
}
597
+ const bool IsCudaHIP =
598
+ Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
599
+ if (IsCudaHIP) {
600
+ // Based on the OS and the format decide on the version of libspirv.
601
+ // NOTE: this will be problematic if cross-compiling between OSes.
602
+ std::string Libclc{" clc/" };
603
+ Libclc.append (
604
+ #ifdef _WIN32
605
+ " remangled-l32-signed_char.libspirv-"
606
+ #else
607
+ " remangled-l64-signed_char.libspirv-"
608
+ #endif
609
+ );
610
+ Libclc.append (Format == BinaryFormat::PTX ? " nvptx64-nvidia-cuda.bc"
611
+ : " amdgcn-amd-amdhsa.bc" );
612
+ LibNames.push_back (Libclc);
613
+ }
563
614
564
615
LLVMContext &Context = Module.getContext ();
565
616
for (const std::string &LibName : LibNames) {
@@ -577,6 +628,58 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
577
628
}
578
629
}
579
630
631
+ // For GPU targets we need to link against vendor provided libdevice.
632
+ if (IsCudaHIP) {
633
+ Triple T{Module.getTargetTriple ()};
634
+ Driver D{(Twine (DPCPPRoot) + " /bin/clang++" ).str (), T.getTriple (), Diags};
635
+ auto [CPU, Features] =
636
+ Translator::getTargetCPUAndFeatureAttrs (&Module, " " , Format);
637
+ (void )Features;
638
+ // Helper lambda to link modules.
639
+ auto LinkInLib = [&](const StringRef LibDevice) -> Error {
640
+ ModuleUPtr LibDeviceModule;
641
+ if (auto Error = loadBitcodeLibrary (LibDevice, Context)
642
+ .moveInto (LibDeviceModule)) {
643
+ return Error;
644
+ }
645
+ if (Linker::linkModules (Module, std::move (LibDeviceModule),
646
+ Linker::LinkOnlyNeeded)) {
647
+ return createStringError (" Unable to link libdevice: %s" ,
648
+ BuildLog.c_str ());
649
+ }
650
+ return Error::success ();
651
+ };
652
+ SmallVector<std::string, 12 > LibDeviceFiles;
653
+ if (Format == BinaryFormat::PTX) {
654
+ // For NVPTX we can get away with CudaInstallationDetector.
655
+ LazyDetector<CudaInstallationDetector> CudaInstallation{D, T,
656
+ UserArgList};
657
+ auto LibDevice = CudaInstallation->getLibDeviceFile (CPU);
658
+ if (LibDevice.empty ()) {
659
+ return createStringError (" Unable to find Cuda libdevice" );
660
+ }
661
+ LibDeviceFiles.push_back (LibDevice);
662
+ } else {
663
+ // AMDGPU requires entire toolchain in order to provide all common bitcode
664
+ // libraries.
665
+ clang::driver::toolchains::ROCMToolChain TC (D, T, UserArgList);
666
+ auto CommonDeviceLibs = TC.getCommonDeviceLibNames (
667
+ UserArgList, CPU, Action::OffloadKind::OFK_SYCL, false );
668
+ if (CommonDeviceLibs.empty ()) {
669
+ return createStringError (" Unable to find ROCm common device libraries" );
670
+ }
671
+ for (auto &Lib : CommonDeviceLibs) {
672
+ LibDeviceFiles.push_back (Lib.Path );
673
+ }
674
+ }
675
+ for (auto &LibDeviceFile : LibDeviceFiles) {
676
+ // llvm::Error converts to false on success.
677
+ if (auto Error = LinkInLib (LibDeviceFile)) {
678
+ return Error;
679
+ }
680
+ }
681
+ }
682
+
580
683
return Error::success ();
581
684
}
582
685
0 commit comments