-
Notifications
You must be signed in to change notification settings - Fork 237
uses current CUDAStream correctly #118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+57
−3
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This commit fixes GitHub issue pytorch/pytorch#157363 where custom CUDA kernels were not properly synchronized with PyTorch's CUDA stream when used with torch.compile in reduce-overhead mode. Changes: - Add #include <ATen/cuda/CUDAContext.h> for getCurrentCUDAStream() - Use at::cuda::getCurrentCUDAStream() to get PyTorch's current CUDA stream - Launch all kernels with the correct stream parameter The issue occurred because custom kernels launched on the default CUDA stream while PyTorch operations (like nn.Linear) run on PyTorch's managed stream. This created race conditions where custom kernels would execute before PyTorch operations completed, resulting in incorrect output values. With this fix, all custom kernels are properly synchronized with PyTorch's CUDA stream, ensuring correct execution order and preventing race conditions when used with torch.compile. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
Added comprehensive tests to verify the fix for GitHub issue pytorch/pytorch#157363: 1. test_compile_with_linear_layer: - Tests custom CUDA kernels with nn.Linear + torch.compile - Verifies correct behavior with various input sizes (1000, 5000, 10000) - Uses reduce-overhead mode to reproduce the original issue conditions 2. test_compile_custom_only: - Tests custom operations without linear layers - Ensures custom operations work correctly with torch.compile These tests ensure that custom CUDA kernels properly synchronize with PyTorch's CUDA stream when used with torch.compile, preventing race conditions that previously caused incorrect outputs. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
jansel
approved these changes
Jul 8, 2025
zou3519
approved these changes
Jul 8, 2025
zou3519
approved these changes
Jul 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see you have embraced claude code
zou3519
reviewed
Jul 8, 2025
zou3519
reviewed
Jul 8, 2025
Replace manual tolerance specification with self.assertEqual which automatically handles appropriate tolerances for tensor comparisons. This makes the tests more concise and follows PyTorch testing conventions. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes pytorch/pytorch#157363
Thanks to @vlejd for finding the issue, debugging and reporting it.