@@ -36,6 +36,8 @@ namespace python {
36
36
37
37
namespace py = pybind11;
38
38
39
+ using ATenCudaCompilationUnit = tc::ATenCompilationUnit<tc::CudaTcExecutor>;
40
+
39
41
PYBIND11_MODULE (tc, m) {
40
42
m.def (
41
43
" global_debug_init" , // exposing the debugging flags to people
@@ -60,17 +62,13 @@ PYBIND11_MODULE(tc, m) {
60
62
std::cerr << " \n PyTorch installation is missing, binary will be useless \n "
61
63
<< e.what () << std::endl;
62
64
}
63
- py::class_<tc::ATenCompilationUnit<tc::CudaTcExecutor>>(
64
- m, " ATenCompilationUnit" )
65
+ py::class_<ATenCudaCompilationUnit>(m, " ATenCompilationUnit" )
65
66
.def (py::init<>())
66
- .def (
67
- " define" ,
68
- &tc::ATenCompilationUnit<tc::CudaTcExecutor>::define,
69
- " Define the TC language" )
67
+ .def (" define" , &ATenCudaCompilationUnit::define, " Define the TC language" )
70
68
.def (
71
69
" compile" ,
72
70
[dlpack](
73
- tc::ATenCompilationUnit<tc::CudaTcExecutor> & instance,
71
+ ATenCudaCompilationUnit & instance,
74
72
const std::string& name,
75
73
py::list& inputs,
76
74
const tc::MappingOptions& options) {
@@ -80,7 +78,7 @@ PYBIND11_MODULE(tc, m) {
80
78
.def (
81
79
" run" ,
82
80
[dlpack](
83
- tc::ATenCompilationUnit<tc::CudaTcExecutor> & instance,
81
+ ATenCudaCompilationUnit & instance,
84
82
const std::string& name,
85
83
py::list& inputs,
86
84
py::list& outputs,
@@ -95,7 +93,7 @@ PYBIND11_MODULE(tc, m) {
95
93
.def (
96
94
" uncheckedRun" ,
97
95
[dlpack](
98
- tc::ATenCompilationUnit<tc::CudaTcExecutor> & instance,
96
+ ATenCudaCompilationUnit & instance,
99
97
py::list& inputs,
100
98
py::list& outputs,
101
99
size_t handle) {
@@ -107,7 +105,7 @@ PYBIND11_MODULE(tc, m) {
107
105
.def (
108
106
" inject_cuda" ,
109
107
[dlpack](
110
- tc::ATenCompilationUnit<tc::CudaTcExecutor> & instance,
108
+ ATenCudaCompilationUnit & instance,
111
109
const std::string& name,
112
110
const std::string& injectedKernelName,
113
111
const std::string& cudaSource,
0 commit comments