Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit cb8282e

Browse files
tc::ATenCompilationUnit<tc::CudaTcExecutor> -> ATenCudaCompilationUnit
1 parent 767ade4 commit cb8282e

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

tensor_comprehensions/pybinds/pybind_engine.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ namespace python {
3636

3737
namespace py = pybind11;
3838

39+
using ATenCudaCompilationUnit = tc::ATenCompilationUnit<tc::CudaTcExecutor>;
40+
3941
PYBIND11_MODULE(tc, m) {
4042
m.def(
4143
"global_debug_init", // exposing the debugging flags to people
@@ -60,17 +62,13 @@ PYBIND11_MODULE(tc, m) {
6062
std::cerr << "\n PyTorch installation is missing, binary will be useless \n"
6163
<< e.what() << std::endl;
6264
}
63-
py::class_<tc::ATenCompilationUnit<tc::CudaTcExecutor>>(
64-
m, "ATenCompilationUnit")
65+
py::class_<ATenCudaCompilationUnit>(m, "ATenCompilationUnit")
6566
.def(py::init<>())
66-
.def(
67-
"define",
68-
&tc::ATenCompilationUnit<tc::CudaTcExecutor>::define,
69-
"Define the TC language")
67+
.def("define", &ATenCudaCompilationUnit::define, "Define the TC language")
7068
.def(
7169
"compile",
7270
[dlpack](
73-
tc::ATenCompilationUnit<tc::CudaTcExecutor>& instance,
71+
ATenCudaCompilationUnit& instance,
7472
const std::string& name,
7573
py::list& inputs,
7674
const tc::MappingOptions& options) {
@@ -80,7 +78,7 @@ PYBIND11_MODULE(tc, m) {
8078
.def(
8179
"run",
8280
[dlpack](
83-
tc::ATenCompilationUnit<tc::CudaTcExecutor>& instance,
81+
ATenCudaCompilationUnit& instance,
8482
const std::string& name,
8583
py::list& inputs,
8684
py::list& outputs,
@@ -95,7 +93,7 @@ PYBIND11_MODULE(tc, m) {
9593
.def(
9694
"uncheckedRun",
9795
[dlpack](
98-
tc::ATenCompilationUnit<tc::CudaTcExecutor>& instance,
96+
ATenCudaCompilationUnit& instance,
9997
py::list& inputs,
10098
py::list& outputs,
10199
size_t handle) {
@@ -107,7 +105,7 @@ PYBIND11_MODULE(tc, m) {
107105
.def(
108106
"inject_cuda",
109107
[dlpack](
110-
tc::ATenCompilationUnit<tc::CudaTcExecutor>& instance,
108+
ATenCudaCompilationUnit& instance,
111109
const std::string& name,
112110
const std::string& injectedKernelName,
113111
const std::string& cudaSource,

0 commit comments

Comments
 (0)