Skip to content

Commit 2ded067

Browse files
authored
[Bugfix] Fix CUDA arch flags for MoE permute (#21426)
Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent 13abd0e commit 2ded067

File tree

2 files changed

+297
-3
lines changed

2 files changed

+297
-3
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
635635
"in CUDA target architectures.")
636636
endif()
637637
endif()
638-
638+
639639
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
640640
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
641641
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
@@ -842,8 +842,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
842842
"csrc/moe/moe_permute_unpermute_op.cu")
843843

844844
set_gencode_flags_for_srcs(
845-
SRCS "${MARLIN_PERMUTE_SRC}"
846-
CUDA_ARCHS "${MOE_PERMUTE_ARCHS}")
845+
SRCS "${MOE_PERMUTE_SRC}"
846+
CUDA_ARCHS "${CUDA_ARCHS}")
847847

848848
list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
849849
endif()

tests/kernels/test_shuffle_rows.py

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Tests for the shuffle_rows function
4+
5+
Run `pytest tests/kernels/test_shuffle_rows.py`.
6+
"""
7+
8+
import pytest
9+
import torch
10+
11+
from vllm._custom_ops import shuffle_rows
12+
from vllm.platforms import current_platform
13+
14+
15+
@pytest.mark.parametrize("num_tokens", [1, 16, 64, 128, 256, 512, 1024])
16+
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 1024, 2048, 4096])
17+
@pytest.mark.parametrize("dtype",
18+
[torch.float16, torch.bfloat16, torch.float32])
19+
def test_shuffle_rows_basic(num_tokens: int, hidden_size: int,
20+
dtype: torch.dtype):
21+
"""Test basic functionality of shuffle_rows with various tensor sizes and
22+
dtypes."""
23+
if not current_platform.is_cuda():
24+
pytest.skip("shuffle_rows requires CUDA")
25+
26+
# Create input tensor
27+
input_tensor = torch.randn(num_tokens,
28+
hidden_size,
29+
device="cuda",
30+
dtype=dtype)
31+
32+
# Create a simple permutation map (identity mapping)
33+
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
34+
35+
# Test shuffle_rows
36+
output = shuffle_rows(input_tensor, dst2src_map)
37+
38+
# With identity mapping, output should be identical to input
39+
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0)
40+
41+
# Check output shape
42+
assert output.shape == (num_tokens, hidden_size)
43+
assert output.dtype == dtype
44+
assert output.device == input_tensor.device
45+
46+
47+
@pytest.mark.parametrize("num_tokens", [16, 64, 128])
48+
@pytest.mark.parametrize("hidden_size", [128, 512, 1024])
49+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
50+
def test_shuffle_rows_permutation(num_tokens: int, hidden_size: int,
51+
dtype: torch.dtype):
52+
"""Test shuffle_rows with actual permutation."""
53+
if not current_platform.is_cuda():
54+
pytest.skip("shuffle_rows requires CUDA")
55+
56+
# Create input tensor
57+
input_tensor = torch.randn(num_tokens,
58+
hidden_size,
59+
device="cuda",
60+
dtype=dtype)
61+
62+
# Create a reverse permutation map
63+
dst2src_map = torch.arange(num_tokens - 1,
64+
-1,
65+
-1,
66+
device="cuda",
67+
dtype=torch.int32)
68+
69+
# Test shuffle_rows
70+
output = shuffle_rows(input_tensor, dst2src_map)
71+
72+
# Check that the output is the reverse of the input
73+
expected_output = torch.flip(input_tensor, dims=[0])
74+
torch.testing.assert_close(output, expected_output, atol=1e-6, rtol=1e-5)
75+
76+
# Check output shape and properties
77+
assert output.shape == (num_tokens, hidden_size)
78+
assert output.dtype == dtype
79+
assert output.device == input_tensor.device
80+
81+
82+
@pytest.mark.parametrize("num_tokens", [32, 64])
83+
@pytest.mark.parametrize("hidden_size", [256, 512])
84+
def test_shuffle_rows_expansion(num_tokens: int, hidden_size: int):
85+
"""Test shuffle_rows with expansion (more output tokens than input
86+
tokens)."""
87+
if not current_platform.is_cuda():
88+
pytest.skip("shuffle_rows requires CUDA")
89+
90+
dtype = torch.float16
91+
92+
# Create input tensor
93+
input_tensor = torch.randn(num_tokens,
94+
hidden_size,
95+
device="cuda",
96+
dtype=dtype)
97+
98+
# Create a mapping that duplicates some tokens (expansion)
99+
expanded_size = num_tokens * 2
100+
dst2src_map = torch.randint(0,
101+
num_tokens, (expanded_size, ),
102+
device="cuda",
103+
dtype=torch.int32)
104+
105+
# Test shuffle_rows
106+
output = shuffle_rows(input_tensor, dst2src_map)
107+
108+
# Check output shape
109+
assert output.shape == (expanded_size, hidden_size)
110+
assert output.dtype == dtype
111+
assert output.device == input_tensor.device
112+
113+
# Verify that each output row matches the corresponding input row
114+
for i in range(expanded_size):
115+
src_idx = dst2src_map[i].item()
116+
torch.testing.assert_close(output[i],
117+
input_tensor[src_idx],
118+
atol=1e-6,
119+
rtol=1e-5)
120+
121+
122+
@pytest.mark.parametrize("num_tokens", [16, 64])
123+
@pytest.mark.parametrize("hidden_size", [128, 512])
124+
def test_shuffle_rows_random_permutation(num_tokens: int, hidden_size: int):
125+
"""Test shuffle_rows with random permutation."""
126+
if not current_platform.is_cuda():
127+
pytest.skip("shuffle_rows requires CUDA")
128+
129+
dtype = torch.float16
130+
131+
# Set seed for reproducibility
132+
torch.manual_seed(42)
133+
134+
# Create input tensor
135+
input_tensor = torch.randn(num_tokens,
136+
hidden_size,
137+
device="cuda",
138+
dtype=dtype)
139+
140+
# Create a random permutation map
141+
dst2src_map = torch.randperm(num_tokens, device="cuda", dtype=torch.int32)
142+
143+
# Test shuffle_rows
144+
output = shuffle_rows(input_tensor, dst2src_map)
145+
146+
# Check output shape and properties
147+
assert output.shape == (num_tokens, hidden_size)
148+
assert output.dtype == dtype
149+
assert output.device == input_tensor.device
150+
151+
# Verify that each output row matches the corresponding input row
152+
for i in range(num_tokens):
153+
src_idx = dst2src_map[i].item()
154+
torch.testing.assert_close(output[i],
155+
input_tensor[src_idx],
156+
atol=1e-6,
157+
rtol=1e-5)
158+
159+
160+
def test_shuffle_rows_edge_cases():
161+
"""Test shuffle_rows with edge cases."""
162+
if not current_platform.is_cuda():
163+
pytest.skip("shuffle_rows requires CUDA")
164+
165+
dtype = torch.float16
166+
167+
# Test with single token
168+
input_tensor = torch.randn(1, 128, device="cuda", dtype=dtype)
169+
dst2src_map = torch.tensor([0], device="cuda", dtype=torch.int32)
170+
output = shuffle_rows(input_tensor, dst2src_map)
171+
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0)
172+
173+
# Test with single feature dimension
174+
input_tensor = torch.randn(16, 1, device="cuda", dtype=dtype)
175+
dst2src_map = torch.arange(16, device="cuda", dtype=torch.int32)
176+
output = shuffle_rows(input_tensor, dst2src_map)
177+
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0)
178+
179+
180+
def test_shuffle_rows_moe_like_scenario():
181+
"""Test shuffle_rows in a scenario similar to MoE usage."""
182+
if not current_platform.is_cuda():
183+
pytest.skip("shuffle_rows requires CUDA")
184+
185+
dtype = torch.float16
186+
batch_size = 32
187+
hidden_size = 1024
188+
topk = 2
189+
190+
# Simulate input tokens
191+
input_tensor = torch.randn(batch_size,
192+
hidden_size,
193+
device="cuda",
194+
dtype=dtype)
195+
196+
# Simulate expert assignment (each token goes to topk experts)
197+
# This creates a mapping where tokens are duplicated for multiple experts
198+
total_tokens = batch_size * topk
199+
dst2src_map = torch.zeros(total_tokens, device="cuda", dtype=torch.int32)
200+
201+
# Fill the mapping to simulate MoE token distribution
202+
for i in range(batch_size):
203+
for k in range(topk):
204+
dst2src_map[i * topk + k] = i
205+
206+
# Test shuffle_rows
207+
output = shuffle_rows(input_tensor, dst2src_map)
208+
209+
# Check output shape
210+
assert output.shape == (total_tokens, hidden_size)
211+
assert output.dtype == dtype
212+
assert output.device == input_tensor.device
213+
214+
# Verify that tokens are correctly duplicated
215+
for i in range(batch_size):
216+
for k in range(topk):
217+
output_idx = i * topk + k
218+
torch.testing.assert_close(output[output_idx],
219+
input_tensor[i],
220+
atol=1e-6,
221+
rtol=1e-5)
222+
223+
224+
@pytest.mark.parametrize("dtype",
225+
[torch.float16, torch.bfloat16, torch.float32])
226+
def test_shuffle_rows_dtype_consistency(dtype: torch.dtype):
227+
"""Test that shuffle_rows preserves dtype correctly."""
228+
if not current_platform.is_cuda():
229+
pytest.skip("shuffle_rows requires CUDA")
230+
231+
num_tokens = 64
232+
hidden_size = 512
233+
234+
# Create input tensor with specific dtype
235+
input_tensor = torch.randn(num_tokens,
236+
hidden_size,
237+
device="cuda",
238+
dtype=dtype)
239+
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
240+
241+
# Test shuffle_rows
242+
output = shuffle_rows(input_tensor, dst2src_map)
243+
244+
# Verify dtype is preserved
245+
assert output.dtype == dtype
246+
assert output.device == input_tensor.device
247+
torch.testing.assert_close(output, input_tensor, atol=1e-6, rtol=1e-5)
248+
249+
250+
def test_shuffle_rows_device_consistency():
251+
"""Test that shuffle_rows maintains device consistency."""
252+
if not current_platform.is_cuda():
253+
pytest.skip("shuffle_rows requires CUDA")
254+
255+
num_tokens = 32
256+
hidden_size = 256
257+
dtype = torch.float16
258+
259+
# Create input tensor on CUDA
260+
input_tensor = torch.randn(num_tokens,
261+
hidden_size,
262+
device="cuda",
263+
dtype=dtype)
264+
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
265+
266+
# Test shuffle_rows
267+
output = shuffle_rows(input_tensor, dst2src_map)
268+
269+
# Verify device is maintained
270+
assert output.device == input_tensor.device
271+
assert output.device.type == "cuda"
272+
273+
274+
def test_shuffle_rows_contiguous_output():
275+
"""Test that shuffle_rows produces contiguous output."""
276+
if not current_platform.is_cuda():
277+
pytest.skip("shuffle_rows requires CUDA")
278+
279+
num_tokens = 64
280+
hidden_size = 512
281+
dtype = torch.float16
282+
283+
# Create input tensor
284+
input_tensor = torch.randn(num_tokens,
285+
hidden_size,
286+
device="cuda",
287+
dtype=dtype)
288+
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
289+
290+
# Test shuffle_rows
291+
output = shuffle_rows(input_tensor, dst2src_map)
292+
293+
# Verify output is contiguous
294+
assert output.is_contiguous()

0 commit comments

Comments
 (0)