Skip to content

uses current CUDAStream correctly #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions extension_cpp/csrc/cuda/muladd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>

namespace extension_cpp {

Expand All @@ -26,7 +27,8 @@ at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
float* result_ptr = result.data_ptr<float>();

int numel = a_contig.numel();
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
return result;
}

Expand All @@ -48,7 +50,8 @@ at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
int numel = a_contig.numel();
mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
return result;
}

Expand All @@ -73,7 +76,8 @@ void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = out.data_ptr<float>();
int numel = a_contig.numel();
add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
}


Expand Down
50 changes: 50 additions & 0 deletions test/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import Tensor
from typing import Tuple
import torch.nn.functional as F
import torch.nn as nn


def reference_muladd(a, b, c):
Expand Down Expand Up @@ -119,5 +120,54 @@ def test_opcheck_cuda(self):
self._opcheck("cuda")


class TestTorchCompileStreamSync(TestCase):
"""Test for GitHub issue pytorch/pytorch#157363 - stream synchronization with torch.compile"""

@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_compile_with_linear_layer(self):
"""Test custom CUDA kernels with nn.Linear + torch.compile (the original failing case)"""

class Model(nn.Module):
def __init__(self, size):
super().__init__()
self.linear = nn.Linear(size, size, device="cuda", dtype=torch.float32)

def forward(self, x):
return extension_cpp.ops.mymuladd(self.linear(x), self.linear(x), 0.0)

# Test sizes that previously failed
for size in [1000, 5000, 10000]:
with self.subTest(size=size):
torch.manual_seed(42)
model = Model(size)
x = torch.randn((1, size), device="cuda", dtype=torch.float32)

with torch.no_grad():
expected = model(x)
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
actual = compiled_model(x)

self.assertEqual(actual, expected)

@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_compile_custom_only(self):
"""Test custom operations alone with torch.compile"""

def model(x):
return extension_cpp.ops.mymuladd(x, x, 1.0)

for size in [1000, 5000, 10000]:
with self.subTest(size=size):
torch.manual_seed(42)
x = torch.randn((size,), device="cuda", dtype=torch.float32)

with torch.no_grad():
expected = model(x)
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
actual = compiled_model(x)

self.assertEqual(actual, expected)


if __name__ == "__main__":
unittest.main()
Loading