|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +"""Tests for the shuffle_rows function |
| 4 | +
|
| 5 | +Run `pytest tests/kernels/test_shuffle_rows.py`. |
| 6 | +""" |
| 7 | + |
| 8 | +import pytest |
| 9 | +import torch |
| 10 | + |
| 11 | +from vllm._custom_ops import shuffle_rows |
| 12 | +from vllm.platforms import current_platform |
| 13 | + |
| 14 | + |
| 15 | +@pytest.mark.parametrize("num_tokens", [1, 16, 64, 128, 256, 512, 1024]) |
| 16 | +@pytest.mark.parametrize("hidden_size", [128, 256, 512, 1024, 2048, 4096]) |
| 17 | +@pytest.mark.parametrize("dtype", |
| 18 | + [torch.float16, torch.bfloat16, torch.float32]) |
| 19 | +def test_shuffle_rows_basic(num_tokens: int, hidden_size: int, |
| 20 | + dtype: torch.dtype): |
| 21 | + """Test basic functionality of shuffle_rows with various tensor sizes and |
| 22 | + dtypes.""" |
| 23 | + if not current_platform.is_cuda(): |
| 24 | + pytest.skip("shuffle_rows requires CUDA") |
| 25 | + |
| 26 | + # Create input tensor |
| 27 | + input_tensor = torch.randn(num_tokens, |
| 28 | + hidden_size, |
| 29 | + device="cuda", |
| 30 | + dtype=dtype) |
| 31 | + |
| 32 | + # Create a simple permutation map (identity mapping) |
| 33 | + dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) |
| 34 | + |
| 35 | + # Test shuffle_rows |
| 36 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 37 | + |
| 38 | + # With identity mapping, output should be identical to input |
| 39 | + torch.testing.assert_close(output, input_tensor, atol=0, rtol=0) |
| 40 | + |
| 41 | + # Check output shape |
| 42 | + assert output.shape == (num_tokens, hidden_size) |
| 43 | + assert output.dtype == dtype |
| 44 | + assert output.device == input_tensor.device |
| 45 | + |
| 46 | + |
| 47 | +@pytest.mark.parametrize("num_tokens", [16, 64, 128]) |
| 48 | +@pytest.mark.parametrize("hidden_size", [128, 512, 1024]) |
| 49 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 50 | +def test_shuffle_rows_permutation(num_tokens: int, hidden_size: int, |
| 51 | + dtype: torch.dtype): |
| 52 | + """Test shuffle_rows with actual permutation.""" |
| 53 | + if not current_platform.is_cuda(): |
| 54 | + pytest.skip("shuffle_rows requires CUDA") |
| 55 | + |
| 56 | + # Create input tensor |
| 57 | + input_tensor = torch.randn(num_tokens, |
| 58 | + hidden_size, |
| 59 | + device="cuda", |
| 60 | + dtype=dtype) |
| 61 | + |
| 62 | + # Create a reverse permutation map |
| 63 | + dst2src_map = torch.arange(num_tokens - 1, |
| 64 | + -1, |
| 65 | + -1, |
| 66 | + device="cuda", |
| 67 | + dtype=torch.int32) |
| 68 | + |
| 69 | + # Test shuffle_rows |
| 70 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 71 | + |
| 72 | + # Check that the output is the reverse of the input |
| 73 | + expected_output = torch.flip(input_tensor, dims=[0]) |
| 74 | + torch.testing.assert_close(output, expected_output, atol=1e-6, rtol=1e-5) |
| 75 | + |
| 76 | + # Check output shape and properties |
| 77 | + assert output.shape == (num_tokens, hidden_size) |
| 78 | + assert output.dtype == dtype |
| 79 | + assert output.device == input_tensor.device |
| 80 | + |
| 81 | + |
| 82 | +@pytest.mark.parametrize("num_tokens", [32, 64]) |
| 83 | +@pytest.mark.parametrize("hidden_size", [256, 512]) |
| 84 | +def test_shuffle_rows_expansion(num_tokens: int, hidden_size: int): |
| 85 | + """Test shuffle_rows with expansion (more output tokens than input |
| 86 | + tokens).""" |
| 87 | + if not current_platform.is_cuda(): |
| 88 | + pytest.skip("shuffle_rows requires CUDA") |
| 89 | + |
| 90 | + dtype = torch.float16 |
| 91 | + |
| 92 | + # Create input tensor |
| 93 | + input_tensor = torch.randn(num_tokens, |
| 94 | + hidden_size, |
| 95 | + device="cuda", |
| 96 | + dtype=dtype) |
| 97 | + |
| 98 | + # Create a mapping that duplicates some tokens (expansion) |
| 99 | + expanded_size = num_tokens * 2 |
| 100 | + dst2src_map = torch.randint(0, |
| 101 | + num_tokens, (expanded_size, ), |
| 102 | + device="cuda", |
| 103 | + dtype=torch.int32) |
| 104 | + |
| 105 | + # Test shuffle_rows |
| 106 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 107 | + |
| 108 | + # Check output shape |
| 109 | + assert output.shape == (expanded_size, hidden_size) |
| 110 | + assert output.dtype == dtype |
| 111 | + assert output.device == input_tensor.device |
| 112 | + |
| 113 | + # Verify that each output row matches the corresponding input row |
| 114 | + for i in range(expanded_size): |
| 115 | + src_idx = dst2src_map[i].item() |
| 116 | + torch.testing.assert_close(output[i], |
| 117 | + input_tensor[src_idx], |
| 118 | + atol=1e-6, |
| 119 | + rtol=1e-5) |
| 120 | + |
| 121 | + |
| 122 | +@pytest.mark.parametrize("num_tokens", [16, 64]) |
| 123 | +@pytest.mark.parametrize("hidden_size", [128, 512]) |
| 124 | +def test_shuffle_rows_random_permutation(num_tokens: int, hidden_size: int): |
| 125 | + """Test shuffle_rows with random permutation.""" |
| 126 | + if not current_platform.is_cuda(): |
| 127 | + pytest.skip("shuffle_rows requires CUDA") |
| 128 | + |
| 129 | + dtype = torch.float16 |
| 130 | + |
| 131 | + # Set seed for reproducibility |
| 132 | + torch.manual_seed(42) |
| 133 | + |
| 134 | + # Create input tensor |
| 135 | + input_tensor = torch.randn(num_tokens, |
| 136 | + hidden_size, |
| 137 | + device="cuda", |
| 138 | + dtype=dtype) |
| 139 | + |
| 140 | + # Create a random permutation map |
| 141 | + dst2src_map = torch.randperm(num_tokens, device="cuda", dtype=torch.int32) |
| 142 | + |
| 143 | + # Test shuffle_rows |
| 144 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 145 | + |
| 146 | + # Check output shape and properties |
| 147 | + assert output.shape == (num_tokens, hidden_size) |
| 148 | + assert output.dtype == dtype |
| 149 | + assert output.device == input_tensor.device |
| 150 | + |
| 151 | + # Verify that each output row matches the corresponding input row |
| 152 | + for i in range(num_tokens): |
| 153 | + src_idx = dst2src_map[i].item() |
| 154 | + torch.testing.assert_close(output[i], |
| 155 | + input_tensor[src_idx], |
| 156 | + atol=1e-6, |
| 157 | + rtol=1e-5) |
| 158 | + |
| 159 | + |
| 160 | +def test_shuffle_rows_edge_cases(): |
| 161 | + """Test shuffle_rows with edge cases.""" |
| 162 | + if not current_platform.is_cuda(): |
| 163 | + pytest.skip("shuffle_rows requires CUDA") |
| 164 | + |
| 165 | + dtype = torch.float16 |
| 166 | + |
| 167 | + # Test with single token |
| 168 | + input_tensor = torch.randn(1, 128, device="cuda", dtype=dtype) |
| 169 | + dst2src_map = torch.tensor([0], device="cuda", dtype=torch.int32) |
| 170 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 171 | + torch.testing.assert_close(output, input_tensor, atol=0, rtol=0) |
| 172 | + |
| 173 | + # Test with single feature dimension |
| 174 | + input_tensor = torch.randn(16, 1, device="cuda", dtype=dtype) |
| 175 | + dst2src_map = torch.arange(16, device="cuda", dtype=torch.int32) |
| 176 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 177 | + torch.testing.assert_close(output, input_tensor, atol=0, rtol=0) |
| 178 | + |
| 179 | + |
| 180 | +def test_shuffle_rows_moe_like_scenario(): |
| 181 | + """Test shuffle_rows in a scenario similar to MoE usage.""" |
| 182 | + if not current_platform.is_cuda(): |
| 183 | + pytest.skip("shuffle_rows requires CUDA") |
| 184 | + |
| 185 | + dtype = torch.float16 |
| 186 | + batch_size = 32 |
| 187 | + hidden_size = 1024 |
| 188 | + topk = 2 |
| 189 | + |
| 190 | + # Simulate input tokens |
| 191 | + input_tensor = torch.randn(batch_size, |
| 192 | + hidden_size, |
| 193 | + device="cuda", |
| 194 | + dtype=dtype) |
| 195 | + |
| 196 | + # Simulate expert assignment (each token goes to topk experts) |
| 197 | + # This creates a mapping where tokens are duplicated for multiple experts |
| 198 | + total_tokens = batch_size * topk |
| 199 | + dst2src_map = torch.zeros(total_tokens, device="cuda", dtype=torch.int32) |
| 200 | + |
| 201 | + # Fill the mapping to simulate MoE token distribution |
| 202 | + for i in range(batch_size): |
| 203 | + for k in range(topk): |
| 204 | + dst2src_map[i * topk + k] = i |
| 205 | + |
| 206 | + # Test shuffle_rows |
| 207 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 208 | + |
| 209 | + # Check output shape |
| 210 | + assert output.shape == (total_tokens, hidden_size) |
| 211 | + assert output.dtype == dtype |
| 212 | + assert output.device == input_tensor.device |
| 213 | + |
| 214 | + # Verify that tokens are correctly duplicated |
| 215 | + for i in range(batch_size): |
| 216 | + for k in range(topk): |
| 217 | + output_idx = i * topk + k |
| 218 | + torch.testing.assert_close(output[output_idx], |
| 219 | + input_tensor[i], |
| 220 | + atol=1e-6, |
| 221 | + rtol=1e-5) |
| 222 | + |
| 223 | + |
| 224 | +@pytest.mark.parametrize("dtype", |
| 225 | + [torch.float16, torch.bfloat16, torch.float32]) |
| 226 | +def test_shuffle_rows_dtype_consistency(dtype: torch.dtype): |
| 227 | + """Test that shuffle_rows preserves dtype correctly.""" |
| 228 | + if not current_platform.is_cuda(): |
| 229 | + pytest.skip("shuffle_rows requires CUDA") |
| 230 | + |
| 231 | + num_tokens = 64 |
| 232 | + hidden_size = 512 |
| 233 | + |
| 234 | + # Create input tensor with specific dtype |
| 235 | + input_tensor = torch.randn(num_tokens, |
| 236 | + hidden_size, |
| 237 | + device="cuda", |
| 238 | + dtype=dtype) |
| 239 | + dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) |
| 240 | + |
| 241 | + # Test shuffle_rows |
| 242 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 243 | + |
| 244 | + # Verify dtype is preserved |
| 245 | + assert output.dtype == dtype |
| 246 | + assert output.device == input_tensor.device |
| 247 | + torch.testing.assert_close(output, input_tensor, atol=1e-6, rtol=1e-5) |
| 248 | + |
| 249 | + |
| 250 | +def test_shuffle_rows_device_consistency(): |
| 251 | + """Test that shuffle_rows maintains device consistency.""" |
| 252 | + if not current_platform.is_cuda(): |
| 253 | + pytest.skip("shuffle_rows requires CUDA") |
| 254 | + |
| 255 | + num_tokens = 32 |
| 256 | + hidden_size = 256 |
| 257 | + dtype = torch.float16 |
| 258 | + |
| 259 | + # Create input tensor on CUDA |
| 260 | + input_tensor = torch.randn(num_tokens, |
| 261 | + hidden_size, |
| 262 | + device="cuda", |
| 263 | + dtype=dtype) |
| 264 | + dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) |
| 265 | + |
| 266 | + # Test shuffle_rows |
| 267 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 268 | + |
| 269 | + # Verify device is maintained |
| 270 | + assert output.device == input_tensor.device |
| 271 | + assert output.device.type == "cuda" |
| 272 | + |
| 273 | + |
| 274 | +def test_shuffle_rows_contiguous_output(): |
| 275 | + """Test that shuffle_rows produces contiguous output.""" |
| 276 | + if not current_platform.is_cuda(): |
| 277 | + pytest.skip("shuffle_rows requires CUDA") |
| 278 | + |
| 279 | + num_tokens = 64 |
| 280 | + hidden_size = 512 |
| 281 | + dtype = torch.float16 |
| 282 | + |
| 283 | + # Create input tensor |
| 284 | + input_tensor = torch.randn(num_tokens, |
| 285 | + hidden_size, |
| 286 | + device="cuda", |
| 287 | + dtype=dtype) |
| 288 | + dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) |
| 289 | + |
| 290 | + # Test shuffle_rows |
| 291 | + output = shuffle_rows(input_tensor, dst2src_map) |
| 292 | + |
| 293 | + # Verify output is contiguous |
| 294 | + assert output.is_contiguous() |
0 commit comments