Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit ecd85a1

Browse files
nicolasvasilacheftynse
authored andcommitted
Generate PTX with NVCC
This commit allows using NVCC to emit PTX.
1 parent 052a3ef commit ecd85a1

File tree

6 files changed

+66
-2
lines changed

6 files changed

+66
-2
lines changed

python/tests/test_tc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,24 @@ def test_tc_llvm(self):
8484
'naive',
8585
A, B,
8686
)
87+
# Reset the cuda compiler back to nvrtc
88+
tc.cuda_compiler('nvrtc')
89+
C = add(A, B)
90+
tc.assert_almost_equal(C, torch.add(A, B), A, B)
91+
92+
#
93+
# Simple TC example with explicit 'naive' compilation with nvcc
94+
#
95+
def test_tc_nvcc(self):
96+
A, B = torch.randn(100, device='cuda'), torch.randn(100, device='cuda')
97+
tc.cuda_compiler('nvcc')
98+
add = tc.compile(
99+
"def add(float(N) A, float(N) B) -> (C) { C(i) = A(i) + B(i) }",
100+
"add",
101+
'naive',
102+
A, B,
103+
)
104+
# Reset the cuda compiler back to nvrtc
87105
tc.cuda_compiler('nvrtc')
88106
C = add(A, B)
89107
tc.assert_almost_equal(C, torch.add(A, B), A, B)

tc/core/cuda/cuda_rtc.cc

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,41 @@ static std::string llvmCompile(
138138
std::istreambuf_iterator<char>());
139139
}
140140

141+
static std::string nvccCompile(
142+
const std::string& name,
143+
const std::string& source) {
144+
int device, major, minor;
145+
std::tie(device, major, minor) = getCudaArchitecture();
146+
147+
std::string pat("/tmp/cudaXXXXXX");
148+
std::vector<char> ifn(pat.begin(), pat.end());
149+
TC_CHECK_GE(mkstemp(ifn.data()), 0); // string.c_str is const char*
150+
std::string inputFileName(ifn.begin(), ifn.end());
151+
// cstdio's std::remove to delete files
152+
tc::ScopeGuard sgi([&]() { std::remove(inputFileName.c_str()); });
153+
{
154+
std::ofstream ostream(inputFileName, std::ios::binary);
155+
ostream << source;
156+
}
157+
158+
std::string arch = "sm_" + std::to_string(major) + std::to_string(minor);
159+
std::string outputPtxFile = inputFileName + ".ptx";
160+
// cstdio's std::remove to delete files
161+
tc::ScopeGuard sgo([&]() { std::remove(outputPtxFile.c_str()); });
162+
163+
std::string cmdPtx = std::string(TC_STRINGIFY(TC_CUDA_TOOLKIT_ROOT_DIR)) +
164+
"/bin/nvcc -x cu " + inputFileName + " --gpu-architecture=" + arch + " " +
165+
"--ptx " + "-I" + TC_STRINGIFY(TC_CUDA_INCLUDE_DIR) + " " + "-I" +
166+
TC_STRINGIFY(TC_CUB_INCLUDE_DIR) + " " + tc::FLAGS_nvcc_flags + " -o " +
167+
outputPtxFile;
168+
TC_CHECK_EQ(std::system(cmdPtx.c_str()), 0) << cmdPtx;
169+
170+
std::ifstream stream(outputPtxFile);
171+
return std::string(
172+
(std::istreambuf_iterator<char>(stream)),
173+
std::istreambuf_iterator<char>());
174+
}
175+
141176
static std::string nvrtcCompile(
142177
const std::string& name,
143178
const std::string& source) {
@@ -209,8 +244,7 @@ std::unique_ptr<CudaRTCFunction> CudaRTCFunction::Compile(
209244
} else if (FLAGS_cuda_compiler == "llvm") {
210245
res->ptx = llvmCompile(name, source);
211246
} else if (FLAGS_cuda_compiler == "nvcc") {
212-
CHECK(false) << "NYI";
213-
// res->ptx = llvmCompile(name, source);
247+
res->ptx = nvccCompile(name, source);
214248
} else {
215249
CHECK(false) << "Unknown CUDA compiler: " << FLAGS_cuda_compiler;
216250
}

tc/core/flags.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ DEFINE_string(
4747
llvm_flags,
4848
"-std=c++11 -O3 -ffast-math",
4949
"compiler flags to set when llvm is used");
50+
DEFINE_string(
51+
nvcc_flags,
52+
"-std=c++11 -ptx -DNVRTC_CUB=1 --use_fast_math",
53+
"compiler flags to set when nvcc is used");
5054

5155
// CPU codegen options
5256
DEFINE_bool(llvm_dump_before_opt, false, "Print IR before optimization");

tc/core/flags.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ DECLARE_bool(dump_ptx);
3434
// ptx generation
3535
DECLARE_string(cuda_compiler);
3636
DECLARE_string(llvm_flags);
37+
DECLARE_string(nvcc_flags);
3738

3839
// llvm codegen
3940
DECLARE_bool(llvm_dump_before_opt);

tensor_comprehensions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensor_comprehensions.tclib import dump_ptx
3232
from tensor_comprehensions.tclib import cuda_compiler
3333
from tensor_comprehensions.tclib import llvm_flags
34+
from tensor_comprehensions.tclib import nvcc_flags
3435

3536
from tensor_comprehensions.tclib import CompilationCache
3637
from tensor_comprehensions.tclib import MappingOptions
@@ -610,6 +611,7 @@ def make_autograd(forward_fun: Callable[[Iterable[torch.Tensor]], Iterable[torch
610611
'dump_ptx',
611612
'cuda_compiler',
612613
'llvm_flags',
614+
'nvcc_flags',
613615
# Functions exposed by the tclib
614616
'compile',
615617
'autotune',

tensor_comprehensions/pybinds/tclib.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,11 @@ PYBIND11_MODULE(tclib, m) {
452452
[](const std::string& llvm_flags) { tc::FLAGS_llvm_flags = llvm_flags; },
453453
gflags::DescribeOneFlag(gflags::GetCommandLineFlagInfoOrDie("llvm_flags"))
454454
.c_str());
455+
m.def(
456+
"nvcc_flags",
457+
[](const std::string& nvcc_flags) { tc::FLAGS_nvcc_flags = nvcc_flags; },
458+
gflags::DescribeOneFlag(gflags::GetCommandLineFlagInfoOrDie("nvcc_flags"))
459+
.c_str());
455460

456461
// Access the names of the defs in a TC string
457462
m.def("parse_defs", [](const std::string& tc) {

0 commit comments

Comments
 (0)