-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[Bugfix] Fix CUDA arch flags for MoE permute #21426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,294 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
"""Tests for the shuffle_rows function | ||
Run `pytest tests/kernels/test_shuffle_rows.py`. | ||
""" | ||
|
||
import pytest | ||
import torch | ||
|
||
from vllm._custom_ops import shuffle_rows | ||
from vllm.platforms import current_platform | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", [1, 16, 64, 128, 256, 512, 1024]) | ||
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 1024, 2048, 4096]) | ||
@pytest.mark.parametrize("dtype", | ||
[torch.float16, torch.bfloat16, torch.float32]) | ||
def test_shuffle_rows_basic(num_tokens: int, hidden_size: int, | ||
dtype: torch.dtype): | ||
"""Test basic functionality of shuffle_rows with various tensor sizes and | ||
dtypes.""" | ||
if not current_platform.is_cuda(): | ||
pytest.skip("shuffle_rows requires CUDA") | ||
|
||
# Create input tensor | ||
input_tensor = torch.randn(num_tokens, | ||
hidden_size, | ||
device="cuda", | ||
dtype=dtype) | ||
|
||
# Create a simple permutation map (identity mapping) | ||
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) | ||
|
||
# Test shuffle_rows | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
|
||
# With identity mapping, output should be identical to input | ||
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0) | ||
|
||
# Check output shape | ||
assert output.shape == (num_tokens, hidden_size) | ||
assert output.dtype == dtype | ||
assert output.device == input_tensor.device | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", [16, 64, 128]) | ||
@pytest.mark.parametrize("hidden_size", [128, 512, 1024]) | ||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
def test_shuffle_rows_permutation(num_tokens: int, hidden_size: int, | ||
dtype: torch.dtype): | ||
"""Test shuffle_rows with actual permutation.""" | ||
if not current_platform.is_cuda(): | ||
pytest.skip("shuffle_rows requires CUDA") | ||
|
||
# Create input tensor | ||
input_tensor = torch.randn(num_tokens, | ||
hidden_size, | ||
device="cuda", | ||
dtype=dtype) | ||
|
||
# Create a reverse permutation map | ||
dst2src_map = torch.arange(num_tokens - 1, | ||
-1, | ||
-1, | ||
device="cuda", | ||
dtype=torch.int32) | ||
|
||
# Test shuffle_rows | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
|
||
# Check that the output is the reverse of the input | ||
expected_output = torch.flip(input_tensor, dims=[0]) | ||
torch.testing.assert_close(output, expected_output, atol=1e-6, rtol=1e-5) | ||
|
||
# Check output shape and properties | ||
assert output.shape == (num_tokens, hidden_size) | ||
assert output.dtype == dtype | ||
assert output.device == input_tensor.device | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", [32, 64]) | ||
@pytest.mark.parametrize("hidden_size", [256, 512]) | ||
def test_shuffle_rows_expansion(num_tokens: int, hidden_size: int): | ||
"""Test shuffle_rows with expansion (more output tokens than input | ||
tokens).""" | ||
if not current_platform.is_cuda(): | ||
pytest.skip("shuffle_rows requires CUDA") | ||
|
||
dtype = torch.float16 | ||
|
||
# Create input tensor | ||
input_tensor = torch.randn(num_tokens, | ||
hidden_size, | ||
device="cuda", | ||
dtype=dtype) | ||
|
||
# Create a mapping that duplicates some tokens (expansion) | ||
expanded_size = num_tokens * 2 | ||
dst2src_map = torch.randint(0, | ||
num_tokens, (expanded_size, ), | ||
device="cuda", | ||
dtype=torch.int32) | ||
|
||
# Test shuffle_rows | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
|
||
# Check output shape | ||
assert output.shape == (expanded_size, hidden_size) | ||
assert output.dtype == dtype | ||
assert output.device == input_tensor.device | ||
|
||
# Verify that each output row matches the corresponding input row | ||
for i in range(expanded_size): | ||
src_idx = dst2src_map[i].item() | ||
torch.testing.assert_close(output[i], | ||
input_tensor[src_idx], | ||
atol=1e-6, | ||
rtol=1e-5) | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", [16, 64]) | ||
@pytest.mark.parametrize("hidden_size", [128, 512]) | ||
def test_shuffle_rows_random_permutation(num_tokens: int, hidden_size: int): | ||
"""Test shuffle_rows with random permutation.""" | ||
if not current_platform.is_cuda(): | ||
pytest.skip("shuffle_rows requires CUDA") | ||
|
||
dtype = torch.float16 | ||
|
||
# Set seed for reproducibility | ||
torch.manual_seed(42) | ||
|
||
# Create input tensor | ||
input_tensor = torch.randn(num_tokens, | ||
hidden_size, | ||
device="cuda", | ||
dtype=dtype) | ||
|
||
# Create a random permutation map | ||
dst2src_map = torch.randperm(num_tokens, device="cuda", dtype=torch.int32) | ||
|
||
# Test shuffle_rows | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
|
||
# Check output shape and properties | ||
assert output.shape == (num_tokens, hidden_size) | ||
assert output.dtype == dtype | ||
assert output.device == input_tensor.device | ||
|
||
# Verify that each output row matches the corresponding input row | ||
for i in range(num_tokens): | ||
src_idx = dst2src_map[i].item() | ||
torch.testing.assert_close(output[i], | ||
input_tensor[src_idx], | ||
atol=1e-6, | ||
rtol=1e-5) | ||
|
||
|
||
def test_shuffle_rows_edge_cases(): | ||
"""Test shuffle_rows with edge cases.""" | ||
if not current_platform.is_cuda(): | ||
pytest.skip("shuffle_rows requires CUDA") | ||
|
||
dtype = torch.float16 | ||
|
||
# Test with single token | ||
input_tensor = torch.randn(1, 128, device="cuda", dtype=dtype) | ||
dst2src_map = torch.tensor([0], device="cuda", dtype=torch.int32) | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0) | ||
|
||
# Test with single feature dimension | ||
input_tensor = torch.randn(16, 1, device="cuda", dtype=dtype) | ||
dst2src_map = torch.arange(16, device="cuda", dtype=torch.int32) | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0) | ||
|
||
|
||
def test_shuffle_rows_moe_like_scenario(): | ||
"""Test shuffle_rows in a scenario similar to MoE usage.""" | ||
if not current_platform.is_cuda(): | ||
pytest.skip("shuffle_rows requires CUDA") | ||
|
||
dtype = torch.float16 | ||
batch_size = 32 | ||
hidden_size = 1024 | ||
topk = 2 | ||
|
||
# Simulate input tokens | ||
input_tensor = torch.randn(batch_size, | ||
hidden_size, | ||
device="cuda", | ||
dtype=dtype) | ||
|
||
# Simulate expert assignment (each token goes to topk experts) | ||
# This creates a mapping where tokens are duplicated for multiple experts | ||
total_tokens = batch_size * topk | ||
dst2src_map = torch.zeros(total_tokens, device="cuda", dtype=torch.int32) | ||
|
||
# Fill the mapping to simulate MoE token distribution | ||
for i in range(batch_size): | ||
for k in range(topk): | ||
dst2src_map[i * topk + k] = i | ||
|
||
# Test shuffle_rows | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
|
||
# Check output shape | ||
assert output.shape == (total_tokens, hidden_size) | ||
assert output.dtype == dtype | ||
assert output.device == input_tensor.device | ||
|
||
# Verify that tokens are correctly duplicated | ||
for i in range(batch_size): | ||
for k in range(topk): | ||
output_idx = i * topk + k | ||
torch.testing.assert_close(output[output_idx], | ||
input_tensor[i], | ||
atol=1e-6, | ||
rtol=1e-5) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", | ||
[torch.float16, torch.bfloat16, torch.float32]) | ||
def test_shuffle_rows_dtype_consistency(dtype: torch.dtype): | ||
"""Test that shuffle_rows preserves dtype correctly.""" | ||
if not current_platform.is_cuda(): | ||
pytest.skip("shuffle_rows requires CUDA") | ||
|
||
num_tokens = 64 | ||
hidden_size = 512 | ||
|
||
# Create input tensor with specific dtype | ||
input_tensor = torch.randn(num_tokens, | ||
hidden_size, | ||
device="cuda", | ||
dtype=dtype) | ||
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) | ||
|
||
# Test shuffle_rows | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
|
||
# Verify dtype is preserved | ||
assert output.dtype == dtype | ||
assert output.device == input_tensor.device | ||
torch.testing.assert_close(output, input_tensor, atol=1e-6, rtol=1e-5) | ||
|
||
|
||
def test_shuffle_rows_device_consistency(): | ||
"""Test that shuffle_rows maintains device consistency.""" | ||
if not current_platform.is_cuda(): | ||
pytest.skip("shuffle_rows requires CUDA") | ||
|
||
num_tokens = 32 | ||
hidden_size = 256 | ||
dtype = torch.float16 | ||
|
||
# Create input tensor on CUDA | ||
input_tensor = torch.randn(num_tokens, | ||
hidden_size, | ||
device="cuda", | ||
dtype=dtype) | ||
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) | ||
|
||
# Test shuffle_rows | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
|
||
# Verify device is maintained | ||
assert output.device == input_tensor.device | ||
assert output.device.type == "cuda" | ||
|
||
|
||
def test_shuffle_rows_contiguous_output(): | ||
"""Test that shuffle_rows produces contiguous output.""" | ||
if not current_platform.is_cuda(): | ||
pytest.skip("shuffle_rows requires CUDA") | ||
|
||
num_tokens = 64 | ||
hidden_size = 512 | ||
dtype = torch.float16 | ||
|
||
# Create input tensor | ||
input_tensor = torch.randn(num_tokens, | ||
hidden_size, | ||
device="cuda", | ||
dtype=dtype) | ||
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) | ||
|
||
# Test shuffle_rows | ||
output = shuffle_rows(input_tensor, dst2src_map) | ||
|
||
# Verify output is contiguous | ||
assert output.is_contiguous() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CUDA arch flags are now correctly set to
${CUDA_ARCHS}
. This ensures that the MoE permute kernel is compiled with the appropriate architecture flags, resolving the original error. This change is critical to ensure the kernel runs on the intended hardware.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess someone just copied it over, lol