diff --git a/extension_cpp/csrc/cuda/muladd.cu b/extension_cpp/csrc/cuda/muladd.cu index 570e5d1..b7ae5f7 100644 --- a/extension_cpp/csrc/cuda/muladd.cu +++ b/extension_cpp/csrc/cuda/muladd.cu @@ -54,7 +54,7 @@ at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) { __global__ void add_kernel(int numel, const float* a, const float* b, float* result) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel) result[idx] = a[idx] * b[idx]; + if (idx < numel) result[idx] = a[idx] + b[idx]; } void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { diff --git a/test/test_extension.py b/test/test_extension.py index 618f00b..3b7e39c 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -97,6 +97,13 @@ def _test_correctness(self, device): expected = torch.add(*args[:2]) torch.testing.assert_close(result, expected) + def test_correctness_cpu(self): + self._test_correctness("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_correctness_cuda(self): + self._test_correctness("cuda") + def _opcheck(self, device): # Use opcheck to check for incorrect usage of operator registration APIs samples = self.sample_inputs(device, requires_grad=True)