diff --git a/extension_cpp/csrc/cuda/muladd.cu b/extension_cpp/csrc/cuda/muladd.cu index b7ae5f7..769513b 100644 --- a/extension_cpp/csrc/cuda/muladd.cu +++ b/extension_cpp/csrc/cuda/muladd.cu @@ -4,6 +4,7 @@ #include #include +#include namespace extension_cpp { @@ -26,7 +27,8 @@ at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { float* result_ptr = result.data_ptr(); int numel = a_contig.numel(); - muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr); return result; } @@ -48,7 +50,8 @@ at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) { const float* b_ptr = b_contig.data_ptr(); float* result_ptr = result.data_ptr(); int numel = a_contig.numel(); - mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr); return result; } @@ -73,7 +76,8 @@ void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { const float* b_ptr = b_contig.data_ptr(); float* result_ptr = out.data_ptr(); int numel = a_contig.numel(); - add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr); } diff --git a/test/test_extension.py b/test/test_extension.py index 3b7e39c..f17d7da 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -6,6 +6,7 @@ from torch import Tensor from typing import Tuple import torch.nn.functional as F +import torch.nn as nn def reference_muladd(a, b, c): @@ -119,5 +120,54 @@ def test_opcheck_cuda(self): self._opcheck("cuda") +class TestTorchCompileStreamSync(TestCase): + """Test for GitHub issue pytorch/pytorch#157363 - stream synchronization with torch.compile""" + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_compile_with_linear_layer(self): + """Test custom CUDA kernels with nn.Linear + torch.compile (the original failing case)""" + + class Model(nn.Module): + def __init__(self, size): + super().__init__() + self.linear = nn.Linear(size, size, device="cuda", dtype=torch.float32) + + def forward(self, x): + return extension_cpp.ops.mymuladd(self.linear(x), self.linear(x), 0.0) + + # Test sizes that previously failed + for size in [1000, 5000, 10000]: + with self.subTest(size=size): + torch.manual_seed(42) + model = Model(size) + x = torch.randn((1, size), device="cuda", dtype=torch.float32) + + with torch.no_grad(): + expected = model(x) + compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + actual = compiled_model(x) + + self.assertEqual(actual, expected) + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_compile_custom_only(self): + """Test custom operations alone with torch.compile""" + + def model(x): + return extension_cpp.ops.mymuladd(x, x, 1.0) + + for size in [1000, 5000, 10000]: + with self.subTest(size=size): + torch.manual_seed(42) + x = torch.randn((size,), device="cuda", dtype=torch.float32) + + with torch.no_grad(): + expected = model(x) + compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + actual = compiled_model(x) + + self.assertEqual(actual, expected) + + if __name__ == "__main__": unittest.main()