Skip to content

Commit 0ec4969

Browse files
soumithclaude
andauthored
uses current CUDAStream correctly (#118)
* Fix CUDA stream synchronization in custom kernels This commit fixes GitHub issue pytorch/pytorch#157363 where custom CUDA kernels were not properly synchronized with PyTorch's CUDA stream when used with torch.compile in reduce-overhead mode. Changes: - Add #include <ATen/cuda/CUDAContext.h> for getCurrentCUDAStream() - Use at::cuda::getCurrentCUDAStream() to get PyTorch's current CUDA stream - Launch all kernels with the correct stream parameter The issue occurred because custom kernels launched on the default CUDA stream while PyTorch operations (like nn.Linear) run on PyTorch's managed stream. This created race conditions where custom kernels would execute before PyTorch operations completed, resulting in incorrect output values. With this fix, all custom kernels are properly synchronized with PyTorch's CUDA stream, ensuring correct execution order and preventing race conditions when used with torch.compile. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Add tests for torch.compile stream synchronization fix Added comprehensive tests to verify the fix for GitHub issue pytorch/pytorch#157363: 1. test_compile_with_linear_layer: - Tests custom CUDA kernels with nn.Linear + torch.compile - Verifies correct behavior with various input sizes (1000, 5000, 10000) - Uses reduce-overhead mode to reproduce the original issue conditions 2. test_compile_custom_only: - Tests custom operations without linear layers - Ensures custom operations work correctly with torch.compile These tests ensure that custom CUDA kernels properly synchronize with PyTorch's CUDA stream when used with torch.compile, preventing race conditions that previously caused incorrect outputs. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Use self.assertEqual instead of torch.testing.assert_close in tests Replace manual tolerance specification with self.assertEqual which automatically handles appropriate tolerances for tensor comparisons. This makes the tests more concise and follows PyTorch testing conventions. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent e9a1c5d commit 0ec4969

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

extension_cpp/csrc/cuda/muladd.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <cuda.h>
66
#include <cuda_runtime.h>
7+
#include <ATen/cuda/CUDAContext.h>
78

89
namespace extension_cpp {
910

@@ -26,7 +27,8 @@ at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
2627
float* result_ptr = result.data_ptr<float>();
2728

2829
int numel = a_contig.numel();
29-
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
30+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
31+
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
3032
return result;
3133
}
3234

@@ -48,7 +50,8 @@ at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
4850
const float* b_ptr = b_contig.data_ptr<float>();
4951
float* result_ptr = result.data_ptr<float>();
5052
int numel = a_contig.numel();
51-
mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
53+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
54+
mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
5255
return result;
5356
}
5457

@@ -73,7 +76,8 @@ void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
7376
const float* b_ptr = b_contig.data_ptr<float>();
7477
float* result_ptr = out.data_ptr<float>();
7578
int numel = a_contig.numel();
76-
add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
79+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
80+
add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
7781
}
7882

7983

test/test_extension.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import Tensor
77
from typing import Tuple
88
import torch.nn.functional as F
9+
import torch.nn as nn
910

1011

1112
def reference_muladd(a, b, c):
@@ -119,5 +120,54 @@ def test_opcheck_cuda(self):
119120
self._opcheck("cuda")
120121

121122

123+
class TestTorchCompileStreamSync(TestCase):
124+
"""Test for GitHub issue pytorch/pytorch#157363 - stream synchronization with torch.compile"""
125+
126+
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
127+
def test_compile_with_linear_layer(self):
128+
"""Test custom CUDA kernels with nn.Linear + torch.compile (the original failing case)"""
129+
130+
class Model(nn.Module):
131+
def __init__(self, size):
132+
super().__init__()
133+
self.linear = nn.Linear(size, size, device="cuda", dtype=torch.float32)
134+
135+
def forward(self, x):
136+
return extension_cpp.ops.mymuladd(self.linear(x), self.linear(x), 0.0)
137+
138+
# Test sizes that previously failed
139+
for size in [1000, 5000, 10000]:
140+
with self.subTest(size=size):
141+
torch.manual_seed(42)
142+
model = Model(size)
143+
x = torch.randn((1, size), device="cuda", dtype=torch.float32)
144+
145+
with torch.no_grad():
146+
expected = model(x)
147+
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
148+
actual = compiled_model(x)
149+
150+
self.assertEqual(actual, expected)
151+
152+
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
153+
def test_compile_custom_only(self):
154+
"""Test custom operations alone with torch.compile"""
155+
156+
def model(x):
157+
return extension_cpp.ops.mymuladd(x, x, 1.0)
158+
159+
for size in [1000, 5000, 10000]:
160+
with self.subTest(size=size):
161+
torch.manual_seed(42)
162+
x = torch.randn((size,), device="cuda", dtype=torch.float32)
163+
164+
with torch.no_grad():
165+
expected = model(x)
166+
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
167+
actual = compiled_model(x)
168+
169+
self.assertEqual(actual, expected)
170+
171+
122172
if __name__ == "__main__":
123173
unittest.main()

0 commit comments

Comments
 (0)