|
6 | 6 | from torch import Tensor
|
7 | 7 | from typing import Tuple
|
8 | 8 | import torch.nn.functional as F
|
| 9 | +import torch.nn as nn |
9 | 10 |
|
10 | 11 |
|
11 | 12 | def reference_muladd(a, b, c):
|
@@ -119,5 +120,54 @@ def test_opcheck_cuda(self):
|
119 | 120 | self._opcheck("cuda")
|
120 | 121 |
|
121 | 122 |
|
| 123 | +class TestTorchCompileStreamSync(TestCase): |
| 124 | + """Test for GitHub issue pytorch/pytorch#157363 - stream synchronization with torch.compile""" |
| 125 | + |
| 126 | + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| 127 | + def test_compile_with_linear_layer(self): |
| 128 | + """Test custom CUDA kernels with nn.Linear + torch.compile (the original failing case)""" |
| 129 | + |
| 130 | + class Model(nn.Module): |
| 131 | + def __init__(self, size): |
| 132 | + super().__init__() |
| 133 | + self.linear = nn.Linear(size, size, device="cuda", dtype=torch.float32) |
| 134 | + |
| 135 | + def forward(self, x): |
| 136 | + return extension_cpp.ops.mymuladd(self.linear(x), self.linear(x), 0.0) |
| 137 | + |
| 138 | + # Test sizes that previously failed |
| 139 | + for size in [1000, 5000, 10000]: |
| 140 | + with self.subTest(size=size): |
| 141 | + torch.manual_seed(42) |
| 142 | + model = Model(size) |
| 143 | + x = torch.randn((1, size), device="cuda", dtype=torch.float32) |
| 144 | + |
| 145 | + with torch.no_grad(): |
| 146 | + expected = model(x) |
| 147 | + compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) |
| 148 | + actual = compiled_model(x) |
| 149 | + |
| 150 | + self.assertEqual(actual, expected) |
| 151 | + |
| 152 | + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| 153 | + def test_compile_custom_only(self): |
| 154 | + """Test custom operations alone with torch.compile""" |
| 155 | + |
| 156 | + def model(x): |
| 157 | + return extension_cpp.ops.mymuladd(x, x, 1.0) |
| 158 | + |
| 159 | + for size in [1000, 5000, 10000]: |
| 160 | + with self.subTest(size=size): |
| 161 | + torch.manual_seed(42) |
| 162 | + x = torch.randn((size,), device="cuda", dtype=torch.float32) |
| 163 | + |
| 164 | + with torch.no_grad(): |
| 165 | + expected = model(x) |
| 166 | + compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) |
| 167 | + actual = compiled_model(x) |
| 168 | + |
| 169 | + self.assertEqual(actual, expected) |
| 170 | + |
| 171 | + |
122 | 172 | if __name__ == "__main__":
|
123 | 173 | unittest.main()
|
0 commit comments