Skip to content

Commit 11e8b30

Browse files
authored
fix pybind11 errors in cuda extension example (#113)
1 parent 38ec45e commit 11e8b30

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

extension_cpp/csrc/cuda/muladd.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
#include <torch/extension.h>
1+
#include <ATen/Operators.h>
2+
#include <torch/all.h>
3+
#include <torch/library.h>
24

35
#include <cuda.h>
46
#include <cuda_runtime.h>
@@ -18,7 +20,7 @@ at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
1820
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
1921
at::Tensor a_contig = a.contiguous();
2022
at::Tensor b_contig = b.contiguous();
21-
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
23+
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
2224
const float* a_ptr = a_contig.data_ptr<float>();
2325
const float* b_ptr = b_contig.data_ptr<float>();
2426
float* result_ptr = result.data_ptr<float>();
@@ -41,7 +43,7 @@ at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
4143
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
4244
at::Tensor a_contig = a.contiguous();
4345
at::Tensor b_contig = b.contiguous();
44-
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
46+
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
4547
const float* a_ptr = a_contig.data_ptr<float>();
4648
const float* b_ptr = b_contig.data_ptr<float>();
4749
float* result_ptr = result.data_ptr<float>();

0 commit comments

Comments
 (0)