From 69d3b2c13b7b1f586264e882bfb3a3dee055336f Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Tue, 8 Jul 2025 10:40:21 -0400 Subject: [PATCH 1/3] Fix CUDA stream synchronization in custom kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit fixes GitHub issue pytorch/pytorch#157363 where custom CUDA kernels were not properly synchronized with PyTorch's CUDA stream when used with torch.compile in reduce-overhead mode. Changes: - Add #include for getCurrentCUDAStream() - Use at::cuda::getCurrentCUDAStream() to get PyTorch's current CUDA stream - Launch all kernels with the correct stream parameter The issue occurred because custom kernels launched on the default CUDA stream while PyTorch operations (like nn.Linear) run on PyTorch's managed stream. This created race conditions where custom kernels would execute before PyTorch operations completed, resulting in incorrect output values. With this fix, all custom kernels are properly synchronized with PyTorch's CUDA stream, ensuring correct execution order and preventing race conditions when used with torch.compile. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- extension_cpp/csrc/cuda/muladd.cu | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/extension_cpp/csrc/cuda/muladd.cu b/extension_cpp/csrc/cuda/muladd.cu index b7ae5f7..769513b 100644 --- a/extension_cpp/csrc/cuda/muladd.cu +++ b/extension_cpp/csrc/cuda/muladd.cu @@ -4,6 +4,7 @@ #include #include +#include namespace extension_cpp { @@ -26,7 +27,8 @@ at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { float* result_ptr = result.data_ptr(); int numel = a_contig.numel(); - muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr); return result; } @@ -48,7 +50,8 @@ at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) { const float* b_ptr = b_contig.data_ptr(); float* result_ptr = result.data_ptr(); int numel = a_contig.numel(); - mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr); return result; } @@ -73,7 +76,8 @@ void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { const float* b_ptr = b_contig.data_ptr(); float* result_ptr = out.data_ptr(); int numel = a_contig.numel(); - add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr); } From 1c93ab28ea54132d5abe110bd4283439a3ec360b Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Tue, 8 Jul 2025 10:54:17 -0400 Subject: [PATCH 2/3] Add tests for torch.compile stream synchronization fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added comprehensive tests to verify the fix for GitHub issue pytorch/pytorch#157363: 1. test_compile_with_linear_layer: - Tests custom CUDA kernels with nn.Linear + torch.compile - Verifies correct behavior with various input sizes (1000, 5000, 10000) - Uses reduce-overhead mode to reproduce the original issue conditions 2. test_compile_custom_only: - Tests custom operations without linear layers - Ensures custom operations work correctly with torch.compile These tests ensure that custom CUDA kernels properly synchronize with PyTorch's CUDA stream when used with torch.compile, preventing race conditions that previously caused incorrect outputs. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- test/test_extension.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/test/test_extension.py b/test/test_extension.py index 3b7e39c..00d5462 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -6,6 +6,7 @@ from torch import Tensor from typing import Tuple import torch.nn.functional as F +import torch.nn as nn def reference_muladd(a, b, c): @@ -119,5 +120,54 @@ def test_opcheck_cuda(self): self._opcheck("cuda") +class TestTorchCompileStreamSync(TestCase): + """Test for GitHub issue pytorch/pytorch#157363 - stream synchronization with torch.compile""" + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_compile_with_linear_layer(self): + """Test custom CUDA kernels with nn.Linear + torch.compile (the original failing case)""" + + class Model(nn.Module): + def __init__(self, size): + super().__init__() + self.linear = nn.Linear(size, size, device="cuda", dtype=torch.float32) + + def forward(self, x): + return extension_cpp.ops.mymuladd(self.linear(x), self.linear(x), 0.0) + + # Test sizes that previously failed + for size in [1000, 5000, 10000]: + with self.subTest(size=size): + torch.manual_seed(42) + model = Model(size) + x = torch.randn((1, size), device="cuda", dtype=torch.float32) + + with torch.no_grad(): + expected = model(x) + compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + actual = compiled_model(x) + + torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5) + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_compile_custom_only(self): + """Test custom operations alone with torch.compile""" + + def model(x): + return extension_cpp.ops.mymuladd(x, x, 1.0) + + for size in [1000, 5000, 10000]: + with self.subTest(size=size): + torch.manual_seed(42) + x = torch.randn((size,), device="cuda", dtype=torch.float32) + + with torch.no_grad(): + expected = model(x) + compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + actual = compiled_model(x) + + torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": unittest.main() From 5513989aa28d7c4130d78cf1b7a1894c2f10a711 Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Tue, 8 Jul 2025 11:52:47 -0400 Subject: [PATCH 3/3] Use self.assertEqual instead of torch.testing.assert_close in tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace manual tolerance specification with self.assertEqual which automatically handles appropriate tolerances for tensor comparisons. This makes the tests more concise and follows PyTorch testing conventions. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- test/test_extension.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_extension.py b/test/test_extension.py index 00d5462..f17d7da 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -147,7 +147,7 @@ def forward(self, x): compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) actual = compiled_model(x) - torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5) + self.assertEqual(actual, expected) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_compile_custom_only(self): @@ -166,7 +166,7 @@ def model(x): compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) actual = compiled_model(x) - torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5) + self.assertEqual(actual, expected) if __name__ == "__main__":