Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 052a3ef

Browse files
nicolasvasilacheftynse
authored andcommitted
Generate PTX with LLVM trunk
This commit uses trunk clang, llvm-link, opt and llc to emit PTX.
1 parent c83c36d commit 052a3ef

File tree

10 files changed

+167
-26
lines changed

10 files changed

+167
-26
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ include(cmake/GetGitRevisionDescription.cmake)
258258
################################################################################
259259
# Variables for tc_config.h.in
260260
set(TC_DIR ${TC_DIR})
261+
execute_process(COMMAND ${CLANG_PREFIX}/bin/llvm-config --bindir OUTPUT_VARIABLE LLVM_BIN_DIR OUTPUT_STRIP_TRAILING_WHITESPACE)
262+
set(TC_LLVM_BIN_DIR ${LLVM_BIN_DIR})
261263
if (WITH_CUDA)
262264
# CUDA-specific variables for tc_config.h.in
263265
set(TC_WITH_CUDA 1)

python/tests/test_tc.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_mapping_options(self):
5959
.outerScheduleFusionStrategy("Preserve3Coincident"))
6060

6161
#
62-
# Simple TC example with explicit 'naive' compilation
62+
# Simple TC example with explicit 'naive' compilation with nvrtc (default)
6363
#
6464
def test_tc(self):
6565
A, B = torch.randn(100, device='cuda'), torch.randn(100, device='cuda')
@@ -72,6 +72,22 @@ def test_tc(self):
7272
C = add(A, B)
7373
tc.assert_almost_equal(C, torch.add(A, B), A, B)
7474

75+
#
76+
# Simple TC example with explicit 'naive' compilation with llvm
77+
#
78+
def test_tc_llvm(self):
79+
A, B = torch.randn(100, device='cuda'), torch.randn(100, device='cuda')
80+
tc.cuda_compiler('llvm')
81+
add = tc.compile(
82+
"def add(float(N) A, float(N) B) -> (C) { C(i) = A(i) + B(i) }",
83+
"add",
84+
'naive',
85+
A, B,
86+
)
87+
tc.cuda_compiler('nvrtc')
88+
C = add(A, B)
89+
tc.assert_almost_equal(C, torch.add(A, B), A, B)
90+
7591
#
7692
# Simple TC example without fallback but with tuning starting from
7793
# MappingOptions('naive')

tc/core/cuda/cuda_libraries.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ constexpr auto defines = R"C(
4343

4444
constexpr auto warpSyncFunctions = R"C(
4545
// Before CUDA 9, syncwarp is a noop since warps are always synchronized.
46-
#if __CUDACC_VER_MAJOR__ < 9
47-
__device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {}
46+
#if (!defined(__clang__) && __CUDACC_VER_MAJOR__ < 9) || \
47+
( defined(__clang__) && CUDA_VERSION < 9000)
48+
inline __device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {}
4849
#endif
4950
)C";
5051

tc/core/cuda/cuda_rtc.cc

Lines changed: 112 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
#include <cstdio>
17+
#include <cstdlib>
18+
#include <fstream>
19+
#include <iostream>
1620
#include <sstream>
1721
#include <string>
1822
#include <vector>
@@ -60,30 +64,89 @@ void checkOrCreateContext() {
6064
}
6165
}
6266

