Skip to content

Commit 725e7a4

Browse files
yanghailong-gityanghl
andauthored
Fix incorrect operation in add_kernel (used * instead of +); add tests for myadd_out (#117)
Co-authored-by: yanghl <yanghl@zetyun.com>
1 parent 11e8b30 commit 725e7a4

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

extension_cpp/csrc/cuda/muladd.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
5454

5555
__global__ void add_kernel(int numel, const float* a, const float* b, float* result) {
5656
int idx = blockIdx.x * blockDim.x + threadIdx.x;
57-
if (idx < numel) result[idx] = a[idx] * b[idx];
57+
if (idx < numel) result[idx] = a[idx] + b[idx];
5858
}
5959

6060
void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {

test/test_extension.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ def _test_correctness(self, device):
9797
expected = torch.add(*args[:2])
9898
torch.testing.assert_close(result, expected)
9999

100+
def test_correctness_cpu(self):
101+
self._test_correctness("cpu")
102+
103+
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
104+
def test_correctness_cuda(self):
105+
self._test_correctness("cuda")
106+
100107
def _opcheck(self, device):
101108
# Use opcheck to check for incorrect usage of operator registration APIs
102109
samples = self.sample_inputs(device, requires_grad=True)

0 commit comments

Comments
 (0)