63-
std::unique_ptr<CudaRTCFunction> CudaRTCFunction::Compile(
64-
const std::string& name,
65-
const std::string& source) {
66-
std::unique_ptr<CudaRTCFunction> res(new CudaRTCFunction());
67-
res->specializedName = name;
68-
res->cleared_ = false;
69-
70-
if (FLAGS_debug_tc_mapper) {
71-
LOG(INFO) << "NVRTC function source:\n" << source;
72-
}
73-
// Actually do the compiling.
74-
nvrtcProgram prog;
75-
TC_NVRTC_CHECK(
76-
nvrtcCreateProgram(&prog, source.c_str(), nullptr, 0, nullptr, nullptr));
77-
78-
// Get the architecture of the current device.
79-
int device, minor, major;
67+
namespace {
68+
static std::tuple<int, int, int> getCudaArchitecture() {
69+
int device, major, minor;
8070
CUdevice deviceHandle;
8171
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaGetDevice(&device));
8272
TC_CUDA_DRIVERAPI_ENFORCE(cuDeviceGet(&deviceHandle, device));
8373
TC_CUDA_DRIVERAPI_ENFORCE(cuDeviceGetAttribute(
8474
&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, deviceHandle));
8575
TC_CUDA_DRIVERAPI_ENFORCE(cuDeviceGetAttribute(
8676
&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, deviceHandle));
77+
return std::tuple<int, int, int>(device, major, minor);
78+
}
79+
80+
static std::string llvmCompile(
81+
const std::string& name,
82+
const std::string& source) {
83+
int device, major, minor;
84+
std::tie(device, major, minor) = getCudaArchitecture();
85+
86+
std::string pat("/tmp/cudaXXXXXX");
87+
std::vector<char> ifn(pat.begin(), pat.end());
88+
TC_CHECK_GE(mkstemp(ifn.data()), 0); // string.c_str is const char*
89+
std::string inputFileName(ifn.begin(), ifn.end());
90+
// cstdio's std::remove to delete files
91+
tc::ScopeGuard sgi([&]() { std::remove(inputFileName.c_str()); });
92+
{
93+
std::ofstream ostream(inputFileName, std::ios::binary);
94+
ostream << source;
95+
}
96+
97+
std::string arch = "sm_" + std::to_string(major) + std::to_string(minor);
98+
std::string outputClangFile = inputFileName + "-clang.ll";
99+
std::string outputLinkFile = inputFileName + "-link.ll";
100+
std::string outputOptFile = inputFileName + "-opt.ll";
101+
std::string outputPtxFile = inputFileName + ".s";
102+
tc::ScopeGuard sgo([&]() {
103+
// cstdio's std::remove to delete files
104+
std::remove(outputClangFile.c_str());
105+
std::remove(outputLinkFile.c_str());
106+
std::remove(outputOptFile.c_str());
107+
std::remove(outputPtxFile.c_str());
108+
});
109+
110+
std::string cmdLlvmIr = std::string(TC_STRINGIFY(TC_LLVM_BIN_DIR)) +
111+
"/clang++ -x cuda " + inputFileName + " " + "--cuda-device-only " +
112+
"--cuda-gpu-arch=" + arch + " " +
113+
"--cuda-path=" + TC_STRINGIFY(TC_CUDA_TOOLKIT_ROOT_DIR) + " " + "-I" +
114+
TC_STRINGIFY(TC_CUDA_INCLUDE_DIR) + " " + "-I" +
115+
TC_STRINGIFY(TC_CUB_INCLUDE_DIR) + " " + tc::FLAGS_llvm_flags +
116+
" -DNVRTC_CUB=1 " + "-nocudalib -S -emit-llvm " + "-o " +
117+
outputClangFile;
118+
TC_CHECK_EQ(std::system(cmdLlvmIr.c_str()), 0) << cmdLlvmIr;
119+
120+
std::string cmdLlvmLink = std::string(TC_STRINGIFY(TC_LLVM_BIN_DIR)) +
121+
"/llvm-link " + outputClangFile + " " +
122+
TC_STRINGIFY(TC_CUDA_TOOLKIT_ROOT_DIR) +
123+
"/nvvm/libdevice/libdevice.*.bc " + "-S -o " + outputLinkFile;
124+
TC_CHECK_EQ(std::system(cmdLlvmLink.c_str()), 0) << cmdLlvmLink;
125+
126+
std::string cmdOpt = std::string(TC_STRINGIFY(TC_LLVM_BIN_DIR)) + "/opt " +
127+
"-internalize -internalize-public-api-list=" + name + " " +
128+
"-nvvm-reflect -O3 " + outputLinkFile + " -S -o " + outputOptFile;
129+
TC_CHECK_EQ(std::system(cmdOpt.c_str()), 0) << cmdOpt;
130+
131+
std::string cmdPtx = std::string(TC_STRINGIFY(TC_LLVM_BIN_DIR)) +
132+
"/llc -mcpu=" + arch + " " + outputOptFile + " -o " + outputPtxFile;
133+
TC_CHECK_EQ(std::system(cmdPtx.c_str()), 0) << cmdPtx;
134+
135+
std::ifstream stream(outputPtxFile);
136+
return std::string(
137+
(std::istreambuf_iterator<char>(stream)),
138+
std::istreambuf_iterator<char>());
139+
}
140+
141+
static std::string nvrtcCompile(
142+
const std::string& name,
143+
const std::string& source) {
144+
int device, major, minor;
145+
std::tie(device, major, minor) = getCudaArchitecture();
146+
147+
nvrtcProgram prog;
148+
TC_NVRTC_CHECK(
149+
nvrtcCreateProgram(&prog, source.c_str(), nullptr, 0, nullptr, nullptr));
87150

88151
std::stringstream arch_param;
89152
arch_param << "--gpu-architecture=compute_" << major << minor;
@@ -125,14 +188,38 @@ std::unique_ptr<CudaRTCFunction> CudaRTCFunction::Compile(
125188
}
126189
size_t ptx_size;
127190
TC_NVRTC_CHECK(nvrtcGetPTXSize(prog, &ptx_size));
128-
res->nvrtc_ptx = std::vector<char>(ptx_size);
129-
TC_NVRTC_CHECK(nvrtcGetPTX(prog, res->nvrtc_ptx.data()));
191+
std::vector<char> res(ptx_size);
192+
TC_NVRTC_CHECK(nvrtcGetPTX(prog, res.data()));
130193
TC_NVRTC_CHECK(nvrtcDestroyProgram(&prog));
194+
return std::string(res.begin(), res.end());
195+
}
196+
} // namespace
197+
198+
std::unique_ptr<CudaRTCFunction> CudaRTCFunction::Compile(
199+
const std::string& name,
200+
const std::string& source) {
201+
std::unique_ptr<CudaRTCFunction> res(new CudaRTCFunction());
202+
res->specializedName = name;
203+
res->cleared_ = false;
204+
if (FLAGS_debug_tc_mapper) {
205+
LOG(INFO) << "NVRTC function source:\n" << source;
206+
}
207+
if (FLAGS_cuda_compiler == "nvrtc") {
208+
res->ptx = nvrtcCompile(name, source);
209+
} else if (FLAGS_cuda_compiler == "llvm") {
210+
res->ptx = llvmCompile(name, source);
211+
} else if (FLAGS_cuda_compiler == "nvcc") {
212+
CHECK(false) << "NYI";
213+
// res->ptx = llvmCompile(name, source);
214+
} else {
215+
CHECK(false) << "Unknown CUDA compiler: " << FLAGS_cuda_compiler;
216+
}
131217
if (FLAGS_dump_ptx) {
132-
LOG(INFO) << "PTX:\n" << std::string(res->nvrtc_ptx.data());
218+
LOG(INFO) << "PTX:\n" << res->ptx;
133219
}
134220
return res;
135221
}
222+
136223
namespace {
137224

138225
template <typename T>
@@ -164,8 +251,11 @@ Duration CudaRTCFunction::Launch(
164251
// This call to cudaDeviceSynchronize implicitly creates a new context if
165252
// one is not bound to the current CPU.
166253
checkOrCreateContext();
167-
TC_CUDA_DRIVERAPI_ENFORCE(
168-
cuModuleLoadDataEx(&module, nvrtc_ptx.data(), 0, 0, 0));
254+
auto res = cuModuleLoadData(&module, ptx.c_str());
255+
if (res != CUDA_SUCCESS) {
256+
LOG(ERROR) << "Invalid PTX: " << ptx;
257+
}
258+
TC_CUDA_DRIVERAPI_ENFORCE(res);
169259
perGpuModule_.emplace(dev, module);
170260
TC_CUDA_DRIVERAPI_ENFORCE(
171261
cuModuleGetFunction(&function, module, specializedName.c_str()));

tc/core/cuda/cuda_rtc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class CudaRTCFunction {
6565
mutable std::unordered_map<size_t, CUmodule> perGpuModule_;
6666
mutable std::unordered_map<size_t, CUfunction> perGpuKernel_;
6767
std::string specializedName;
68-
std::vector<char> nvrtc_ptx;
68+
std::string ptx;
6969
bool cleared_;
7070
};
7171

tc/core/flags.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ DEFINE_bool(
3838
DEFINE_bool(dump_cuda, false, "Print the generated source");
3939
DEFINE_bool(dump_ptx, false, "Dump the generated PTX");
4040

41+
// PTX generation
42+
DEFINE_string(
43+
cuda_compiler,
44+
"nvrtc",
45+
"which compiler to use to emit ptx: nvrtc, llvm, nvcc (default [nvrtc])");
46+
DEFINE_string(
47+
llvm_flags,
48+
"-std=c++11 -O3 -ffast-math",
49+
"compiler flags to set when llvm is used");
50+
4151
// CPU codegen options
4252
DEFINE_bool(llvm_dump_before_opt, false, "Print IR before optimization");
4353
DEFINE_bool(llvm_dump_after_opt, false, "Print IR after optimization");

tc/core/flags.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ DECLARE_bool(debug_tuner);
3131
DECLARE_bool(dump_cuda);
3232
DECLARE_bool(dump_ptx);
3333

34+
// ptx generation
35+
DECLARE_string(cuda_compiler);
36+
DECLARE_string(llvm_flags);
37+
3438
// llvm codegen
3539
DECLARE_bool(llvm_dump_before_opt);
3640
DECLARE_bool(llvm_dump_after_opt);

tc/tc_config.h.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@
2424
#define TC_CUDA_TOOLKIT_ROOT_DIR @TC_CUDA_TOOLKIT_ROOT_DIR@
2525
#define TC_CUDA_INCLUDE_DIR @TC_CUDA_INCLUDE_DIR@
2626
#define TC_CUB_INCLUDE_DIR @TC_CUB_INCLUDE_DIR@
27+
#define TC_LLVM_BIN_DIR @TC_LLVM_BIN_DIR@
2728
// clang-format on

tensor_comprehensions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from tensor_comprehensions.tclib import debug_tuner
3030
from tensor_comprehensions.tclib import dump_cuda
3131
from tensor_comprehensions.tclib import dump_ptx
32+
from tensor_comprehensions.tclib import cuda_compiler
33+
from tensor_comprehensions.tclib import llvm_flags
3234

3335
from tensor_comprehensions.tclib import CompilationCache
3436
from tensor_comprehensions.tclib import MappingOptions
@@ -606,6 +608,8 @@ def make_autograd(forward_fun: Callable[[Iterable[torch.Tensor]], Iterable[torch
606608
'debug_tuner',
607609
'dump_cuda',
608610
'dump_ptx',
611+
'cuda_compiler',
612+
'llvm_flags',
609613
# Functions exposed by the tclib
610614
'compile',
611615
'autotune',

tensor_comprehensions/pybinds/tclib.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,19 @@ PYBIND11_MODULE(tclib, m) {
439439
});
440440
m.def("dump_cuda", [](bool dump_cuda) { tc::FLAGS_dump_cuda = dump_cuda; });
441441
m.def("dump_ptx", [](bool dump_ptx) { tc::FLAGS_dump_ptx = dump_ptx; });
442+
m.def(
443+
"cuda_compiler",
444+
[](const std::string& cuda_compiler) {
445+
tc::FLAGS_cuda_compiler = cuda_compiler;
446+
},
447+
gflags::DescribeOneFlag(
448+
gflags::GetCommandLineFlagInfoOrDie("cuda_compiler"))
449+
.c_str());
450+
m.def(
451+
"llvm_flags",
452+
[](const std::string& llvm_flags) { tc::FLAGS_llvm_flags = llvm_flags; },
453+
gflags::DescribeOneFlag(gflags::GetCommandLineFlagInfoOrDie("llvm_flags"))
454+
.c_str());
442455

443456
// Access the names of the defs in a TC string
444457
m.def("parse_defs", [](const std::string& tc) {

0 commit comments

Comments
 (0)